forked from 626_privacy/tensorflow_privacy
713 lines
21 KiB
Python
713 lines
21 KiB
Python
|
# Copyright 2021 Google LLC
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
|
||
|
# This program solves the NeuraCrypt challenge to 100% accuracy.
|
||
|
# Given a set of encoded images and original versions of those,
|
||
|
# it shows how to match the original to the encoded.
|
||
|
|
||
|
import collections
|
||
|
import hashlib
|
||
|
import time
|
||
|
import multiprocessing as mp
|
||
|
|
||
|
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
import torch.nn as nn
|
||
|
import scipy.stats
|
||
|
import matplotlib.pyplot as plt
|
||
|
from PIL import Image
|
||
|
|
||
|
import jax
|
||
|
import jax.numpy as jn
|
||
|
import objax
|
||
|
import scipy.optimize
|
||
|
import numpy as np
|
||
|
import multiprocessing as mp
|
||
|
|
||
|
# Objax neural network that's going to embed patches to a
|
||
|
# low dimensional space to guess if two patches correspond
|
||
|
# to the same orginal image.
|
||
|
class Model(objax.Module):
|
||
|
def __init__(self):
|
||
|
IN = 15
|
||
|
H = 64
|
||
|
self.encoder =objax.nn.Sequential([
|
||
|
objax.nn.Linear(IN, H),
|
||
|
objax.functional.leaky_relu,
|
||
|
objax.nn.Linear(H, H),
|
||
|
objax.functional.leaky_relu,
|
||
|
objax.nn.Linear(H, 8)])
|
||
|
self.decoder =objax.nn.Sequential([
|
||
|
objax.nn.Linear(IN, H),
|
||
|
objax.functional.leaky_relu,
|
||
|
objax.nn.Linear(H, H),
|
||
|
objax.functional.leaky_relu,
|
||
|
objax.nn.Linear(H, 8)])
|
||
|
self.scale = objax.nn.Linear(1, 1, use_bias=False)
|
||
|
def encode(self, x):
|
||
|
# Encode turns original images into feature space
|
||
|
a = self.encoder(x)
|
||
|
a = a/jn.sum(a**2,axis=-1,keepdims=True)**.5
|
||
|
return a
|
||
|
def decode(self, x):
|
||
|
# And decode turns encoded images into feature space
|
||
|
a = self.decoder(x)
|
||
|
a = a/jn.sum(a**2,axis=-1,keepdims=True)**.5
|
||
|
return a
|
||
|
|
||
|
# Proxy dataset for analysis
|
||
|
class ImageNet:
|
||
|
num_chan = 3
|
||
|
private_kernel_size = 16
|
||
|
hidden_dim = 2048
|
||
|
img_size = (256, 256)
|
||
|
private_depth = 7
|
||
|
def __init__(self, remove):
|
||
|
self.remove_pixel_shuffle = remove
|
||
|
|
||
|
# Original dataset as used in the NeuraCrypt paper
|
||
|
class Xray:
|
||
|
num_chan = 1
|
||
|
private_kernel_size = 16
|
||
|
hidden_dim = 2048
|
||
|
img_size = (256, 256)
|
||
|
private_depth = 4
|
||
|
def __init__(self, remove):
|
||
|
self.remove_pixel_shuffle = remove
|
||
|
|
||
|
## The following class is taken directly from the NeuraCrypt codebase.
|
||
|
## https://github.com/yala/NeuraCrypt
|
||
|
## which is originally licensed under the MIT License
|
||
|
class PrivateEncoder(nn.Module):
|
||
|
def __init__(self, args, width_factor=1):
|
||
|
super(PrivateEncoder, self).__init__()
|
||
|
self.args = args
|
||
|
input_dim = args.num_chan
|
||
|
patch_size = args.private_kernel_size
|
||
|
output_dim = args.hidden_dim
|
||
|
num_patches = (args.img_size[0] // patch_size) **2
|
||
|
self.noise_size = 1
|
||
|
|
||
|
args.input_dim = args.hidden_dim
|
||
|
|
||
|
|
||
|
layers = [
|
||
|
nn.Conv2d(input_dim, output_dim * width_factor, kernel_size=patch_size, dilation=1 ,stride=patch_size),
|
||
|
nn.ReLU()
|
||
|
]
|
||
|
for _ in range(self.args.private_depth):
|
||
|
layers.extend( [
|
||
|
nn.Conv2d(output_dim * width_factor, output_dim * width_factor , kernel_size=1, dilation=1, stride=1),
|
||
|
nn.BatchNorm2d(output_dim * width_factor, track_running_stats=False),
|
||
|
nn.ReLU()
|
||
|
])
|
||
|
|
||
|
|
||
|
self.image_encoder = nn.Sequential(*layers)
|
||
|
|
||
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, output_dim * width_factor))
|
||
|
|
||
|
self.mixer = nn.Sequential( *[
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(output_dim * width_factor, output_dim)
|
||
|
])
|
||
|
|
||
|
|
||
|
def forward(self, x):
|
||
|
encoded = self.image_encoder(x)
|
||
|
B, C, H,W = encoded.size()
|
||
|
encoded = encoded.view([B, -1, H*W]).transpose(1,2)
|
||
|
encoded += self.pos_embedding
|
||
|
encoded = self.mixer(encoded)
|
||
|
|
||
|
## Shuffle indicies
|
||
|
if not self.args.remove_pixel_shuffle:
|
||
|
shuffled = torch.zeros_like(encoded)
|
||
|
for i in range(B):
|
||
|
idx = torch.randperm(H*W, device=encoded.device)
|
||
|
for j, k in enumerate(idx):
|
||
|
shuffled[i,j] = encoded[i,k]
|
||
|
encoded = shuffled
|
||
|
|
||
|
return encoded
|
||
|
## End copied code
|
||
|
|
||
|
def setup(ds):
|
||
|
"""
|
||
|
Load the datasets to use. Nothing interesting to see.
|
||
|
"""
|
||
|
global x_train, y_train
|
||
|
if ds == 'imagenet':
|
||
|
import torchvision
|
||
|
transform = torchvision.transforms.Compose([
|
||
|
torchvision.transforms.Resize(256),
|
||
|
torchvision.transforms.CenterCrop(256),
|
||
|
torchvision.transforms.ToTensor()])
|
||
|
imagenet_data = torchvision.datasets.ImageNet('/mnt/data/datasets/unpacked_imagenet_pytorch/',
|
||
|
split='val',
|
||
|
transform=transform)
|
||
|
data_loader = torch.utils.data.DataLoader(imagenet_data,
|
||
|
batch_size=100,
|
||
|
shuffle=True,
|
||
|
num_workers=8)
|
||
|
r = []
|
||
|
for x,_ in data_loader:
|
||
|
if len(r) > 1000: break
|
||
|
print(x.shape)
|
||
|
r.extend(x.numpy())
|
||
|
x_train = np.array(r)
|
||
|
print(x_train.shape)
|
||
|
elif ds == 'xray':
|
||
|
import torchvision
|
||
|
transform = torchvision.transforms.Compose([
|
||
|
torchvision.transforms.Resize(256),
|
||
|
torchvision.transforms.CenterCrop(256),
|
||
|
torchvision.transforms.ToTensor()])
|
||
|
imagenet_data = torchvision.datasets.ImageFolder('CheXpert-v1.0/train',
|
||
|
transform=transform)
|
||
|
data_loader = torch.utils.data.DataLoader(imagenet_data,
|
||
|
batch_size=100,
|
||
|
shuffle=True,
|
||
|
num_workers=8)
|
||
|
r = []
|
||
|
for x,_ in data_loader:
|
||
|
if len(r) > 1000: break
|
||
|
print(x.shape)
|
||
|
r.extend(x.numpy())
|
||
|
x_train = np.array(r)
|
||
|
print(x_train.shape)
|
||
|
elif ds == 'challenge':
|
||
|
x_train = np.load("orig-7.npy")
|
||
|
print(np.min(x_train), np.max(x_train), x_train.shape)
|
||
|
else:
|
||
|
raise
|
||
|
|
||
|
|
||
|
def gen_train_data():
|
||
|
"""
|
||
|
Generate aligned training data to train a patch similarity function.
|
||
|
Given some original images, generate lots of encoded versions.
|
||
|
"""
|
||
|
global encoded_train, original_train
|
||
|
|
||
|
encoded_train = []
|
||
|
original_train = []
|
||
|
|
||
|
args = Xray(True)
|
||
|
|
||
|
C = 100
|
||
|
for i in range(30):
|
||
|
print(i)
|
||
|
torch.manual_seed(int(time.time()))
|
||
|
e = PrivateEncoder(args).cuda()
|
||
|
batch = np.random.randint(0, len(x_train), size=C)
|
||
|
xin = x_train[batch]
|
||
|
|
||
|
r = []
|
||
|
for i in range(0,C,32):
|
||
|
r.extend(e(torch.tensor(xin[i:i+32]).cuda()).detach().cpu().numpy())
|
||
|
r = np.array(r)
|
||
|
|
||
|
encoded_train.append(r)
|
||
|
original_train.append(xin)
|
||
|
|
||
|
def features_(x, moments=15, encoded=False):
|
||
|
"""
|
||
|
Compute higher-order moments for patches in an image to use as
|
||
|
features for the neural network.
|
||
|
"""
|
||
|
x = np.array(x, dtype=np.float32)
|
||
|
dim = 2
|
||
|
arr = np.array([np.mean(x, dim)] + [abs(scipy.stats.moment(x, moment=i, axis=dim))**(1/i) for i in range(1,moments)])
|
||
|
|
||
|
return arr.transpose((1,2,0))
|
||
|
|
||
|
|
||
|
def features(x, encoded):
|
||
|
"""
|
||
|
Given the original images or the encoded images, generate the
|
||
|
features to use for the patch similarity function.
|
||
|
"""
|
||
|
print('start shape',x.shape)
|
||
|
if len(x.shape) == 3:
|
||
|
x = x - np.mean(x,axis=0,keepdims=True)
|
||
|
else:
|
||
|
# count x 100 x 256 x 768
|
||
|
print(x[0].shape)
|
||
|
x = x - np.mean(x,axis=1,keepdims=True)
|
||
|
# remove per-neural-network dimension
|
||
|
x = x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
|
||
|
p = mp.Pool(96)
|
||
|
B = len(x)//96
|
||
|
print(1)
|
||
|
bs = [x[i:i+B] for i in range(0,len(x),B)]
|
||
|
print(2)
|
||
|
r = p.map(features_, bs)
|
||
|
#r = features_(bs[0][:100])
|
||
|
print(3)
|
||
|
p.close()
|
||
|
#r = np.array(r)
|
||
|
#print('finish',r.shape)
|
||
|
return np.concatenate(r, axis=0)
|
||
|
|
||
|
|
||
|
|
||
|
def get_train_features():
|
||
|
"""
|
||
|
Create features for the entire datasets.
|
||
|
"""
|
||
|
global xs_train, ys_train
|
||
|
print(x_train.shape)
|
||
|
original_train_ = np.array(original_train)
|
||
|
encoded_train_ = np.array(encoded_train)
|
||
|
|
||
|
print("Computing features")
|
||
|
ys_train = features(encoded_train_, True)
|
||
|
|
||
|
patch_size = 16
|
||
|
ss = original_train_.shape[3]//patch_size
|
||
|
# Okay so this is an ugly transpose block.
|
||
|
# We are going from [outer_batch, batch_size, channels, width, height
|
||
|
# to [outer_batch, batch_size, channels, width/patch_size, patch_size, height/patch_size, patch_size]
|
||
|
# Then we reshape this and flatten so that we end up with
|
||
|
# [other_batch, batch_size, width/patch_size, height_patch_size, patch_size**2*channels]
|
||
|
# So that now we can run features on the last dimension
|
||
|
original_train_ = original_train_.reshape((original_train_.shape[0],
|
||
|
original_train_.shape[1],
|
||
|
original_train_.shape[2],
|
||
|
ss,patch_size,ss,patch_size)).transpose((0,1,3,5,2,4,6)).reshape((original_train_.shape[0], original_train_.shape[1], ss**2, patch_size**2))
|
||
|
|
||
|
|
||
|
xs_train = features(original_train_, False)
|
||
|
|
||
|
print(xs_train.shape, ys_train.shape)
|
||
|
|
||
|
|
||
|
def train_model():
|
||
|
"""
|
||
|
Train the patch similarity function
|
||
|
"""
|
||
|
global ema, model
|
||
|
|
||
|
model = Model()
|
||
|
def loss(x, y):
|
||
|
"""
|
||
|
K-way contrastive loss as in SimCLR et al.
|
||
|
The idea is that we should embed x and y so that they are similar
|
||
|
to each other, and dis-similar from others. To do this we have a
|
||
|
softmx loss over one dimension to make the values large on the diagonal
|
||
|
and small off-diagonal.
|
||
|
"""
|
||
|
a = model.encode(x)
|
||
|
b = model.decode(y)
|
||
|
|
||
|
mat = a@b.T
|
||
|
return objax.functional.loss.cross_entropy_logits_sparse(
|
||
|
logits=jn.exp(jn.clip(model.scale.w.value, -2, 4)) * mat,
|
||
|
labels=np.arange(a.shape[0])).mean()
|
||
|
|
||
|
ema = objax.optimizer.ExponentialMovingAverage(model.vars(), momentum=0.999)
|
||
|
gv = objax.GradValues(loss, model.vars())
|
||
|
|
||
|
encode_ema = ema.replace_vars(lambda x: model.encode(x))
|
||
|
decode_ema = ema.replace_vars(lambda y: model.decode(y))
|
||
|
|
||
|
def train_op(x, y):
|
||
|
"""
|
||
|
No one was ever fired for using Adam with 1e-4.
|
||
|
"""
|
||
|
g, v = gv(x, y)
|
||
|
opt(1e-4, g)
|
||
|
ema()
|
||
|
return v
|
||
|
|
||
|
opt = objax.optimizer.Adam(model.vars())
|
||
|
train_op = objax.Jit(train_op, gv.vars() + opt.vars() + ema.vars())
|
||
|
|
||
|
ys_ = ys_train
|
||
|
|
||
|
print(ys_.shape)
|
||
|
|
||
|
xs_ = xs_train.reshape((-1, xs_train.shape[-1]))
|
||
|
ys_ = ys_.reshape((-1, ys_train.shape[-1]))
|
||
|
|
||
|
# The model scale trick here is taken from CLIP.
|
||
|
# Let the model decide how confident to make its own predictions.
|
||
|
model.scale.w.assign(jn.zeros((1,1)))
|
||
|
|
||
|
valid_size = 1000
|
||
|
|
||
|
print(xs_train.shape)
|
||
|
# SimCLR likes big batches
|
||
|
B = 4096
|
||
|
for it in range(80):
|
||
|
print()
|
||
|
ms = []
|
||
|
for i in range(1000):
|
||
|
# First batch is smaller, to make training more stable
|
||
|
bs = [B//64, B][it>0]
|
||
|
batch = np.random.randint(0, len(xs_)-valid_size, size=bs)
|
||
|
r = train_op(xs_[batch], ys_[batch])
|
||
|
|
||
|
# This shouldn't happen, but if it does, better to bort early
|
||
|
if np.isnan(r):
|
||
|
print("Die on nan")
|
||
|
print(ms[-100:])
|
||
|
return
|
||
|
ms.append(r)
|
||
|
|
||
|
print('mean',np.mean(ms), 'scale', model.scale.w.value)
|
||
|
print('loss',loss(xs_[-100:], ys_[-100:]))
|
||
|
|
||
|
a = encode_ema(xs_[-valid_size:])
|
||
|
b = decode_ema(ys_[-valid_size:])
|
||
|
|
||
|
br = b[np.random.permutation(len(b))]
|
||
|
|
||
|
print('score',np.mean(np.sum(a*b,axis=(1)) - np.sum(a*br,axis=(1))),
|
||
|
np.mean(np.sum(a*b,axis=(1)) > np.sum(a*br,axis=(1))))
|
||
|
ckpt = objax.io.Checkpoint("saved", keep_ckpts=0)
|
||
|
ema.replace_vars(lambda: ckpt.save(model.vars(), 0))()
|
||
|
|
||
|
|
||
|
|
||
|
def load_challenge():
|
||
|
"""
|
||
|
Load the challenge datast for attacking
|
||
|
"""
|
||
|
global xs, ys, encoded, original, ooriginal
|
||
|
print("SETUP: Loading matrixes")
|
||
|
# The encoded images
|
||
|
encoded = np.load("challenge-7.npy")
|
||
|
# And the original images
|
||
|
ooriginal = original = np.load("orig-7.npy")
|
||
|
|
||
|
print("Sizes", encoded.shape, ooriginal.shape)
|
||
|
|
||
|
# Again do that ugly resize thing to make the features be on the last dimension
|
||
|
# Look up above to see what's going on.
|
||
|
patch_size = 16
|
||
|
ss = original.shape[2]//patch_size
|
||
|
original = ooriginal.reshape((original.shape[0],1,ss,patch_size,ss,patch_size))
|
||
|
original = original.transpose((0,2,4,1,3,5))
|
||
|
original = original.reshape((original.shape[0], ss**2, patch_size**2))
|
||
|
|
||
|
|
||
|
def match_sub(args):
|
||
|
"""
|
||
|
Find the best way to undo the permutation between two images.
|
||
|
"""
|
||
|
vec1, vec2 = args
|
||
|
value = np.sum((vec1[None,:,:] - vec2[:,None,:])**2,axis=2)
|
||
|
row, col = scipy.optimize.linear_sum_assignment(value)
|
||
|
return col
|
||
|
|
||
|
|
||
|
def recover_local_permutation():
|
||
|
"""
|
||
|
Given a set of encoded images, return a new encoding without permutations
|
||
|
"""
|
||
|
global encoded, ys
|
||
|
|
||
|
p = mp.Pool(96)
|
||
|
print('recover local')
|
||
|
local_perm = p.map(match_sub, [(encoded[0], e) for e in encoded])
|
||
|
local_perm = np.array(local_perm)
|
||
|
|
||
|
encoded_perm = []
|
||
|
|
||
|
for i in range(len(encoded)):
|
||
|
encoded_perm.append(encoded[i][np.argsort(local_perm[i])])
|
||
|
|
||
|
encoded_perm = np.array(encoded_perm)
|
||
|
|
||
|
encoded = np.array(encoded_perm)
|
||
|
|
||
|
p.close()
|
||
|
|
||
|
|
||
|
def recover_better_local_permutation():
|
||
|
"""
|
||
|
Given a set of encoded images, return a new encoding, but better!
|
||
|
"""
|
||
|
global encoded, ys
|
||
|
|
||
|
# Now instead of pairing all images to image 0, we compute the mean l2 vector
|
||
|
# and then pair all images onto the mean vector. Slightly more noise resistant.
|
||
|
p = mp.Pool(96)
|
||
|
target = encoded.mean(0)
|
||
|
local_perm = p.map(match_sub, [(target, e) for e in encoded])
|
||
|
local_perm = np.array(local_perm)
|
||
|
|
||
|
# Probably we didn't change by much, generally <0.1%
|
||
|
print('improved changed by', np.mean(local_perm != np.arange(local_perm.shape[1])))
|
||
|
|
||
|
encoded_perm = []
|
||
|
|
||
|
for i in range(len(encoded)):
|
||
|
encoded_perm.append(encoded[i][np.argsort(local_perm[i])])
|
||
|
|
||
|
encoded = np.array(encoded_perm)
|
||
|
|
||
|
p.close()
|
||
|
|
||
|
|
||
|
def compute_patch_similarity():
|
||
|
"""
|
||
|
Compute the feature vectors for each patch using the trained neural network.
|
||
|
"""
|
||
|
global xs, ys, xs_image, ys_image
|
||
|
|
||
|
print("Computing features")
|
||
|
ys = features(encoded, encoded=True)
|
||
|
xs = features(original, encoded=False)
|
||
|
|
||
|
model = Model()
|
||
|
ckpt = objax.io.Checkpoint("saved", keep_ckpts=0)
|
||
|
ckpt.restore(model.vars())
|
||
|
|
||
|
xs_image = model.encode(xs)
|
||
|
ys_image = model.decode(ys)
|
||
|
assert xs.shape[0] == xs_image.shape[0]
|
||
|
print("Done")
|
||
|
|
||
|
|
||
|
def match(args, ret_col=False):
|
||
|
"""
|
||
|
Compute the similarity between image features and encoded features.
|
||
|
"""
|
||
|
vec1, vec2s = args
|
||
|
r = []
|
||
|
open("/tmp/start%d.%d"%(np.random.randint(10000),time.time()),"w").write("hi")
|
||
|
for vec2 in vec2s:
|
||
|
value = np.sum(vec1[None,:,:] * vec2[:,None,:],axis=2)
|
||
|
|
||
|
row, col = scipy.optimize.linear_sum_assignment(-value)
|
||
|
r.append(value[row,col].mean())
|
||
|
return r
|
||
|
|
||
|
|
||
|
|
||
|
def recover_global_matching_first():
|
||
|
"""
|
||
|
Recover the global matching of original to encoded images by doing
|
||
|
an all-pairs matching problem
|
||
|
"""
|
||
|
global global_matching, ys_image, encoded
|
||
|
|
||
|
matrix = []
|
||
|
p = mp.Pool(96)
|
||
|
xs_image_ = np.array(xs_image)
|
||
|
ys_image_ = np.array(ys_image)
|
||
|
|
||
|
matrix = p.map(match, [(x, ys_image_) for x in xs_image_])
|
||
|
matrix = np.array(matrix).reshape((xs_image.shape[0],
|
||
|
xs_image.shape[0]))
|
||
|
|
||
|
|
||
|
row, col = scipy.optimize.linear_sum_assignment(-np.array(matrix))
|
||
|
global_matching = np.argsort(col)
|
||
|
print('glob',list(global_matching))
|
||
|
|
||
|
p.close()
|
||
|
|
||
|
|
||
|
|
||
|
def recover_global_permutation():
|
||
|
"""
|
||
|
Find the way that the encoded images are permuted off of the original images
|
||
|
"""
|
||
|
global global_permutation
|
||
|
|
||
|
print("Glob match", global_matching)
|
||
|
overall = []
|
||
|
for i,j in enumerate(global_matching):
|
||
|
overall.append(np.sum(xs_image[j][None,:,:] * ys_image[i][:,None,:],axis=2))
|
||
|
|
||
|
overall = np.mean(overall, 0)
|
||
|
|
||
|
row, col = scipy.optimize.linear_sum_assignment(-overall)
|
||
|
|
||
|
try:
|
||
|
print("Changed frac:", np.mean(global_permutation!=np.argsort(col)))
|
||
|
except:
|
||
|
pass
|
||
|
|
||
|
global_permutation = np.argsort(col)
|
||
|
|
||
|
|
||
|
def recover_global_matching_second():
|
||
|
"""
|
||
|
Match each encoded image with its original encoded image,
|
||
|
but better by relying on the global permutation.
|
||
|
"""
|
||
|
global global_matching_second, global_matching
|
||
|
|
||
|
ys_fix = []
|
||
|
for i in range(ys_image.shape[0]):
|
||
|
ys_fix.append(ys_image[i][global_permutation])
|
||
|
ys_fix = np.array(ys_fix)
|
||
|
|
||
|
|
||
|
print(xs_image.shape)
|
||
|
|
||
|
sims = []
|
||
|
for i in range(0,len(xs_image),10):
|
||
|
tmp = np.mean(xs_image[None,:,:,:] * ys_fix[i:i+10][:,None,:,:],axis=(2,3))
|
||
|
sims.extend(tmp)
|
||
|
sims = np.array(sims)
|
||
|
print(sims.shape)
|
||
|
|
||
|
|
||
|
row, col = scipy.optimize.linear_sum_assignment(-sims)
|
||
|
|
||
|
print('arg',sims.argmax(1))
|
||
|
|
||
|
print("Same matching frac", np.mean(col == global_matching) )
|
||
|
print(col)
|
||
|
global_matching = col
|
||
|
|
||
|
|
||
|
def extract_by_training(resume):
|
||
|
"""
|
||
|
Final recovery process by extracting the neural network
|
||
|
"""
|
||
|
global inverse
|
||
|
|
||
|
device = torch.device('cuda:1')
|
||
|
|
||
|
if not resume:
|
||
|
inverse = PrivateEncoder(Xray(True)).cuda(device)
|
||
|
|
||
|
# More adam to train.
|
||
|
optimizer = torch.optim.Adam(inverse.parameters(), lr=0.0001)
|
||
|
|
||
|
this_xs = ooriginal[global_matching]
|
||
|
this_ys = encoded[:,global_permutation,:]
|
||
|
|
||
|
for i in range(2000):
|
||
|
idx = np.random.random_integers(0, len(this_xs)-1, 32)
|
||
|
xbatch = torch.tensor(this_xs[idx]).cuda(device)
|
||
|
ybatch = torch.tensor(this_ys[idx]).cuda(device)
|
||
|
|
||
|
optimizer.zero_grad()
|
||
|
|
||
|
guess_output = inverse(xbatch)
|
||
|
# L1 loss because we don't want to be sensitive to outliers
|
||
|
error = torch.mean(torch.abs(guess_output-ybatch))
|
||
|
error.backward()
|
||
|
|
||
|
optimizer.step()
|
||
|
|
||
|
print(error)
|
||
|
|
||
|
|
||
|
|
||
|
def test_extract():
|
||
|
"""
|
||
|
Now we can recover the matching much better by computing the estimated
|
||
|
encodings for each original image.
|
||
|
"""
|
||
|
global err, global_matching, guessed_encoded, smatrix
|
||
|
|
||
|
device = torch.device('cuda:1')
|
||
|
|
||
|
print(ooriginal.shape, encoded.shape)
|
||
|
|
||
|
out = []
|
||
|
for i in range(0,len(ooriginal),32):
|
||
|
print(i)
|
||
|
out.extend(inverse(torch.tensor(ooriginal[i:i+32]).cuda(device)).cpu().detach().numpy())
|
||
|
|
||
|
guessed_encoded = np.array(out)
|
||
|
|
||
|
|
||
|
# Now we have to compare each encoded image with every other original image.
|
||
|
# Do this fast with some matrix multiplies.
|
||
|
|
||
|
out = guessed_encoded.reshape((len(encoded), -1))
|
||
|
real = encoded[:,global_permutation,:].reshape((len(encoded), -1))
|
||
|
@jax.jit
|
||
|
def foo(x, y):
|
||
|
return jn.square(x[:,None] - y[None,:]).sum(2)
|
||
|
|
||
|
smatrix = np.zeros((len(out), len(out)))
|
||
|
|
||
|
B = 500
|
||
|
for i in range(0,len(out),B):
|
||
|
print(i)
|
||
|
for j in range(0,len(out),B):
|
||
|
smatrix[i:i+B, j:j+B] = foo(out[i:i+B], real[j:j+B])
|
||
|
|
||
|
# And the final time you'l have to look at a min weight matching, I promise.
|
||
|
row, col = scipy.optimize.linear_sum_assignment(np.array(smatrix))
|
||
|
r = np.array(smatrix)
|
||
|
|
||
|
print(list(row)[::100])
|
||
|
|
||
|
print("Differences", np.mean(np.argsort(col) != global_matching))
|
||
|
|
||
|
global_matching = np.argsort(col)
|
||
|
|
||
|
|
||
|
def perf(steps=[]):
|
||
|
if len(steps) == 0:
|
||
|
steps.append(time.time())
|
||
|
else:
|
||
|
print("Last Time Elapsed:", time.time()-steps[-1], ' Total Time Elapsed:', time.time()-steps[0])
|
||
|
steps.append(time.time())
|
||
|
time.sleep(1)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
if True:
|
||
|
perf()
|
||
|
setup('challenge')
|
||
|
perf()
|
||
|
gen_train_data()
|
||
|
perf()
|
||
|
get_train_features()
|
||
|
perf()
|
||
|
train_model()
|
||
|
perf()
|
||
|
|
||
|
if True:
|
||
|
load_challenge()
|
||
|
perf()
|
||
|
recover_local_permutation()
|
||
|
perf()
|
||
|
recover_better_local_permutation()
|
||
|
perf()
|
||
|
compute_patch_similarity()
|
||
|
perf()
|
||
|
recover_global_matching_first()
|
||
|
perf()
|
||
|
|
||
|
for _ in range(3):
|
||
|
recover_global_permutation()
|
||
|
perf()
|
||
|
recover_global_matching_second()
|
||
|
perf()
|
||
|
|
||
|
for i in range(3):
|
||
|
recover_global_permutation()
|
||
|
perf()
|
||
|
extract_by_training(i > 0)
|
||
|
perf()
|
||
|
test_extract()
|
||
|
perf()
|
||
|
print(perf())
|