Merge pull request #185 from carlini:neuracrypt

PiperOrigin-RevId: 429632517
This commit is contained in:
Michael Reneer 2022-02-18 21:10:54 +00:00
commit 7e0b193393

View file

@ -0,0 +1,712 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# This program solves the NeuraCrypt challenge to 100% accuracy.
# Given a set of encoded images and original versions of those,
# it shows how to match the original to the encoded.
import collections
import hashlib
import time
import multiprocessing as mp
import torch
import numpy as np
import torch.nn as nn
import scipy.stats
import matplotlib.pyplot as plt
from PIL import Image
import jax
import jax.numpy as jn
import objax
import scipy.optimize
import numpy as np
import multiprocessing as mp
# Objax neural network that's going to embed patches to a
# low dimensional space to guess if two patches correspond
# to the same orginal image.
class Model(objax.Module):
def __init__(self):
IN = 15
H = 64
self.encoder =objax.nn.Sequential([
objax.nn.Linear(IN, H),
objax.functional.leaky_relu,
objax.nn.Linear(H, H),
objax.functional.leaky_relu,
objax.nn.Linear(H, 8)])
self.decoder =objax.nn.Sequential([
objax.nn.Linear(IN, H),
objax.functional.leaky_relu,
objax.nn.Linear(H, H),
objax.functional.leaky_relu,
objax.nn.Linear(H, 8)])
self.scale = objax.nn.Linear(1, 1, use_bias=False)
def encode(self, x):
# Encode turns original images into feature space
a = self.encoder(x)
a = a/jn.sum(a**2,axis=-1,keepdims=True)**.5
return a
def decode(self, x):
# And decode turns encoded images into feature space
a = self.decoder(x)
a = a/jn.sum(a**2,axis=-1,keepdims=True)**.5
return a
# Proxy dataset for analysis
class ImageNet:
num_chan = 3
private_kernel_size = 16
hidden_dim = 2048
img_size = (256, 256)
private_depth = 7
def __init__(self, remove):
self.remove_pixel_shuffle = remove
# Original dataset as used in the NeuraCrypt paper
class Xray:
num_chan = 1
private_kernel_size = 16
hidden_dim = 2048
img_size = (256, 256)
private_depth = 4
def __init__(self, remove):
self.remove_pixel_shuffle = remove
## The following class is taken directly from the NeuraCrypt codebase.
## https://github.com/yala/NeuraCrypt
## which is originally licensed under the MIT License
class PrivateEncoder(nn.Module):
def __init__(self, args, width_factor=1):
super(PrivateEncoder, self).__init__()
self.args = args
input_dim = args.num_chan
patch_size = args.private_kernel_size
output_dim = args.hidden_dim
num_patches = (args.img_size[0] // patch_size) **2
self.noise_size = 1
args.input_dim = args.hidden_dim
layers = [
nn.Conv2d(input_dim, output_dim * width_factor, kernel_size=patch_size, dilation=1 ,stride=patch_size),
nn.ReLU()
]
for _ in range(self.args.private_depth):
layers.extend( [
nn.Conv2d(output_dim * width_factor, output_dim * width_factor , kernel_size=1, dilation=1, stride=1),
nn.BatchNorm2d(output_dim * width_factor, track_running_stats=False),
nn.ReLU()
])
self.image_encoder = nn.Sequential(*layers)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, output_dim * width_factor))
self.mixer = nn.Sequential( *[
nn.ReLU(),
nn.Linear(output_dim * width_factor, output_dim)
])
def forward(self, x):
encoded = self.image_encoder(x)
B, C, H,W = encoded.size()
encoded = encoded.view([B, -1, H*W]).transpose(1,2)
encoded += self.pos_embedding
encoded = self.mixer(encoded)
## Shuffle indicies
if not self.args.remove_pixel_shuffle:
shuffled = torch.zeros_like(encoded)
for i in range(B):
idx = torch.randperm(H*W, device=encoded.device)
for j, k in enumerate(idx):
shuffled[i,j] = encoded[i,k]
encoded = shuffled
return encoded
## End copied code
def setup(ds):
"""
Load the datasets to use. Nothing interesting to see.
"""
global x_train, y_train
if ds == 'imagenet':
import torchvision
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(256),
torchvision.transforms.ToTensor()])
imagenet_data = torchvision.datasets.ImageNet('/mnt/data/datasets/unpacked_imagenet_pytorch/',
split='val',
transform=transform)
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=100,
shuffle=True,
num_workers=8)
r = []
for x,_ in data_loader:
if len(r) > 1000: break
print(x.shape)
r.extend(x.numpy())
x_train = np.array(r)
print(x_train.shape)
elif ds == 'xray':
import torchvision
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(256),
torchvision.transforms.ToTensor()])
imagenet_data = torchvision.datasets.ImageFolder('CheXpert-v1.0/train',
transform=transform)
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=100,
shuffle=True,
num_workers=8)
r = []
for x,_ in data_loader:
if len(r) > 1000: break
print(x.shape)
r.extend(x.numpy())
x_train = np.array(r)
print(x_train.shape)
elif ds == 'challenge':
x_train = np.load("orig-7.npy")
print(np.min(x_train), np.max(x_train), x_train.shape)
else:
raise
def gen_train_data():
"""
Generate aligned training data to train a patch similarity function.
Given some original images, generate lots of encoded versions.
"""
global encoded_train, original_train
encoded_train = []
original_train = []
args = Xray(True)
C = 100
for i in range(30):
print(i)
torch.manual_seed(int(time.time()))
e = PrivateEncoder(args).cuda()
batch = np.random.randint(0, len(x_train), size=C)
xin = x_train[batch]
r = []
for i in range(0,C,32):
r.extend(e(torch.tensor(xin[i:i+32]).cuda()).detach().cpu().numpy())
r = np.array(r)
encoded_train.append(r)
original_train.append(xin)
def features_(x, moments=15, encoded=False):
"""
Compute higher-order moments for patches in an image to use as
features for the neural network.
"""
x = np.array(x, dtype=np.float32)
dim = 2
arr = np.array([np.mean(x, dim)] + [abs(scipy.stats.moment(x, moment=i, axis=dim))**(1/i) for i in range(1,moments)])
return arr.transpose((1,2,0))
def features(x, encoded):
"""
Given the original images or the encoded images, generate the
features to use for the patch similarity function.
"""
print('start shape',x.shape)
if len(x.shape) == 3:
x = x - np.mean(x,axis=0,keepdims=True)
else:
# count x 100 x 256 x 768
print(x[0].shape)
x = x - np.mean(x,axis=1,keepdims=True)
# remove per-neural-network dimension
x = x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
p = mp.Pool(96)
B = len(x) // 96
print(1)
bs = [x[i:i+B] for i in range(0,len(x),B)]
print(2)
r = p.map(features_, bs)
#r = features_(bs[0][:100])
print(3)
p.close()
#r = np.array(r)
#print('finish',r.shape)
return np.concatenate(r, axis=0)
def get_train_features():
"""
Create features for the entire datasets.
"""
global xs_train, ys_train
print(x_train.shape)
original_train_ = np.array(original_train)
encoded_train_ = np.array(encoded_train)
print("Computing features")
ys_train = features(encoded_train_, True)
patch_size = 16
ss = original_train_.shape[3] // patch_size
# Okay so this is an ugly transpose block.
# We are going from [outer_batch, batch_size, channels, width, height
# to [outer_batch, batch_size, channels, width/patch_size, patch_size, height/patch_size, patch_size]
# Then we reshape this and flatten so that we end up with
# [other_batch, batch_size, width/patch_size, height_patch_size, patch_size**2*channels]
# So that now we can run features on the last dimension
original_train_ = original_train_.reshape((original_train_.shape[0],
original_train_.shape[1],
original_train_.shape[2],
ss,patch_size,ss,patch_size)).transpose((0,1,3,5,2,4,6)).reshape((original_train_.shape[0], original_train_.shape[1], ss**2, patch_size**2))
xs_train = features(original_train_, False)
print(xs_train.shape, ys_train.shape)
def train_model():
"""
Train the patch similarity function
"""
global ema, model
model = Model()
def loss(x, y):
"""
K-way contrastive loss as in SimCLR et al.
The idea is that we should embed x and y so that they are similar
to each other, and dis-similar from others. To do this we have a
softmx loss over one dimension to make the values large on the diagonal
and small off-diagonal.
"""
a = model.encode(x)
b = model.decode(y)
mat = a@b.T
return objax.functional.loss.cross_entropy_logits_sparse(
logits=jn.exp(jn.clip(model.scale.w.value, -2, 4)) * mat,
labels=np.arange(a.shape[0])).mean()
ema = objax.optimizer.ExponentialMovingAverage(model.vars(), momentum=0.999)
gv = objax.GradValues(loss, model.vars())
encode_ema = ema.replace_vars(lambda x: model.encode(x))
decode_ema = ema.replace_vars(lambda y: model.decode(y))
def train_op(x, y):
"""
No one was ever fired for using Adam with 1e-4.
"""
g, v = gv(x, y)
opt(1e-4, g)
ema()
return v
opt = objax.optimizer.Adam(model.vars())
train_op = objax.Jit(train_op, gv.vars() + opt.vars() + ema.vars())
ys_ = ys_train
print(ys_.shape)
xs_ = xs_train.reshape((-1, xs_train.shape[-1]))
ys_ = ys_.reshape((-1, ys_train.shape[-1]))
# The model scale trick here is taken from CLIP.
# Let the model decide how confident to make its own predictions.
model.scale.w.assign(jn.zeros((1,1)))
valid_size = 1000
print(xs_train.shape)
# SimCLR likes big batches
B = 4096
for it in range(80):
print()
ms = []
for i in range(1000):
# First batch is smaller, to make training more stable
bs = [B // 64, B][it>0]
batch = np.random.randint(0, len(xs_)-valid_size, size=bs)
r = train_op(xs_[batch], ys_[batch])
# This shouldn't happen, but if it does, better to bort early
if np.isnan(r):
print("Die on nan")
print(ms[-100:])
return
ms.append(r)
print('mean',np.mean(ms), 'scale', model.scale.w.value)
print('loss',loss(xs_[-100:], ys_[-100:]))
a = encode_ema(xs_[-valid_size:])
b = decode_ema(ys_[-valid_size:])
br = b[np.random.permutation(len(b))]
print('score',np.mean(np.sum(a*b,axis=(1)) - np.sum(a*br,axis=(1))),
np.mean(np.sum(a*b,axis=(1)) > np.sum(a*br,axis=(1))))
ckpt = objax.io.Checkpoint("saved", keep_ckpts=0)
ema.replace_vars(lambda: ckpt.save(model.vars(), 0))()
def load_challenge():
"""
Load the challenge datast for attacking
"""
global xs, ys, encoded, original, ooriginal
print("SETUP: Loading matrixes")
# The encoded images
encoded = np.load("challenge-7.npy")
# And the original images
ooriginal = original = np.load("orig-7.npy")
print("Sizes", encoded.shape, ooriginal.shape)
# Again do that ugly resize thing to make the features be on the last dimension
# Look up above to see what's going on.
patch_size = 16
ss = original.shape[2] // patch_size
original = ooriginal.reshape((original.shape[0],1,ss,patch_size,ss,patch_size))
original = original.transpose((0,2,4,1,3,5))
original = original.reshape((original.shape[0], ss**2, patch_size**2))
def match_sub(args):
"""
Find the best way to undo the permutation between two images.
"""
vec1, vec2 = args
value = np.sum((vec1[None,:,:] - vec2[:,None,:])**2,axis=2)
row, col = scipy.optimize.linear_sum_assignment(value)
return col
def recover_local_permutation():
"""
Given a set of encoded images, return a new encoding without permutations
"""
global encoded, ys
p = mp.Pool(96)
print('recover local')
local_perm = p.map(match_sub, [(encoded[0], e) for e in encoded])
local_perm = np.array(local_perm)
encoded_perm = []
for i in range(len(encoded)):
encoded_perm.append(encoded[i][np.argsort(local_perm[i])])
encoded_perm = np.array(encoded_perm)
encoded = np.array(encoded_perm)
p.close()
def recover_better_local_permutation():
"""
Given a set of encoded images, return a new encoding, but better!
"""
global encoded, ys
# Now instead of pairing all images to image 0, we compute the mean l2 vector
# and then pair all images onto the mean vector. Slightly more noise resistant.
p = mp.Pool(96)
target = encoded.mean(0)
local_perm = p.map(match_sub, [(target, e) for e in encoded])
local_perm = np.array(local_perm)
# Probably we didn't change by much, generally <0.1%
print('improved changed by', np.mean(local_perm != np.arange(local_perm.shape[1])))
encoded_perm = []
for i in range(len(encoded)):
encoded_perm.append(encoded[i][np.argsort(local_perm[i])])
encoded = np.array(encoded_perm)
p.close()
def compute_patch_similarity():
"""
Compute the feature vectors for each patch using the trained neural network.
"""
global xs, ys, xs_image, ys_image
print("Computing features")
ys = features(encoded, encoded=True)
xs = features(original, encoded=False)
model = Model()
ckpt = objax.io.Checkpoint("saved", keep_ckpts=0)
ckpt.restore(model.vars())
xs_image = model.encode(xs)
ys_image = model.decode(ys)
assert xs.shape[0] == xs_image.shape[0]
print("Done")
def match(args, ret_col=False):
"""
Compute the similarity between image features and encoded features.
"""
vec1, vec2s = args
r = []
open("/tmp/start%d.%d"%(np.random.randint(10000),time.time()),"w").write("hi")
for vec2 in vec2s:
value = np.sum(vec1[None,:,:] * vec2[:,None,:],axis=2)
row, col = scipy.optimize.linear_sum_assignment(-value)
r.append(value[row,col].mean())
return r
def recover_global_matching_first():
"""
Recover the global matching of original to encoded images by doing
an all-pairs matching problem
"""
global global_matching, ys_image, encoded
matrix = []
p = mp.Pool(96)
xs_image_ = np.array(xs_image)
ys_image_ = np.array(ys_image)
matrix = p.map(match, [(x, ys_image_) for x in xs_image_])
matrix = np.array(matrix).reshape((xs_image.shape[0],
xs_image.shape[0]))
row, col = scipy.optimize.linear_sum_assignment(-np.array(matrix))
global_matching = np.argsort(col)
print('glob',list(global_matching))
p.close()
def recover_global_permutation():
"""
Find the way that the encoded images are permuted off of the original images
"""
global global_permutation
print("Glob match", global_matching)
overall = []
for i,j in enumerate(global_matching):
overall.append(np.sum(xs_image[j][None,:,:] * ys_image[i][:,None,:],axis=2))
overall = np.mean(overall, 0)
row, col = scipy.optimize.linear_sum_assignment(-overall)
try:
print("Changed frac:", np.mean(global_permutation!=np.argsort(col)))
except:
pass
global_permutation = np.argsort(col)
def recover_global_matching_second():
"""
Match each encoded image with its original encoded image,
but better by relying on the global permutation.
"""
global global_matching_second, global_matching
ys_fix = []
for i in range(ys_image.shape[0]):
ys_fix.append(ys_image[i][global_permutation])
ys_fix = np.array(ys_fix)
print(xs_image.shape)
sims = []
for i in range(0,len(xs_image),10):
tmp = np.mean(xs_image[None,:,:,:] * ys_fix[i:i+10][:,None,:,:],axis=(2,3))
sims.extend(tmp)
sims = np.array(sims)
print(sims.shape)
row, col = scipy.optimize.linear_sum_assignment(-sims)
print('arg',sims.argmax(1))
print("Same matching frac", np.mean(col == global_matching) )
print(col)
global_matching = col
def extract_by_training(resume):
"""
Final recovery process by extracting the neural network
"""
global inverse
device = torch.device('cuda:1')
if not resume:
inverse = PrivateEncoder(Xray(True)).cuda(device)
# More adam to train.
optimizer = torch.optim.Adam(inverse.parameters(), lr=0.0001)
this_xs = ooriginal[global_matching]
this_ys = encoded[:,global_permutation,:]
for i in range(2000):
idx = np.random.random_integers(0, len(this_xs)-1, 32)
xbatch = torch.tensor(this_xs[idx]).cuda(device)
ybatch = torch.tensor(this_ys[idx]).cuda(device)
optimizer.zero_grad()
guess_output = inverse(xbatch)
# L1 loss because we don't want to be sensitive to outliers
error = torch.mean(torch.abs(guess_output-ybatch))
error.backward()
optimizer.step()
print(error)
def test_extract():
"""
Now we can recover the matching much better by computing the estimated
encodings for each original image.
"""
global err, global_matching, guessed_encoded, smatrix
device = torch.device('cuda:1')
print(ooriginal.shape, encoded.shape)
out = []
for i in range(0,len(ooriginal),32):
print(i)
out.extend(inverse(torch.tensor(ooriginal[i:i+32]).cuda(device)).cpu().detach().numpy())
guessed_encoded = np.array(out)
# Now we have to compare each encoded image with every other original image.
# Do this fast with some matrix multiplies.
out = guessed_encoded.reshape((len(encoded), -1))
real = encoded[:,global_permutation,:].reshape((len(encoded), -1))
@jax.jit
def foo(x, y):
return jn.square(x[:,None] - y[None,:]).sum(2)
smatrix = np.zeros((len(out), len(out)))
B = 500
for i in range(0,len(out),B):
print(i)
for j in range(0,len(out),B):
smatrix[i:i+B, j:j+B] = foo(out[i:i+B], real[j:j+B])
# And the final time you'l have to look at a min weight matching, I promise.
row, col = scipy.optimize.linear_sum_assignment(np.array(smatrix))
r = np.array(smatrix)
print(list(row)[::100])
print("Differences", np.mean(np.argsort(col) != global_matching))
global_matching = np.argsort(col)
def perf(steps=[]):
if len(steps) == 0:
steps.append(time.time())
else:
print("Last Time Elapsed:", time.time()-steps[-1], ' Total Time Elapsed:', time.time()-steps[0])
steps.append(time.time())
time.sleep(1)
if __name__ == "__main__":
if True:
perf()
setup('challenge')
perf()
gen_train_data()
perf()
get_train_features()
perf()
train_model()
perf()
if True:
load_challenge()
perf()
recover_local_permutation()
perf()
recover_better_local_permutation()
perf()
compute_patch_similarity()
perf()
recover_global_matching_first()
perf()
for _ in range(3):
recover_global_permutation()
perf()
recover_global_matching_second()
perf()
for i in range(3):
recover_global_permutation()
perf()
extract_by_training(i > 0)
perf()
test_extract()
perf()
print(perf())