# Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # This program solves the NeuraCrypt challenge to 100% accuracy. # Given a set of encoded images and original versions of those, # it shows how to match the original to the encoded. import collections import hashlib import time import multiprocessing as mp import torch import numpy as np import torch.nn as nn import scipy.stats import matplotlib.pyplot as plt from PIL import Image import jax import jax.numpy as jn import objax import scipy.optimize import numpy as np import multiprocessing as mp # Objax neural network that's going to embed patches to a # low dimensional space to guess if two patches correspond # to the same orginal image. class Model(objax.Module): def __init__(self): IN = 15 H = 64 self.encoder =objax.nn.Sequential([ objax.nn.Linear(IN, H), objax.functional.leaky_relu, objax.nn.Linear(H, H), objax.functional.leaky_relu, objax.nn.Linear(H, 8)]) self.decoder =objax.nn.Sequential([ objax.nn.Linear(IN, H), objax.functional.leaky_relu, objax.nn.Linear(H, H), objax.functional.leaky_relu, objax.nn.Linear(H, 8)]) self.scale = objax.nn.Linear(1, 1, use_bias=False) def encode(self, x): # Encode turns original images into feature space a = self.encoder(x) a = a/jn.sum(a**2,axis=-1,keepdims=True)**.5 return a def decode(self, x): # And decode turns encoded images into feature space a = self.decoder(x) a = a/jn.sum(a**2,axis=-1,keepdims=True)**.5 return a # Proxy dataset for analysis class ImageNet: num_chan = 3 private_kernel_size = 16 hidden_dim = 2048 img_size = (256, 256) private_depth = 7 def __init__(self, remove): self.remove_pixel_shuffle = remove # Original dataset as used in the NeuraCrypt paper class Xray: num_chan = 1 private_kernel_size = 16 hidden_dim = 2048 img_size = (256, 256) private_depth = 4 def __init__(self, remove): self.remove_pixel_shuffle = remove ## The following class is taken directly from the NeuraCrypt codebase. ## https://github.com/yala/NeuraCrypt ## which is originally licensed under the MIT License class PrivateEncoder(nn.Module): def __init__(self, args, width_factor=1): super(PrivateEncoder, self).__init__() self.args = args input_dim = args.num_chan patch_size = args.private_kernel_size output_dim = args.hidden_dim num_patches = (args.img_size[0] // patch_size) **2 self.noise_size = 1 args.input_dim = args.hidden_dim layers = [ nn.Conv2d(input_dim, output_dim * width_factor, kernel_size=patch_size, dilation=1 ,stride=patch_size), nn.ReLU() ] for _ in range(self.args.private_depth): layers.extend( [ nn.Conv2d(output_dim * width_factor, output_dim * width_factor , kernel_size=1, dilation=1, stride=1), nn.BatchNorm2d(output_dim * width_factor, track_running_stats=False), nn.ReLU() ]) self.image_encoder = nn.Sequential(*layers) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, output_dim * width_factor)) self.mixer = nn.Sequential( *[ nn.ReLU(), nn.Linear(output_dim * width_factor, output_dim) ]) def forward(self, x): encoded = self.image_encoder(x) B, C, H,W = encoded.size() encoded = encoded.view([B, -1, H*W]).transpose(1,2) encoded += self.pos_embedding encoded = self.mixer(encoded) ## Shuffle indicies if not self.args.remove_pixel_shuffle: shuffled = torch.zeros_like(encoded) for i in range(B): idx = torch.randperm(H*W, device=encoded.device) for j, k in enumerate(idx): shuffled[i,j] = encoded[i,k] encoded = shuffled return encoded ## End copied code def setup(ds): """ Load the datasets to use. Nothing interesting to see. """ global x_train, y_train if ds == 'imagenet': import torchvision transform = torchvision.transforms.Compose([ torchvision.transforms.Resize(256), torchvision.transforms.CenterCrop(256), torchvision.transforms.ToTensor()]) imagenet_data = torchvision.datasets.ImageNet('/mnt/data/datasets/unpacked_imagenet_pytorch/', split='val', transform=transform) data_loader = torch.utils.data.DataLoader(imagenet_data, batch_size=100, shuffle=True, num_workers=8) r = [] for x,_ in data_loader: if len(r) > 1000: break print(x.shape) r.extend(x.numpy()) x_train = np.array(r) print(x_train.shape) elif ds == 'xray': import torchvision transform = torchvision.transforms.Compose([ torchvision.transforms.Resize(256), torchvision.transforms.CenterCrop(256), torchvision.transforms.ToTensor()]) imagenet_data = torchvision.datasets.ImageFolder('CheXpert-v1.0/train', transform=transform) data_loader = torch.utils.data.DataLoader(imagenet_data, batch_size=100, shuffle=True, num_workers=8) r = [] for x,_ in data_loader: if len(r) > 1000: break print(x.shape) r.extend(x.numpy()) x_train = np.array(r) print(x_train.shape) elif ds == 'challenge': x_train = np.load("orig-7.npy") print(np.min(x_train), np.max(x_train), x_train.shape) else: raise def gen_train_data(): """ Generate aligned training data to train a patch similarity function. Given some original images, generate lots of encoded versions. """ global encoded_train, original_train encoded_train = [] original_train = [] args = Xray(True) C = 100 for i in range(30): print(i) torch.manual_seed(int(time.time())) e = PrivateEncoder(args).cuda() batch = np.random.randint(0, len(x_train), size=C) xin = x_train[batch] r = [] for i in range(0,C,32): r.extend(e(torch.tensor(xin[i:i+32]).cuda()).detach().cpu().numpy()) r = np.array(r) encoded_train.append(r) original_train.append(xin) def features_(x, moments=15, encoded=False): """ Compute higher-order moments for patches in an image to use as features for the neural network. """ x = np.array(x, dtype=np.float32) dim = 2 arr = np.array([np.mean(x, dim)] + [abs(scipy.stats.moment(x, moment=i, axis=dim))**(1/i) for i in range(1,moments)]) return arr.transpose((1,2,0)) def features(x, encoded): """ Given the original images or the encoded images, generate the features to use for the patch similarity function. """ print('start shape',x.shape) if len(x.shape) == 3: x = x - np.mean(x,axis=0,keepdims=True) else: # count x 100 x 256 x 768 print(x[0].shape) x = x - np.mean(x,axis=1,keepdims=True) # remove per-neural-network dimension x = x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:]) p = mp.Pool(96) B = len(x)//96 print(1) bs = [x[i:i+B] for i in range(0,len(x),B)] print(2) r = p.map(features_, bs) #r = features_(bs[0][:100]) print(3) p.close() #r = np.array(r) #print('finish',r.shape) return np.concatenate(r, axis=0) def get_train_features(): """ Create features for the entire datasets. """ global xs_train, ys_train print(x_train.shape) original_train_ = np.array(original_train) encoded_train_ = np.array(encoded_train) print("Computing features") ys_train = features(encoded_train_, True) patch_size = 16 ss = original_train_.shape[3]//patch_size # Okay so this is an ugly transpose block. # We are going from [outer_batch, batch_size, channels, width, height # to [outer_batch, batch_size, channels, width/patch_size, patch_size, height/patch_size, patch_size] # Then we reshape this and flatten so that we end up with # [other_batch, batch_size, width/patch_size, height_patch_size, patch_size**2*channels] # So that now we can run features on the last dimension original_train_ = original_train_.reshape((original_train_.shape[0], original_train_.shape[1], original_train_.shape[2], ss,patch_size,ss,patch_size)).transpose((0,1,3,5,2,4,6)).reshape((original_train_.shape[0], original_train_.shape[1], ss**2, patch_size**2)) xs_train = features(original_train_, False) print(xs_train.shape, ys_train.shape) def train_model(): """ Train the patch similarity function """ global ema, model model = Model() def loss(x, y): """ K-way contrastive loss as in SimCLR et al. The idea is that we should embed x and y so that they are similar to each other, and dis-similar from others. To do this we have a softmx loss over one dimension to make the values large on the diagonal and small off-diagonal. """ a = model.encode(x) b = model.decode(y) mat = a@b.T return objax.functional.loss.cross_entropy_logits_sparse( logits=jn.exp(jn.clip(model.scale.w.value, -2, 4)) * mat, labels=np.arange(a.shape[0])).mean() ema = objax.optimizer.ExponentialMovingAverage(model.vars(), momentum=0.999) gv = objax.GradValues(loss, model.vars()) encode_ema = ema.replace_vars(lambda x: model.encode(x)) decode_ema = ema.replace_vars(lambda y: model.decode(y)) def train_op(x, y): """ No one was ever fired for using Adam with 1e-4. """ g, v = gv(x, y) opt(1e-4, g) ema() return v opt = objax.optimizer.Adam(model.vars()) train_op = objax.Jit(train_op, gv.vars() + opt.vars() + ema.vars()) ys_ = ys_train print(ys_.shape) xs_ = xs_train.reshape((-1, xs_train.shape[-1])) ys_ = ys_.reshape((-1, ys_train.shape[-1])) # The model scale trick here is taken from CLIP. # Let the model decide how confident to make its own predictions. model.scale.w.assign(jn.zeros((1,1))) valid_size = 1000 print(xs_train.shape) # SimCLR likes big batches B = 4096 for it in range(80): print() ms = [] for i in range(1000): # First batch is smaller, to make training more stable bs = [B//64, B][it>0] batch = np.random.randint(0, len(xs_)-valid_size, size=bs) r = train_op(xs_[batch], ys_[batch]) # This shouldn't happen, but if it does, better to bort early if np.isnan(r): print("Die on nan") print(ms[-100:]) return ms.append(r) print('mean',np.mean(ms), 'scale', model.scale.w.value) print('loss',loss(xs_[-100:], ys_[-100:])) a = encode_ema(xs_[-valid_size:]) b = decode_ema(ys_[-valid_size:]) br = b[np.random.permutation(len(b))] print('score',np.mean(np.sum(a*b,axis=(1)) - np.sum(a*br,axis=(1))), np.mean(np.sum(a*b,axis=(1)) > np.sum(a*br,axis=(1)))) ckpt = objax.io.Checkpoint("saved", keep_ckpts=0) ema.replace_vars(lambda: ckpt.save(model.vars(), 0))() def load_challenge(): """ Load the challenge datast for attacking """ global xs, ys, encoded, original, ooriginal print("SETUP: Loading matrixes") # The encoded images encoded = np.load("challenge-7.npy") # And the original images ooriginal = original = np.load("orig-7.npy") print("Sizes", encoded.shape, ooriginal.shape) # Again do that ugly resize thing to make the features be on the last dimension # Look up above to see what's going on. patch_size = 16 ss = original.shape[2]//patch_size original = ooriginal.reshape((original.shape[0],1,ss,patch_size,ss,patch_size)) original = original.transpose((0,2,4,1,3,5)) original = original.reshape((original.shape[0], ss**2, patch_size**2)) def match_sub(args): """ Find the best way to undo the permutation between two images. """ vec1, vec2 = args value = np.sum((vec1[None,:,:] - vec2[:,None,:])**2,axis=2) row, col = scipy.optimize.linear_sum_assignment(value) return col def recover_local_permutation(): """ Given a set of encoded images, return a new encoding without permutations """ global encoded, ys p = mp.Pool(96) print('recover local') local_perm = p.map(match_sub, [(encoded[0], e) for e in encoded]) local_perm = np.array(local_perm) encoded_perm = [] for i in range(len(encoded)): encoded_perm.append(encoded[i][np.argsort(local_perm[i])]) encoded_perm = np.array(encoded_perm) encoded = np.array(encoded_perm) p.close() def recover_better_local_permutation(): """ Given a set of encoded images, return a new encoding, but better! """ global encoded, ys # Now instead of pairing all images to image 0, we compute the mean l2 vector # and then pair all images onto the mean vector. Slightly more noise resistant. p = mp.Pool(96) target = encoded.mean(0) local_perm = p.map(match_sub, [(target, e) for e in encoded]) local_perm = np.array(local_perm) # Probably we didn't change by much, generally <0.1% print('improved changed by', np.mean(local_perm != np.arange(local_perm.shape[1]))) encoded_perm = [] for i in range(len(encoded)): encoded_perm.append(encoded[i][np.argsort(local_perm[i])]) encoded = np.array(encoded_perm) p.close() def compute_patch_similarity(): """ Compute the feature vectors for each patch using the trained neural network. """ global xs, ys, xs_image, ys_image print("Computing features") ys = features(encoded, encoded=True) xs = features(original, encoded=False) model = Model() ckpt = objax.io.Checkpoint("saved", keep_ckpts=0) ckpt.restore(model.vars()) xs_image = model.encode(xs) ys_image = model.decode(ys) assert xs.shape[0] == xs_image.shape[0] print("Done") def match(args, ret_col=False): """ Compute the similarity between image features and encoded features. """ vec1, vec2s = args r = [] open("/tmp/start%d.%d"%(np.random.randint(10000),time.time()),"w").write("hi") for vec2 in vec2s: value = np.sum(vec1[None,:,:] * vec2[:,None,:],axis=2) row, col = scipy.optimize.linear_sum_assignment(-value) r.append(value[row,col].mean()) return r def recover_global_matching_first(): """ Recover the global matching of original to encoded images by doing an all-pairs matching problem """ global global_matching, ys_image, encoded matrix = [] p = mp.Pool(96) xs_image_ = np.array(xs_image) ys_image_ = np.array(ys_image) matrix = p.map(match, [(x, ys_image_) for x in xs_image_]) matrix = np.array(matrix).reshape((xs_image.shape[0], xs_image.shape[0])) row, col = scipy.optimize.linear_sum_assignment(-np.array(matrix)) global_matching = np.argsort(col) print('glob',list(global_matching)) p.close() def recover_global_permutation(): """ Find the way that the encoded images are permuted off of the original images """ global global_permutation print("Glob match", global_matching) overall = [] for i,j in enumerate(global_matching): overall.append(np.sum(xs_image[j][None,:,:] * ys_image[i][:,None,:],axis=2)) overall = np.mean(overall, 0) row, col = scipy.optimize.linear_sum_assignment(-overall) try: print("Changed frac:", np.mean(global_permutation!=np.argsort(col))) except: pass global_permutation = np.argsort(col) def recover_global_matching_second(): """ Match each encoded image with its original encoded image, but better by relying on the global permutation. """ global global_matching_second, global_matching ys_fix = [] for i in range(ys_image.shape[0]): ys_fix.append(ys_image[i][global_permutation]) ys_fix = np.array(ys_fix) print(xs_image.shape) sims = [] for i in range(0,len(xs_image),10): tmp = np.mean(xs_image[None,:,:,:] * ys_fix[i:i+10][:,None,:,:],axis=(2,3)) sims.extend(tmp) sims = np.array(sims) print(sims.shape) row, col = scipy.optimize.linear_sum_assignment(-sims) print('arg',sims.argmax(1)) print("Same matching frac", np.mean(col == global_matching) ) print(col) global_matching = col def extract_by_training(resume): """ Final recovery process by extracting the neural network """ global inverse device = torch.device('cuda:1') if not resume: inverse = PrivateEncoder(Xray(True)).cuda(device) # More adam to train. optimizer = torch.optim.Adam(inverse.parameters(), lr=0.0001) this_xs = ooriginal[global_matching] this_ys = encoded[:,global_permutation,:] for i in range(2000): idx = np.random.random_integers(0, len(this_xs)-1, 32) xbatch = torch.tensor(this_xs[idx]).cuda(device) ybatch = torch.tensor(this_ys[idx]).cuda(device) optimizer.zero_grad() guess_output = inverse(xbatch) # L1 loss because we don't want to be sensitive to outliers error = torch.mean(torch.abs(guess_output-ybatch)) error.backward() optimizer.step() print(error) def test_extract(): """ Now we can recover the matching much better by computing the estimated encodings for each original image. """ global err, global_matching, guessed_encoded, smatrix device = torch.device('cuda:1') print(ooriginal.shape, encoded.shape) out = [] for i in range(0,len(ooriginal),32): print(i) out.extend(inverse(torch.tensor(ooriginal[i:i+32]).cuda(device)).cpu().detach().numpy()) guessed_encoded = np.array(out) # Now we have to compare each encoded image with every other original image. # Do this fast with some matrix multiplies. out = guessed_encoded.reshape((len(encoded), -1)) real = encoded[:,global_permutation,:].reshape((len(encoded), -1)) @jax.jit def foo(x, y): return jn.square(x[:,None] - y[None,:]).sum(2) smatrix = np.zeros((len(out), len(out))) B = 500 for i in range(0,len(out),B): print(i) for j in range(0,len(out),B): smatrix[i:i+B, j:j+B] = foo(out[i:i+B], real[j:j+B]) # And the final time you'l have to look at a min weight matching, I promise. row, col = scipy.optimize.linear_sum_assignment(np.array(smatrix)) r = np.array(smatrix) print(list(row)[::100]) print("Differences", np.mean(np.argsort(col) != global_matching)) global_matching = np.argsort(col) def perf(steps=[]): if len(steps) == 0: steps.append(time.time()) else: print("Last Time Elapsed:", time.time()-steps[-1], ' Total Time Elapsed:', time.time()-steps[0]) steps.append(time.time()) time.sleep(1) if __name__ == "__main__": if True: perf() setup('challenge') perf() gen_train_data() perf() get_train_features() perf() train_model() perf() if True: load_challenge() perf() recover_local_permutation() perf() recover_better_local_permutation() perf() compute_patch_similarity() perf() recover_global_matching_first() perf() for _ in range(3): recover_global_permutation() perf() recover_global_matching_second() perf() for i in range(3): recover_global_permutation() perf() extract_by_training(i > 0) perf() test_extract() perf() print(perf())