diff --git a/research/neuracrypt_attack_2021/attack.py b/research/neuracrypt_attack_2021/attack.py new file mode 100644 index 0000000..dc2e91a --- /dev/null +++ b/research/neuracrypt_attack_2021/attack.py @@ -0,0 +1,712 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# This program solves the NeuraCrypt challenge to 100% accuracy. +# Given a set of encoded images and original versions of those, +# it shows how to match the original to the encoded. + +import collections +import hashlib +import time +import multiprocessing as mp + + +import torch +import numpy as np +import torch.nn as nn +import scipy.stats +import matplotlib.pyplot as plt +from PIL import Image + +import jax +import jax.numpy as jn +import objax +import scipy.optimize +import numpy as np +import multiprocessing as mp + +# Objax neural network that's going to embed patches to a +# low dimensional space to guess if two patches correspond +# to the same orginal image. +class Model(objax.Module): + def __init__(self): + IN = 15 + H = 64 + self.encoder =objax.nn.Sequential([ + objax.nn.Linear(IN, H), + objax.functional.leaky_relu, + objax.nn.Linear(H, H), + objax.functional.leaky_relu, + objax.nn.Linear(H, 8)]) + self.decoder =objax.nn.Sequential([ + objax.nn.Linear(IN, H), + objax.functional.leaky_relu, + objax.nn.Linear(H, H), + objax.functional.leaky_relu, + objax.nn.Linear(H, 8)]) + self.scale = objax.nn.Linear(1, 1, use_bias=False) + def encode(self, x): + # Encode turns original images into feature space + a = self.encoder(x) + a = a/jn.sum(a**2,axis=-1,keepdims=True)**.5 + return a + def decode(self, x): + # And decode turns encoded images into feature space + a = self.decoder(x) + a = a/jn.sum(a**2,axis=-1,keepdims=True)**.5 + return a + +# Proxy dataset for analysis +class ImageNet: + num_chan = 3 + private_kernel_size = 16 + hidden_dim = 2048 + img_size = (256, 256) + private_depth = 7 + def __init__(self, remove): + self.remove_pixel_shuffle = remove + +# Original dataset as used in the NeuraCrypt paper +class Xray: + num_chan = 1 + private_kernel_size = 16 + hidden_dim = 2048 + img_size = (256, 256) + private_depth = 4 + def __init__(self, remove): + self.remove_pixel_shuffle = remove + +## The following class is taken directly from the NeuraCrypt codebase. +## https://github.com/yala/NeuraCrypt +## which is originally licensed under the MIT License +class PrivateEncoder(nn.Module): + def __init__(self, args, width_factor=1): + super(PrivateEncoder, self).__init__() + self.args = args + input_dim = args.num_chan + patch_size = args.private_kernel_size + output_dim = args.hidden_dim + num_patches = (args.img_size[0] // patch_size) **2 + self.noise_size = 1 + + args.input_dim = args.hidden_dim + + + layers = [ + nn.Conv2d(input_dim, output_dim * width_factor, kernel_size=patch_size, dilation=1 ,stride=patch_size), + nn.ReLU() + ] + for _ in range(self.args.private_depth): + layers.extend( [ + nn.Conv2d(output_dim * width_factor, output_dim * width_factor , kernel_size=1, dilation=1, stride=1), + nn.BatchNorm2d(output_dim * width_factor, track_running_stats=False), + nn.ReLU() + ]) + + + self.image_encoder = nn.Sequential(*layers) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, output_dim * width_factor)) + + self.mixer = nn.Sequential( *[ + nn.ReLU(), + nn.Linear(output_dim * width_factor, output_dim) + ]) + + + def forward(self, x): + encoded = self.image_encoder(x) + B, C, H,W = encoded.size() + encoded = encoded.view([B, -1, H*W]).transpose(1,2) + encoded += self.pos_embedding + encoded = self.mixer(encoded) + + ## Shuffle indicies + if not self.args.remove_pixel_shuffle: + shuffled = torch.zeros_like(encoded) + for i in range(B): + idx = torch.randperm(H*W, device=encoded.device) + for j, k in enumerate(idx): + shuffled[i,j] = encoded[i,k] + encoded = shuffled + + return encoded +## End copied code + +def setup(ds): + """ + Load the datasets to use. Nothing interesting to see. + """ + global x_train, y_train + if ds == 'imagenet': + import torchvision + transform = torchvision.transforms.Compose([ + torchvision.transforms.Resize(256), + torchvision.transforms.CenterCrop(256), + torchvision.transforms.ToTensor()]) + imagenet_data = torchvision.datasets.ImageNet('/mnt/data/datasets/unpacked_imagenet_pytorch/', + split='val', + transform=transform) + data_loader = torch.utils.data.DataLoader(imagenet_data, + batch_size=100, + shuffle=True, + num_workers=8) + r = [] + for x,_ in data_loader: + if len(r) > 1000: break + print(x.shape) + r.extend(x.numpy()) + x_train = np.array(r) + print(x_train.shape) + elif ds == 'xray': + import torchvision + transform = torchvision.transforms.Compose([ + torchvision.transforms.Resize(256), + torchvision.transforms.CenterCrop(256), + torchvision.transforms.ToTensor()]) + imagenet_data = torchvision.datasets.ImageFolder('CheXpert-v1.0/train', + transform=transform) + data_loader = torch.utils.data.DataLoader(imagenet_data, + batch_size=100, + shuffle=True, + num_workers=8) + r = [] + for x,_ in data_loader: + if len(r) > 1000: break + print(x.shape) + r.extend(x.numpy()) + x_train = np.array(r) + print(x_train.shape) + elif ds == 'challenge': + x_train = np.load("orig-7.npy") + print(np.min(x_train), np.max(x_train), x_train.shape) + else: + raise + + +def gen_train_data(): + """ + Generate aligned training data to train a patch similarity function. + Given some original images, generate lots of encoded versions. + """ + global encoded_train, original_train + + encoded_train = [] + original_train = [] + + args = Xray(True) + + C = 100 + for i in range(30): + print(i) + torch.manual_seed(int(time.time())) + e = PrivateEncoder(args).cuda() + batch = np.random.randint(0, len(x_train), size=C) + xin = x_train[batch] + + r = [] + for i in range(0,C,32): + r.extend(e(torch.tensor(xin[i:i+32]).cuda()).detach().cpu().numpy()) + r = np.array(r) + + encoded_train.append(r) + original_train.append(xin) + +def features_(x, moments=15, encoded=False): + """ + Compute higher-order moments for patches in an image to use as + features for the neural network. + """ + x = np.array(x, dtype=np.float32) + dim = 2 + arr = np.array([np.mean(x, dim)] + [abs(scipy.stats.moment(x, moment=i, axis=dim))**(1/i) for i in range(1,moments)]) + + return arr.transpose((1,2,0)) + + +def features(x, encoded): + """ + Given the original images or the encoded images, generate the + features to use for the patch similarity function. + """ + print('start shape',x.shape) + if len(x.shape) == 3: + x = x - np.mean(x,axis=0,keepdims=True) + else: + # count x 100 x 256 x 768 + print(x[0].shape) + x = x - np.mean(x,axis=1,keepdims=True) + # remove per-neural-network dimension + x = x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:]) + p = mp.Pool(96) + B = len(x) // 96 + print(1) + bs = [x[i:i+B] for i in range(0,len(x),B)] + print(2) + r = p.map(features_, bs) + #r = features_(bs[0][:100]) + print(3) + p.close() + #r = np.array(r) + #print('finish',r.shape) + return np.concatenate(r, axis=0) + + + +def get_train_features(): + """ + Create features for the entire datasets. + """ + global xs_train, ys_train + print(x_train.shape) + original_train_ = np.array(original_train) + encoded_train_ = np.array(encoded_train) + + print("Computing features") + ys_train = features(encoded_train_, True) + + patch_size = 16 + ss = original_train_.shape[3] // patch_size + # Okay so this is an ugly transpose block. + # We are going from [outer_batch, batch_size, channels, width, height + # to [outer_batch, batch_size, channels, width/patch_size, patch_size, height/patch_size, patch_size] + # Then we reshape this and flatten so that we end up with + # [other_batch, batch_size, width/patch_size, height_patch_size, patch_size**2*channels] + # So that now we can run features on the last dimension + original_train_ = original_train_.reshape((original_train_.shape[0], + original_train_.shape[1], + original_train_.shape[2], + ss,patch_size,ss,patch_size)).transpose((0,1,3,5,2,4,6)).reshape((original_train_.shape[0], original_train_.shape[1], ss**2, patch_size**2)) + + + xs_train = features(original_train_, False) + + print(xs_train.shape, ys_train.shape) + + +def train_model(): + """ + Train the patch similarity function + """ + global ema, model + + model = Model() + def loss(x, y): + """ + K-way contrastive loss as in SimCLR et al. + The idea is that we should embed x and y so that they are similar + to each other, and dis-similar from others. To do this we have a + softmx loss over one dimension to make the values large on the diagonal + and small off-diagonal. + """ + a = model.encode(x) + b = model.decode(y) + + mat = a@b.T + return objax.functional.loss.cross_entropy_logits_sparse( + logits=jn.exp(jn.clip(model.scale.w.value, -2, 4)) * mat, + labels=np.arange(a.shape[0])).mean() + + ema = objax.optimizer.ExponentialMovingAverage(model.vars(), momentum=0.999) + gv = objax.GradValues(loss, model.vars()) + + encode_ema = ema.replace_vars(lambda x: model.encode(x)) + decode_ema = ema.replace_vars(lambda y: model.decode(y)) + + def train_op(x, y): + """ + No one was ever fired for using Adam with 1e-4. + """ + g, v = gv(x, y) + opt(1e-4, g) + ema() + return v + + opt = objax.optimizer.Adam(model.vars()) + train_op = objax.Jit(train_op, gv.vars() + opt.vars() + ema.vars()) + + ys_ = ys_train + + print(ys_.shape) + + xs_ = xs_train.reshape((-1, xs_train.shape[-1])) + ys_ = ys_.reshape((-1, ys_train.shape[-1])) + + # The model scale trick here is taken from CLIP. + # Let the model decide how confident to make its own predictions. + model.scale.w.assign(jn.zeros((1,1))) + + valid_size = 1000 + + print(xs_train.shape) + # SimCLR likes big batches + B = 4096 + for it in range(80): + print() + ms = [] + for i in range(1000): + # First batch is smaller, to make training more stable + bs = [B // 64, B][it>0] + batch = np.random.randint(0, len(xs_)-valid_size, size=bs) + r = train_op(xs_[batch], ys_[batch]) + + # This shouldn't happen, but if it does, better to bort early + if np.isnan(r): + print("Die on nan") + print(ms[-100:]) + return + ms.append(r) + + print('mean',np.mean(ms), 'scale', model.scale.w.value) + print('loss',loss(xs_[-100:], ys_[-100:])) + + a = encode_ema(xs_[-valid_size:]) + b = decode_ema(ys_[-valid_size:]) + + br = b[np.random.permutation(len(b))] + + print('score',np.mean(np.sum(a*b,axis=(1)) - np.sum(a*br,axis=(1))), + np.mean(np.sum(a*b,axis=(1)) > np.sum(a*br,axis=(1)))) + ckpt = objax.io.Checkpoint("saved", keep_ckpts=0) + ema.replace_vars(lambda: ckpt.save(model.vars(), 0))() + + + +def load_challenge(): + """ + Load the challenge datast for attacking + """ + global xs, ys, encoded, original, ooriginal + print("SETUP: Loading matrixes") + # The encoded images + encoded = np.load("challenge-7.npy") + # And the original images + ooriginal = original = np.load("orig-7.npy") + + print("Sizes", encoded.shape, ooriginal.shape) + + # Again do that ugly resize thing to make the features be on the last dimension + # Look up above to see what's going on. + patch_size = 16 + ss = original.shape[2] // patch_size + original = ooriginal.reshape((original.shape[0],1,ss,patch_size,ss,patch_size)) + original = original.transpose((0,2,4,1,3,5)) + original = original.reshape((original.shape[0], ss**2, patch_size**2)) + + +def match_sub(args): + """ + Find the best way to undo the permutation between two images. + """ + vec1, vec2 = args + value = np.sum((vec1[None,:,:] - vec2[:,None,:])**2,axis=2) + row, col = scipy.optimize.linear_sum_assignment(value) + return col + + +def recover_local_permutation(): + """ + Given a set of encoded images, return a new encoding without permutations + """ + global encoded, ys + + p = mp.Pool(96) + print('recover local') + local_perm = p.map(match_sub, [(encoded[0], e) for e in encoded]) + local_perm = np.array(local_perm) + + encoded_perm = [] + + for i in range(len(encoded)): + encoded_perm.append(encoded[i][np.argsort(local_perm[i])]) + + encoded_perm = np.array(encoded_perm) + + encoded = np.array(encoded_perm) + + p.close() + + +def recover_better_local_permutation(): + """ + Given a set of encoded images, return a new encoding, but better! + """ + global encoded, ys + + # Now instead of pairing all images to image 0, we compute the mean l2 vector + # and then pair all images onto the mean vector. Slightly more noise resistant. + p = mp.Pool(96) + target = encoded.mean(0) + local_perm = p.map(match_sub, [(target, e) for e in encoded]) + local_perm = np.array(local_perm) + + # Probably we didn't change by much, generally <0.1% + print('improved changed by', np.mean(local_perm != np.arange(local_perm.shape[1]))) + + encoded_perm = [] + + for i in range(len(encoded)): + encoded_perm.append(encoded[i][np.argsort(local_perm[i])]) + + encoded = np.array(encoded_perm) + + p.close() + + +def compute_patch_similarity(): + """ + Compute the feature vectors for each patch using the trained neural network. + """ + global xs, ys, xs_image, ys_image + + print("Computing features") + ys = features(encoded, encoded=True) + xs = features(original, encoded=False) + + model = Model() + ckpt = objax.io.Checkpoint("saved", keep_ckpts=0) + ckpt.restore(model.vars()) + + xs_image = model.encode(xs) + ys_image = model.decode(ys) + assert xs.shape[0] == xs_image.shape[0] + print("Done") + + +def match(args, ret_col=False): + """ + Compute the similarity between image features and encoded features. + """ + vec1, vec2s = args + r = [] + open("/tmp/start%d.%d"%(np.random.randint(10000),time.time()),"w").write("hi") + for vec2 in vec2s: + value = np.sum(vec1[None,:,:] * vec2[:,None,:],axis=2) + + row, col = scipy.optimize.linear_sum_assignment(-value) + r.append(value[row,col].mean()) + return r + + + +def recover_global_matching_first(): + """ + Recover the global matching of original to encoded images by doing + an all-pairs matching problem + """ + global global_matching, ys_image, encoded + + matrix = [] + p = mp.Pool(96) + xs_image_ = np.array(xs_image) + ys_image_ = np.array(ys_image) + + matrix = p.map(match, [(x, ys_image_) for x in xs_image_]) + matrix = np.array(matrix).reshape((xs_image.shape[0], + xs_image.shape[0])) + + + row, col = scipy.optimize.linear_sum_assignment(-np.array(matrix)) + global_matching = np.argsort(col) + print('glob',list(global_matching)) + + p.close() + + + +def recover_global_permutation(): + """ + Find the way that the encoded images are permuted off of the original images + """ + global global_permutation + + print("Glob match", global_matching) + overall = [] + for i,j in enumerate(global_matching): + overall.append(np.sum(xs_image[j][None,:,:] * ys_image[i][:,None,:],axis=2)) + + overall = np.mean(overall, 0) + + row, col = scipy.optimize.linear_sum_assignment(-overall) + + try: + print("Changed frac:", np.mean(global_permutation!=np.argsort(col))) + except: + pass + + global_permutation = np.argsort(col) + + +def recover_global_matching_second(): + """ + Match each encoded image with its original encoded image, + but better by relying on the global permutation. + """ + global global_matching_second, global_matching + + ys_fix = [] + for i in range(ys_image.shape[0]): + ys_fix.append(ys_image[i][global_permutation]) + ys_fix = np.array(ys_fix) + + + print(xs_image.shape) + + sims = [] + for i in range(0,len(xs_image),10): + tmp = np.mean(xs_image[None,:,:,:] * ys_fix[i:i+10][:,None,:,:],axis=(2,3)) + sims.extend(tmp) + sims = np.array(sims) + print(sims.shape) + + + row, col = scipy.optimize.linear_sum_assignment(-sims) + + print('arg',sims.argmax(1)) + + print("Same matching frac", np.mean(col == global_matching) ) + print(col) + global_matching = col + + +def extract_by_training(resume): + """ + Final recovery process by extracting the neural network + """ + global inverse + + device = torch.device('cuda:1') + + if not resume: + inverse = PrivateEncoder(Xray(True)).cuda(device) + + # More adam to train. + optimizer = torch.optim.Adam(inverse.parameters(), lr=0.0001) + + this_xs = ooriginal[global_matching] + this_ys = encoded[:,global_permutation,:] + + for i in range(2000): + idx = np.random.random_integers(0, len(this_xs)-1, 32) + xbatch = torch.tensor(this_xs[idx]).cuda(device) + ybatch = torch.tensor(this_ys[idx]).cuda(device) + + optimizer.zero_grad() + + guess_output = inverse(xbatch) + # L1 loss because we don't want to be sensitive to outliers + error = torch.mean(torch.abs(guess_output-ybatch)) + error.backward() + + optimizer.step() + + print(error) + + + +def test_extract(): + """ + Now we can recover the matching much better by computing the estimated + encodings for each original image. + """ + global err, global_matching, guessed_encoded, smatrix + + device = torch.device('cuda:1') + + print(ooriginal.shape, encoded.shape) + + out = [] + for i in range(0,len(ooriginal),32): + print(i) + out.extend(inverse(torch.tensor(ooriginal[i:i+32]).cuda(device)).cpu().detach().numpy()) + + guessed_encoded = np.array(out) + + + # Now we have to compare each encoded image with every other original image. + # Do this fast with some matrix multiplies. + + out = guessed_encoded.reshape((len(encoded), -1)) + real = encoded[:,global_permutation,:].reshape((len(encoded), -1)) + @jax.jit + def foo(x, y): + return jn.square(x[:,None] - y[None,:]).sum(2) + + smatrix = np.zeros((len(out), len(out))) + + B = 500 + for i in range(0,len(out),B): + print(i) + for j in range(0,len(out),B): + smatrix[i:i+B, j:j+B] = foo(out[i:i+B], real[j:j+B]) + + # And the final time you'l have to look at a min weight matching, I promise. + row, col = scipy.optimize.linear_sum_assignment(np.array(smatrix)) + r = np.array(smatrix) + + print(list(row)[::100]) + + print("Differences", np.mean(np.argsort(col) != global_matching)) + + global_matching = np.argsort(col) + + +def perf(steps=[]): + if len(steps) == 0: + steps.append(time.time()) + else: + print("Last Time Elapsed:", time.time()-steps[-1], ' Total Time Elapsed:', time.time()-steps[0]) + steps.append(time.time()) + time.sleep(1) + + +if __name__ == "__main__": + if True: + perf() + setup('challenge') + perf() + gen_train_data() + perf() + get_train_features() + perf() + train_model() + perf() + + if True: + load_challenge() + perf() + recover_local_permutation() + perf() + recover_better_local_permutation() + perf() + compute_patch_similarity() + perf() + recover_global_matching_first() + perf() + + for _ in range(3): + recover_global_permutation() + perf() + recover_global_matching_second() + perf() + + for i in range(3): + recover_global_permutation() + perf() + extract_by_training(i > 0) + perf() + test_extract() + perf() + print(perf())