Merge pull request #185 from carlini:neuracrypt
PiperOrigin-RevId: 429632517
This commit is contained in:
commit
7e0b193393
1 changed files with 712 additions and 0 deletions
712
research/neuracrypt_attack_2021/attack.py
Normal file
712
research/neuracrypt_attack_2021/attack.py
Normal file
|
@ -0,0 +1,712 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# This program solves the NeuraCrypt challenge to 100% accuracy.
|
||||
# Given a set of encoded images and original versions of those,
|
||||
# it shows how to match the original to the encoded.
|
||||
|
||||
import collections
|
||||
import hashlib
|
||||
import time
|
||||
import multiprocessing as mp
|
||||
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import scipy.stats
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
|
||||
import jax
|
||||
import jax.numpy as jn
|
||||
import objax
|
||||
import scipy.optimize
|
||||
import numpy as np
|
||||
import multiprocessing as mp
|
||||
|
||||
# Objax neural network that's going to embed patches to a
|
||||
# low dimensional space to guess if two patches correspond
|
||||
# to the same orginal image.
|
||||
class Model(objax.Module):
|
||||
def __init__(self):
|
||||
IN = 15
|
||||
H = 64
|
||||
self.encoder =objax.nn.Sequential([
|
||||
objax.nn.Linear(IN, H),
|
||||
objax.functional.leaky_relu,
|
||||
objax.nn.Linear(H, H),
|
||||
objax.functional.leaky_relu,
|
||||
objax.nn.Linear(H, 8)])
|
||||
self.decoder =objax.nn.Sequential([
|
||||
objax.nn.Linear(IN, H),
|
||||
objax.functional.leaky_relu,
|
||||
objax.nn.Linear(H, H),
|
||||
objax.functional.leaky_relu,
|
||||
objax.nn.Linear(H, 8)])
|
||||
self.scale = objax.nn.Linear(1, 1, use_bias=False)
|
||||
def encode(self, x):
|
||||
# Encode turns original images into feature space
|
||||
a = self.encoder(x)
|
||||
a = a/jn.sum(a**2,axis=-1,keepdims=True)**.5
|
||||
return a
|
||||
def decode(self, x):
|
||||
# And decode turns encoded images into feature space
|
||||
a = self.decoder(x)
|
||||
a = a/jn.sum(a**2,axis=-1,keepdims=True)**.5
|
||||
return a
|
||||
|
||||
# Proxy dataset for analysis
|
||||
class ImageNet:
|
||||
num_chan = 3
|
||||
private_kernel_size = 16
|
||||
hidden_dim = 2048
|
||||
img_size = (256, 256)
|
||||
private_depth = 7
|
||||
def __init__(self, remove):
|
||||
self.remove_pixel_shuffle = remove
|
||||
|
||||
# Original dataset as used in the NeuraCrypt paper
|
||||
class Xray:
|
||||
num_chan = 1
|
||||
private_kernel_size = 16
|
||||
hidden_dim = 2048
|
||||
img_size = (256, 256)
|
||||
private_depth = 4
|
||||
def __init__(self, remove):
|
||||
self.remove_pixel_shuffle = remove
|
||||
|
||||
## The following class is taken directly from the NeuraCrypt codebase.
|
||||
## https://github.com/yala/NeuraCrypt
|
||||
## which is originally licensed under the MIT License
|
||||
class PrivateEncoder(nn.Module):
|
||||
def __init__(self, args, width_factor=1):
|
||||
super(PrivateEncoder, self).__init__()
|
||||
self.args = args
|
||||
input_dim = args.num_chan
|
||||
patch_size = args.private_kernel_size
|
||||
output_dim = args.hidden_dim
|
||||
num_patches = (args.img_size[0] // patch_size) **2
|
||||
self.noise_size = 1
|
||||
|
||||
args.input_dim = args.hidden_dim
|
||||
|
||||
|
||||
layers = [
|
||||
nn.Conv2d(input_dim, output_dim * width_factor, kernel_size=patch_size, dilation=1 ,stride=patch_size),
|
||||
nn.ReLU()
|
||||
]
|
||||
for _ in range(self.args.private_depth):
|
||||
layers.extend( [
|
||||
nn.Conv2d(output_dim * width_factor, output_dim * width_factor , kernel_size=1, dilation=1, stride=1),
|
||||
nn.BatchNorm2d(output_dim * width_factor, track_running_stats=False),
|
||||
nn.ReLU()
|
||||
])
|
||||
|
||||
|
||||
self.image_encoder = nn.Sequential(*layers)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, output_dim * width_factor))
|
||||
|
||||
self.mixer = nn.Sequential( *[
|
||||
nn.ReLU(),
|
||||
nn.Linear(output_dim * width_factor, output_dim)
|
||||
])
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
encoded = self.image_encoder(x)
|
||||
B, C, H,W = encoded.size()
|
||||
encoded = encoded.view([B, -1, H*W]).transpose(1,2)
|
||||
encoded += self.pos_embedding
|
||||
encoded = self.mixer(encoded)
|
||||
|
||||
## Shuffle indicies
|
||||
if not self.args.remove_pixel_shuffle:
|
||||
shuffled = torch.zeros_like(encoded)
|
||||
for i in range(B):
|
||||
idx = torch.randperm(H*W, device=encoded.device)
|
||||
for j, k in enumerate(idx):
|
||||
shuffled[i,j] = encoded[i,k]
|
||||
encoded = shuffled
|
||||
|
||||
return encoded
|
||||
## End copied code
|
||||
|
||||
def setup(ds):
|
||||
"""
|
||||
Load the datasets to use. Nothing interesting to see.
|
||||
"""
|
||||
global x_train, y_train
|
||||
if ds == 'imagenet':
|
||||
import torchvision
|
||||
transform = torchvision.transforms.Compose([
|
||||
torchvision.transforms.Resize(256),
|
||||
torchvision.transforms.CenterCrop(256),
|
||||
torchvision.transforms.ToTensor()])
|
||||
imagenet_data = torchvision.datasets.ImageNet('/mnt/data/datasets/unpacked_imagenet_pytorch/',
|
||||
split='val',
|
||||
transform=transform)
|
||||
data_loader = torch.utils.data.DataLoader(imagenet_data,
|
||||
batch_size=100,
|
||||
shuffle=True,
|
||||
num_workers=8)
|
||||
r = []
|
||||
for x,_ in data_loader:
|
||||
if len(r) > 1000: break
|
||||
print(x.shape)
|
||||
r.extend(x.numpy())
|
||||
x_train = np.array(r)
|
||||
print(x_train.shape)
|
||||
elif ds == 'xray':
|
||||
import torchvision
|
||||
transform = torchvision.transforms.Compose([
|
||||
torchvision.transforms.Resize(256),
|
||||
torchvision.transforms.CenterCrop(256),
|
||||
torchvision.transforms.ToTensor()])
|
||||
imagenet_data = torchvision.datasets.ImageFolder('CheXpert-v1.0/train',
|
||||
transform=transform)
|
||||
data_loader = torch.utils.data.DataLoader(imagenet_data,
|
||||
batch_size=100,
|
||||
shuffle=True,
|
||||
num_workers=8)
|
||||
r = []
|
||||
for x,_ in data_loader:
|
||||
if len(r) > 1000: break
|
||||
print(x.shape)
|
||||
r.extend(x.numpy())
|
||||
x_train = np.array(r)
|
||||
print(x_train.shape)
|
||||
elif ds == 'challenge':
|
||||
x_train = np.load("orig-7.npy")
|
||||
print(np.min(x_train), np.max(x_train), x_train.shape)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def gen_train_data():
|
||||
"""
|
||||
Generate aligned training data to train a patch similarity function.
|
||||
Given some original images, generate lots of encoded versions.
|
||||
"""
|
||||
global encoded_train, original_train
|
||||
|
||||
encoded_train = []
|
||||
original_train = []
|
||||
|
||||
args = Xray(True)
|
||||
|
||||
C = 100
|
||||
for i in range(30):
|
||||
print(i)
|
||||
torch.manual_seed(int(time.time()))
|
||||
e = PrivateEncoder(args).cuda()
|
||||
batch = np.random.randint(0, len(x_train), size=C)
|
||||
xin = x_train[batch]
|
||||
|
||||
r = []
|
||||
for i in range(0,C,32):
|
||||
r.extend(e(torch.tensor(xin[i:i+32]).cuda()).detach().cpu().numpy())
|
||||
r = np.array(r)
|
||||
|
||||
encoded_train.append(r)
|
||||
original_train.append(xin)
|
||||
|
||||
def features_(x, moments=15, encoded=False):
|
||||
"""
|
||||
Compute higher-order moments for patches in an image to use as
|
||||
features for the neural network.
|
||||
"""
|
||||
x = np.array(x, dtype=np.float32)
|
||||
dim = 2
|
||||
arr = np.array([np.mean(x, dim)] + [abs(scipy.stats.moment(x, moment=i, axis=dim))**(1/i) for i in range(1,moments)])
|
||||
|
||||
return arr.transpose((1,2,0))
|
||||
|
||||
|
||||
def features(x, encoded):
|
||||
"""
|
||||
Given the original images or the encoded images, generate the
|
||||
features to use for the patch similarity function.
|
||||
"""
|
||||
print('start shape',x.shape)
|
||||
if len(x.shape) == 3:
|
||||
x = x - np.mean(x,axis=0,keepdims=True)
|
||||
else:
|
||||
# count x 100 x 256 x 768
|
||||
print(x[0].shape)
|
||||
x = x - np.mean(x,axis=1,keepdims=True)
|
||||
# remove per-neural-network dimension
|
||||
x = x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
|
||||
p = mp.Pool(96)
|
||||
B = len(x) // 96
|
||||
print(1)
|
||||
bs = [x[i:i+B] for i in range(0,len(x),B)]
|
||||
print(2)
|
||||
r = p.map(features_, bs)
|
||||
#r = features_(bs[0][:100])
|
||||
print(3)
|
||||
p.close()
|
||||
#r = np.array(r)
|
||||
#print('finish',r.shape)
|
||||
return np.concatenate(r, axis=0)
|
||||
|
||||
|
||||
|
||||
def get_train_features():
|
||||
"""
|
||||
Create features for the entire datasets.
|
||||
"""
|
||||
global xs_train, ys_train
|
||||
print(x_train.shape)
|
||||
original_train_ = np.array(original_train)
|
||||
encoded_train_ = np.array(encoded_train)
|
||||
|
||||
print("Computing features")
|
||||
ys_train = features(encoded_train_, True)
|
||||
|
||||
patch_size = 16
|
||||
ss = original_train_.shape[3] // patch_size
|
||||
# Okay so this is an ugly transpose block.
|
||||
# We are going from [outer_batch, batch_size, channels, width, height
|
||||
# to [outer_batch, batch_size, channels, width/patch_size, patch_size, height/patch_size, patch_size]
|
||||
# Then we reshape this and flatten so that we end up with
|
||||
# [other_batch, batch_size, width/patch_size, height_patch_size, patch_size**2*channels]
|
||||
# So that now we can run features on the last dimension
|
||||
original_train_ = original_train_.reshape((original_train_.shape[0],
|
||||
original_train_.shape[1],
|
||||
original_train_.shape[2],
|
||||
ss,patch_size,ss,patch_size)).transpose((0,1,3,5,2,4,6)).reshape((original_train_.shape[0], original_train_.shape[1], ss**2, patch_size**2))
|
||||
|
||||
|
||||
xs_train = features(original_train_, False)
|
||||
|
||||
print(xs_train.shape, ys_train.shape)
|
||||
|
||||
|
||||
def train_model():
|
||||
"""
|
||||
Train the patch similarity function
|
||||
"""
|
||||
global ema, model
|
||||
|
||||
model = Model()
|
||||
def loss(x, y):
|
||||
"""
|
||||
K-way contrastive loss as in SimCLR et al.
|
||||
The idea is that we should embed x and y so that they are similar
|
||||
to each other, and dis-similar from others. To do this we have a
|
||||
softmx loss over one dimension to make the values large on the diagonal
|
||||
and small off-diagonal.
|
||||
"""
|
||||
a = model.encode(x)
|
||||
b = model.decode(y)
|
||||
|
||||
mat = a@b.T
|
||||
return objax.functional.loss.cross_entropy_logits_sparse(
|
||||
logits=jn.exp(jn.clip(model.scale.w.value, -2, 4)) * mat,
|
||||
labels=np.arange(a.shape[0])).mean()
|
||||
|
||||
ema = objax.optimizer.ExponentialMovingAverage(model.vars(), momentum=0.999)
|
||||
gv = objax.GradValues(loss, model.vars())
|
||||
|
||||
encode_ema = ema.replace_vars(lambda x: model.encode(x))
|
||||
decode_ema = ema.replace_vars(lambda y: model.decode(y))
|
||||
|
||||
def train_op(x, y):
|
||||
"""
|
||||
No one was ever fired for using Adam with 1e-4.
|
||||
"""
|
||||
g, v = gv(x, y)
|
||||
opt(1e-4, g)
|
||||
ema()
|
||||
return v
|
||||
|
||||
opt = objax.optimizer.Adam(model.vars())
|
||||
train_op = objax.Jit(train_op, gv.vars() + opt.vars() + ema.vars())
|
||||
|
||||
ys_ = ys_train
|
||||
|
||||
print(ys_.shape)
|
||||
|
||||
xs_ = xs_train.reshape((-1, xs_train.shape[-1]))
|
||||
ys_ = ys_.reshape((-1, ys_train.shape[-1]))
|
||||
|
||||
# The model scale trick here is taken from CLIP.
|
||||
# Let the model decide how confident to make its own predictions.
|
||||
model.scale.w.assign(jn.zeros((1,1)))
|
||||
|
||||
valid_size = 1000
|
||||
|
||||
print(xs_train.shape)
|
||||
# SimCLR likes big batches
|
||||
B = 4096
|
||||
for it in range(80):
|
||||
print()
|
||||
ms = []
|
||||
for i in range(1000):
|
||||
# First batch is smaller, to make training more stable
|
||||
bs = [B // 64, B][it>0]
|
||||
batch = np.random.randint(0, len(xs_)-valid_size, size=bs)
|
||||
r = train_op(xs_[batch], ys_[batch])
|
||||
|
||||
# This shouldn't happen, but if it does, better to bort early
|
||||
if np.isnan(r):
|
||||
print("Die on nan")
|
||||
print(ms[-100:])
|
||||
return
|
||||
ms.append(r)
|
||||
|
||||
print('mean',np.mean(ms), 'scale', model.scale.w.value)
|
||||
print('loss',loss(xs_[-100:], ys_[-100:]))
|
||||
|
||||
a = encode_ema(xs_[-valid_size:])
|
||||
b = decode_ema(ys_[-valid_size:])
|
||||
|
||||
br = b[np.random.permutation(len(b))]
|
||||
|
||||
print('score',np.mean(np.sum(a*b,axis=(1)) - np.sum(a*br,axis=(1))),
|
||||
np.mean(np.sum(a*b,axis=(1)) > np.sum(a*br,axis=(1))))
|
||||
ckpt = objax.io.Checkpoint("saved", keep_ckpts=0)
|
||||
ema.replace_vars(lambda: ckpt.save(model.vars(), 0))()
|
||||
|
||||
|
||||
|
||||
def load_challenge():
|
||||
"""
|
||||
Load the challenge datast for attacking
|
||||
"""
|
||||
global xs, ys, encoded, original, ooriginal
|
||||
print("SETUP: Loading matrixes")
|
||||
# The encoded images
|
||||
encoded = np.load("challenge-7.npy")
|
||||
# And the original images
|
||||
ooriginal = original = np.load("orig-7.npy")
|
||||
|
||||
print("Sizes", encoded.shape, ooriginal.shape)
|
||||
|
||||
# Again do that ugly resize thing to make the features be on the last dimension
|
||||
# Look up above to see what's going on.
|
||||
patch_size = 16
|
||||
ss = original.shape[2] // patch_size
|
||||
original = ooriginal.reshape((original.shape[0],1,ss,patch_size,ss,patch_size))
|
||||
original = original.transpose((0,2,4,1,3,5))
|
||||
original = original.reshape((original.shape[0], ss**2, patch_size**2))
|
||||
|
||||
|
||||
def match_sub(args):
|
||||
"""
|
||||
Find the best way to undo the permutation between two images.
|
||||
"""
|
||||
vec1, vec2 = args
|
||||
value = np.sum((vec1[None,:,:] - vec2[:,None,:])**2,axis=2)
|
||||
row, col = scipy.optimize.linear_sum_assignment(value)
|
||||
return col
|
||||
|
||||
|
||||
def recover_local_permutation():
|
||||
"""
|
||||
Given a set of encoded images, return a new encoding without permutations
|
||||
"""
|
||||
global encoded, ys
|
||||
|
||||
p = mp.Pool(96)
|
||||
print('recover local')
|
||||
local_perm = p.map(match_sub, [(encoded[0], e) for e in encoded])
|
||||
local_perm = np.array(local_perm)
|
||||
|
||||
encoded_perm = []
|
||||
|
||||
for i in range(len(encoded)):
|
||||
encoded_perm.append(encoded[i][np.argsort(local_perm[i])])
|
||||
|
||||
encoded_perm = np.array(encoded_perm)
|
||||
|
||||
encoded = np.array(encoded_perm)
|
||||
|
||||
p.close()
|
||||
|
||||
|
||||
def recover_better_local_permutation():
|
||||
"""
|
||||
Given a set of encoded images, return a new encoding, but better!
|
||||
"""
|
||||
global encoded, ys
|
||||
|
||||
# Now instead of pairing all images to image 0, we compute the mean l2 vector
|
||||
# and then pair all images onto the mean vector. Slightly more noise resistant.
|
||||
p = mp.Pool(96)
|
||||
target = encoded.mean(0)
|
||||
local_perm = p.map(match_sub, [(target, e) for e in encoded])
|
||||
local_perm = np.array(local_perm)
|
||||
|
||||
# Probably we didn't change by much, generally <0.1%
|
||||
print('improved changed by', np.mean(local_perm != np.arange(local_perm.shape[1])))
|
||||
|
||||
encoded_perm = []
|
||||
|
||||
for i in range(len(encoded)):
|
||||
encoded_perm.append(encoded[i][np.argsort(local_perm[i])])
|
||||
|
||||
encoded = np.array(encoded_perm)
|
||||
|
||||
p.close()
|
||||
|
||||
|
||||
def compute_patch_similarity():
|
||||
"""
|
||||
Compute the feature vectors for each patch using the trained neural network.
|
||||
"""
|
||||
global xs, ys, xs_image, ys_image
|
||||
|
||||
print("Computing features")
|
||||
ys = features(encoded, encoded=True)
|
||||
xs = features(original, encoded=False)
|
||||
|
||||
model = Model()
|
||||
ckpt = objax.io.Checkpoint("saved", keep_ckpts=0)
|
||||
ckpt.restore(model.vars())
|
||||
|
||||
xs_image = model.encode(xs)
|
||||
ys_image = model.decode(ys)
|
||||
assert xs.shape[0] == xs_image.shape[0]
|
||||
print("Done")
|
||||
|
||||
|
||||
def match(args, ret_col=False):
|
||||
"""
|
||||
Compute the similarity between image features and encoded features.
|
||||
"""
|
||||
vec1, vec2s = args
|
||||
r = []
|
||||
open("/tmp/start%d.%d"%(np.random.randint(10000),time.time()),"w").write("hi")
|
||||
for vec2 in vec2s:
|
||||
value = np.sum(vec1[None,:,:] * vec2[:,None,:],axis=2)
|
||||
|
||||
row, col = scipy.optimize.linear_sum_assignment(-value)
|
||||
r.append(value[row,col].mean())
|
||||
return r
|
||||
|
||||
|
||||
|
||||
def recover_global_matching_first():
|
||||
"""
|
||||
Recover the global matching of original to encoded images by doing
|
||||
an all-pairs matching problem
|
||||
"""
|
||||
global global_matching, ys_image, encoded
|
||||
|
||||
matrix = []
|
||||
p = mp.Pool(96)
|
||||
xs_image_ = np.array(xs_image)
|
||||
ys_image_ = np.array(ys_image)
|
||||
|
||||
matrix = p.map(match, [(x, ys_image_) for x in xs_image_])
|
||||
matrix = np.array(matrix).reshape((xs_image.shape[0],
|
||||
xs_image.shape[0]))
|
||||
|
||||
|
||||
row, col = scipy.optimize.linear_sum_assignment(-np.array(matrix))
|
||||
global_matching = np.argsort(col)
|
||||
print('glob',list(global_matching))
|
||||
|
||||
p.close()
|
||||
|
||||
|
||||
|
||||
def recover_global_permutation():
|
||||
"""
|
||||
Find the way that the encoded images are permuted off of the original images
|
||||
"""
|
||||
global global_permutation
|
||||
|
||||
print("Glob match", global_matching)
|
||||
overall = []
|
||||
for i,j in enumerate(global_matching):
|
||||
overall.append(np.sum(xs_image[j][None,:,:] * ys_image[i][:,None,:],axis=2))
|
||||
|
||||
overall = np.mean(overall, 0)
|
||||
|
||||
row, col = scipy.optimize.linear_sum_assignment(-overall)
|
||||
|
||||
try:
|
||||
print("Changed frac:", np.mean(global_permutation!=np.argsort(col)))
|
||||
except:
|
||||
pass
|
||||
|
||||
global_permutation = np.argsort(col)
|
||||
|
||||
|
||||
def recover_global_matching_second():
|
||||
"""
|
||||
Match each encoded image with its original encoded image,
|
||||
but better by relying on the global permutation.
|
||||
"""
|
||||
global global_matching_second, global_matching
|
||||
|
||||
ys_fix = []
|
||||
for i in range(ys_image.shape[0]):
|
||||
ys_fix.append(ys_image[i][global_permutation])
|
||||
ys_fix = np.array(ys_fix)
|
||||
|
||||
|
||||
print(xs_image.shape)
|
||||
|
||||
sims = []
|
||||
for i in range(0,len(xs_image),10):
|
||||
tmp = np.mean(xs_image[None,:,:,:] * ys_fix[i:i+10][:,None,:,:],axis=(2,3))
|
||||
sims.extend(tmp)
|
||||
sims = np.array(sims)
|
||||
print(sims.shape)
|
||||
|
||||
|
||||
row, col = scipy.optimize.linear_sum_assignment(-sims)
|
||||
|
||||
print('arg',sims.argmax(1))
|
||||
|
||||
print("Same matching frac", np.mean(col == global_matching) )
|
||||
print(col)
|
||||
global_matching = col
|
||||
|
||||
|
||||
def extract_by_training(resume):
|
||||
"""
|
||||
Final recovery process by extracting the neural network
|
||||
"""
|
||||
global inverse
|
||||
|
||||
device = torch.device('cuda:1')
|
||||
|
||||
if not resume:
|
||||
inverse = PrivateEncoder(Xray(True)).cuda(device)
|
||||
|
||||
# More adam to train.
|
||||
optimizer = torch.optim.Adam(inverse.parameters(), lr=0.0001)
|
||||
|
||||
this_xs = ooriginal[global_matching]
|
||||
this_ys = encoded[:,global_permutation,:]
|
||||
|
||||
for i in range(2000):
|
||||
idx = np.random.random_integers(0, len(this_xs)-1, 32)
|
||||
xbatch = torch.tensor(this_xs[idx]).cuda(device)
|
||||
ybatch = torch.tensor(this_ys[idx]).cuda(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
guess_output = inverse(xbatch)
|
||||
# L1 loss because we don't want to be sensitive to outliers
|
||||
error = torch.mean(torch.abs(guess_output-ybatch))
|
||||
error.backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
print(error)
|
||||
|
||||
|
||||
|
||||
def test_extract():
|
||||
"""
|
||||
Now we can recover the matching much better by computing the estimated
|
||||
encodings for each original image.
|
||||
"""
|
||||
global err, global_matching, guessed_encoded, smatrix
|
||||
|
||||
device = torch.device('cuda:1')
|
||||
|
||||
print(ooriginal.shape, encoded.shape)
|
||||
|
||||
out = []
|
||||
for i in range(0,len(ooriginal),32):
|
||||
print(i)
|
||||
out.extend(inverse(torch.tensor(ooriginal[i:i+32]).cuda(device)).cpu().detach().numpy())
|
||||
|
||||
guessed_encoded = np.array(out)
|
||||
|
||||
|
||||
# Now we have to compare each encoded image with every other original image.
|
||||
# Do this fast with some matrix multiplies.
|
||||
|
||||
out = guessed_encoded.reshape((len(encoded), -1))
|
||||
real = encoded[:,global_permutation,:].reshape((len(encoded), -1))
|
||||
@jax.jit
|
||||
def foo(x, y):
|
||||
return jn.square(x[:,None] - y[None,:]).sum(2)
|
||||
|
||||
smatrix = np.zeros((len(out), len(out)))
|
||||
|
||||
B = 500
|
||||
for i in range(0,len(out),B):
|
||||
print(i)
|
||||
for j in range(0,len(out),B):
|
||||
smatrix[i:i+B, j:j+B] = foo(out[i:i+B], real[j:j+B])
|
||||
|
||||
# And the final time you'l have to look at a min weight matching, I promise.
|
||||
row, col = scipy.optimize.linear_sum_assignment(np.array(smatrix))
|
||||
r = np.array(smatrix)
|
||||
|
||||
print(list(row)[::100])
|
||||
|
||||
print("Differences", np.mean(np.argsort(col) != global_matching))
|
||||
|
||||
global_matching = np.argsort(col)
|
||||
|
||||
|
||||
def perf(steps=[]):
|
||||
if len(steps) == 0:
|
||||
steps.append(time.time())
|
||||
else:
|
||||
print("Last Time Elapsed:", time.time()-steps[-1], ' Total Time Elapsed:', time.time()-steps[0])
|
||||
steps.append(time.time())
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if True:
|
||||
perf()
|
||||
setup('challenge')
|
||||
perf()
|
||||
gen_train_data()
|
||||
perf()
|
||||
get_train_features()
|
||||
perf()
|
||||
train_model()
|
||||
perf()
|
||||
|
||||
if True:
|
||||
load_challenge()
|
||||
perf()
|
||||
recover_local_permutation()
|
||||
perf()
|
||||
recover_better_local_permutation()
|
||||
perf()
|
||||
compute_patch_similarity()
|
||||
perf()
|
||||
recover_global_matching_first()
|
||||
perf()
|
||||
|
||||
for _ in range(3):
|
||||
recover_global_permutation()
|
||||
perf()
|
||||
recover_global_matching_second()
|
||||
perf()
|
||||
|
||||
for i in range(3):
|
||||
recover_global_permutation()
|
||||
perf()
|
||||
extract_by_training(i > 0)
|
||||
perf()
|
||||
test_extract()
|
||||
perf()
|
||||
print(perf())
|
Loading…
Reference in a new issue