Merge pull request #184 from carlini/instahide

Add InstaHide attack code to research folder
This commit is contained in:
Shuang Song 2022-02-14 13:23:16 -08:00 committed by GitHub
commit 66338409b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 650 additions and 0 deletions

View file

@ -0,0 +1,66 @@
Implementation of our reconstruction attack on InstaHide.
Is Private Learning Possible with Instance Encoding?
Nicholas Carlini, Samuel Deng, Sanjam Garg, Somesh Jha, Saeed Mahloujifar, Mohammad Mahmoody, Shuang Song, Abhradeep Thakurta, Florian Tramer
https://arxiv.org/abs/2011.05315
## Overview
InstaHide is a recent privacy-preserving machine learning framework.
It takes a (sensitive) dataset and generates encoded images that are privacy-preserving.
Our attack breaks InstaHide and shows it does not offer meaningful privacy.
Given the encoded dataset, we can recover a near-identical copy of the original images.
This repository implements the attack described in our paper. It consists of a number of
steps that shoul be run sequentially. It assumes access to pre-trained neural network
classifiers that should be downloaded following the steps below.
### Requirements
* Python, version ≥ 3.5
* jax
* jaxlib
* objax (https://github.com/google/objax)
* PIL
* sklearn
### Running the attack
To reproduce our results and run the attack, each of the files should be run in turn.
0. Download the necessary dependency files:
- (encryption.npy)[https://www.dropbox.com/sh/8zdsr1sjftia4of/AAA-60TOjGKtGEZrRmbawwqGa?dl=0] and (labels.npy)[https://www.dropbox.com/sh/8zdsr1sjftia4of/AAA-60TOjGKtGEZrRmbawwqGa?dl=0] from the (InstaHide Challenge)[https://github.com/Hazelsuko07/InstaHide_Challenge]
- The (saved models)[https://drive.google.com/file/d/1YfKzGRfnnzKfUKpLjIRXRto8iD4FdwGw/view?usp=sharing] used to run the attack
- Set up all the requirements as above
1. Run `step_1_create_graph.py`. Produce the similarity graph to pair together encoded images that share an original image.
2. Run `step_2_color_graph.py`. Color the graph to find 50 dense cliques.
3. Run `step_3_second_graph.py`. Create a new bipartite similarity graph.
4. Run `step_4_final_graph.py`. Solve the matching problem to assign encoded images to original images.
5. Run `step_5_reconstruct.py`. Reconstruct the original images.
6. Run `step_6_adjust_color.py`. Adjust the color curves to match.
7. Run `step_7_visualize.py`. Show the final resulting images.
## Citation
You can cite this attack at
```
@inproceedings{carlini2021private,
title={Is Private Learning Possible with Instance Encoding?},
author={Carlini, Nicholas and Deng, Samuel and Garg, Sanjam and Jha, Somesh and Mahloujifar, Saeed and Mahmoody, Mohammad and Thakurta, Abhradeep and Tram{\`e}r, Florian},
booktitle={2021 IEEE Symposium on Security and Privacy (SP)},
pages={410--427},
year={2021},
organization={IEEE}
}
```

View file

@ -0,0 +1,77 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Create the similarity graph given the encoded images by running the similarity
neural network over all pairs of images.
"""
import objax
import numpy as np
import jax.numpy as jn
import functools
import os
import random
from objax.zoo import wide_resnet
def setup():
global model
class DoesUseSame(objax.Module):
def __init__(self):
fn = functools.partial(wide_resnet.WideResNet, depth=28, width=6)
self.model = fn(6,2)
model_vars = self.model.vars()
self.ema = objax.optimizer.ExponentialMovingAverage(model_vars, momentum=0.999, debias=True)
def predict_op(x,y):
# The model takes the two images and checks if they correspond
# to the same original image.
xx = jn.concatenate([jn.abs(x),
jn.abs(y)],
axis=1)
return self.model(xx, training=False)
self.predict = objax.Jit(self.ema.replace_vars(predict_op), model_vars + self.ema.vars())
self.predict_fast = objax.Parallel(self.ema.replace_vars(predict_op), model_vars + self.ema.vars())
model = DoesUseSame()
checkpoint = objax.io.Checkpoint("models/step1/", keep_ckpts=5, makedir=True)
start_epoch, last_ckpt = checkpoint.restore(model.vars())
def doall():
global graph
n = np.load("data/encryption.npy")
n = np.transpose(n, (0,3,1,2))
# Compute the similarity between each encoded image and all others
# This is n^2 work but should run fairly quickly, especially given
# more than one GPU. Otherwise about an hour or so.
graph = []
with model.vars().replicate():
for i in range(5000):
print(i)
v = model.predict_fast(np.tile(n[i:i+1], (5000,1,1,1)), n)
graph.append(np.array(v[:,0]-v[:,1]))
graph = np.array(graph)
np.save("data/graph.npy", graph)
if __name__ == "__main__":
setup()
doall()

View file

@ -0,0 +1,95 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import pickle
import random
import collections
import numpy as np
def score(subset):
sub = graph[subset]
sub = sub[:,subset]
return np.sum(sub)
def run(v, return_scores=False):
if isinstance(v, int):
v = [v]
scores = []
for _ in range(100):
keep = graph[v,:]
next_value = np.sum(keep,axis=0)
to_add = next_value.argsort()
to_add = [x for x in to_add if x not in v]
if _ < 1:
v.append(to_add[random.randint(0,10)])
else:
v.append(to_add[0])
if return_scores:
scores.append(score(v)/len(keep))
if return_scores:
return v, scores
else:
return v
def make_many_clusters():
# Compute clusters of 100 examples that probably correspond to some original image
p = mp.Pool(mp.cpu_count())
s = p.map(run, range(2000))
return s
def downselect_clusters(s):
# Right now we have a lot of clusters, but they probably overlap. Let's remove that.
# We want to find disjoint clusters, so we'll greedily add them until we have
# 100 distjoint clusters.
ss = [set(x) for x in s]
keep = []
keep_set = []
for iteration in range(2):
for this_set in s:
# MAGIC NUMBERS...!
# We want clusters of size 50 because it works
# Except on iteration 2 where we'll settle for 25 if we haven't
# found clusters with 50 neighbors that work.
cur = set(this_set[:50 - 25*iteration])
intersections = np.array([len(cur & x) for x in ss])
good = np.sum(intersections==50)>2
# Good means that this cluster isn't a fluke and some other cluster
# is like this one.
if good or iteration == 1:
print("N")
# And also make sure we haven't found this cluster (or one like it).
already_found = np.array([len(cur & x) for x in keep_set])
if np.all(already_found<len(cur)/2):
print("And is new")
keep.append(this_set)
keep_set.append(set(this_set))
if len(keep) == 100:
break
print("Found", len(keep))
if len(keep) == 100:
break
# Keep should now have 100 items.
# If it doesn't go and change the 2000 in make_many_clusters to a bigger number.
return keep
if __name__ == "__main__":
graph = np.load("data/graph.npy")
np.save("data/many_clusters",make_many_clusters())
np.save("data/100_clusters", downselect_clusters(np.load("data/many_clusters.npy")))

View file

@ -0,0 +1,114 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Create the improved graph mapping each encoded image to an original image.
"""
import objax
import numpy as np
import jax.numpy as jn
import functools
import os
import random
from objax.zoo import wide_resnet
def setup():
global model
class DoesUseSame(objax.Module):
def __init__(self):
fn = functools.partial(wide_resnet.WideResNet, depth=28, width=6)
self.model = fn(3*4,2)
model_vars = self.model.vars()
self.ema = objax.optimizer.ExponentialMovingAverage(model_vars, momentum=0.999, debias=True)
def predict_op(x,y):
# The model takes SEVERAL images and checks if they all correspond
# to the same original image.
# Guaranteed that the first N-1 all do, the test is if the last does.
xx = jn.concatenate([jn.abs(x),
jn.abs(y)],
axis=1)
return self.model(xx, training=False)
self.predict = objax.Jit(self.ema.replace_vars(predict_op), model_vars + self.ema.vars())
model = DoesUseSame()
checkpoint = objax.io.Checkpoint("models/step2/", keep_ckpts=5, makedir=True)
start_epoch, last_ckpt = checkpoint.restore(model.vars())
def step2():
global v, n, u, nextgraph
# Start out by loading the encoded images
n = np.load("data/encryption.npy")
n = np.transpose(n, (0,3,1,2))
# Then load the graph with 100 cluster-centers.
keep = np.array(np.load("data/100_clusters.npy", allow_pickle=True))
graph = np.load("data/graph.npy")
# Now we're going to record the distance to each of the cluster centers
# from every encoded image, so that we can do the matching.
# To do that, though, first we need to choose the cluster centers.
# Start out by choosing the best cluster centers.
distances = []
for x in keep:
this_set = x[:50]
use_elts = graph[this_set]
distances.append(np.sum(use_elts,axis=0))
distances = np.array(distances)
ds = np.argsort(distances, axis=1)
# Now we record the "prototypes" of each cluster center.
# We just need three, more might help a little bit but not much.
# (And then do that ten times, so we can average out noise
# with respect to which cluster centers we picked.)
prototypes = []
for _ in range(10):
ps = []
# choose 3 random samples from each set
for i in range(3):
ps.append(n[ds[:,random.randint(0,20)]])
prototypes.append(np.concatenate(ps,1))
prototypes = np.concatenate(prototypes,0)
# Finally compute the distances from each node to each cluster center.
nextgraph = []
for i in range(5000):
out = model.predict(prototypes, np.tile(n[i:i+1], (1000,1,1,1)))
out = out.reshape((10, 100, 2))
v = np.sum(out,axis=0)
v = v[:,0] - v[:,1]
v = np.array(v)
nextgraph.append(v)
np.save("data/nextgraph.npy", nextgraph)
if __name__ == "__main__":
setup()
step2()

View file

@ -0,0 +1,51 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import pickle
import random
import numpy as np
labels = np.load("data/label.npy")
nextgraph = np.load("data/nextgraph.npy")
assigned = [[] for _ in range(5000)]
lambdas = [[] for _ in range(5000)]
for i in range(100):
order = (np.argsort(nextgraph[:,i]))
correct = (labels[order[:20]]>0).sum(axis=0).argmax()
# Let's create the final graph
# Instead of doing a full bipartite matching, let's just greedily
# choose the closest 80 candidates for each encoded image to pair
# together can call it a day.
# This is within a percent or two of doing that, and much easier.
# Also record the lambdas based on which image it coresponds to,
# but if they share a label then just guess it's an even 50/50 split.
for x in order[:80]:
if labels[x][correct] > 0 and len(assigned[x]) < 2:
assigned[x].append(i)
if np.sum(labels[x]>0) == 1:
# the same label was mixed in twice. punt.
lambdas[x].append(labels[x][correct]/2)
else:
lambdas[x].append(labels[x][correct])
np.save("data/predicted_pairings_80.npy", assigned)
np.save("data/predicted_lambdas_80.npy", lambdas)

View file

@ -0,0 +1,143 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
The final recovery happens here. Given the graph, reconstruct images.
"""
import json
import numpy as np
import jax.numpy as jn
import jax
import collections
from PIL import Image
import jax.experimental.optimizers
import matplotlib.pyplot as plt
def toimg(x):
#x = np.transpose(x,(1,2,0))
print(x.shape)
img = (x+1)*127.5
return Image.fromarray(np.array(img,dtype=np.uint8))
def explained_variance(I, private_images, lambdas, encoded_images, public_to_private, return_mat=False):
# private images: 100x32x32x3
# encoded images: 5000x32x32x3
public_to_private = jax.nn.softmax(public_to_private,axis=-1)
# Compute the components from each of the images we know should map onto the same original image.
component_1 = jn.dot(public_to_private[0], private_images.reshape((100,-1))).reshape((5000,32,32,3))
component_2 = jn.dot(public_to_private[1], private_images.reshape((100,-1))).reshape((5000,32,32,3))
# Now combine them together to get the variance we can explain
merged = component_1 * lambdas[:,0][:,None,None,None] + component_2 * lambdas[:,1][:,None,None,None]
# And now get the variance we can't explain.
# This is the contribution of the public images.
# We want this value to be small.
def keep_smallest_abs(xx1, xx2):
t = 0
which = (jn.abs(xx1+t) < jn.abs(xx2+t)) + 0.0
return xx1 * which + xx2 * (1-which)
xx1 = jn.abs(encoded) - merged
xx2 = -(jn.abs(encoded) + merged)
xx = keep_smallest_abs(xx1, xx2)
unexplained_variance = xx
if return_mat:
return unexplained_variance, xx1, xx2
extra = (1-jn.abs(private_images)).mean()*.05
return extra + (unexplained_variance**2).mean()
def setup():
global private, imagenet40, encoded, lambdas, using, real_using, pub_using
# Load all the things we've made.
encoded = np.load("data/encryption.npy")
labels = np.load("data/label.npy")
using = np.load("data/predicted_pairings_80.npy", allow_pickle=True)
lambdas = list(np.load("data/predicted_lambdas_80.npy", allow_pickle=True))
for x in lambdas:
while len(x) < 2:
x.append(0)
lambdas = np.array(lambdas)
# Construct the mapping
public_to_private_new = np.zeros((2, 5000, 100))
cs = [0]*100
for i,row in enumerate(using):
for j,b in enumerate(row[:2]):
public_to_private_new[j, i, b] = 1e9
cs[b] += 1
using = public_to_private_new
def loss(private, lams, I):
return explained_variance(I, private, lams, jn.array(encoded), jn.array(using))
def make_loss():
global vg
vg = jax.jit(jax.value_and_grad(loss, argnums=(0,1)))
def run():
priv = np.zeros((100,32,32,3))
uusing = np.array(using)
lams = np.array(lambdas)
# Use Adam, because thinking hard is overrated we have magic pixie dust.
init_1, opt_update_1, get_params_1 = jax.experimental.optimizers.adam(.01)
@jax.jit
def update_1(i, opt_state, gs):
return opt_update_1(i, gs, opt_state)
opt_state_1 = init_1(priv)
# 1000 iterations of gradient descent is probably enough
for i in range(1000):
value, grad = vg(priv, lams, i)
if i%100 == 0:
print(value)
var,_,_ = explained_variance(0, priv, jn.array(lambdas), jn.array(encoded), jn.array(using),
return_mat=True)
print('unexplained min/max', var.min(), var.max())
opt_state_1 = update_1(i, opt_state_1, grad[0])
priv = opt_state_1.packed_state[0][0]
priv -= np.min(priv, axis=(1,2,3), keepdims=True)
priv /= np.max(priv, axis=(1,2,3), keepdims=True)
priv *= 2
priv -= 1
# Finally save the stored values
np.save("data/private_raw.npy", priv)
if __name__ == "__main__":
setup()
make_loss()
run()

View file

@ -0,0 +1,66 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fix the color curves. Use a pre-trained "neural network" with <100 weights.
Visually this helps a lot, even if it's not doing much of anything in pactice.
"""
import random
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import numpy as np
import jax.numpy as jn
import objax
# Our extremely complicated neural network to re-color the images.
# Takes one pixel at a time and fixes the color of that pixel.
model = objax.nn.Sequential([objax.nn.Linear(3, 10),
objax.functional.relu,
objax.nn.Linear(10, 3)
])
# These are the weights.
weights = [[-0.09795442, -0.26434848, -0.24964345, -0.11450608, 0.6797288, -0.48435465,
0.45307165, -0.31196147, -0.33266315, 0.20486055],
[[-0.9056427, 0.02872663, -1.5114126, -0.41024876, -0.98195165, 0.1143966,
0.6763464, -0.58654785, -1.797063, -0.2176538, ],
[ 1.1941166, 0.15515928, 1.1691351, -0.7256186, 0.8046044, 1.3127686,
-0.77297133, -1.1761239, 0.85841715, 0.95545965],
[ 0.20092924, 0.57503146, 0.22809981, 1.5288007, -0.94781816, -0.68305916,
-0.5245211, 1.4042739, -0.00527458, -1.1462274, ]],
[0.15683544, 0.22086962, 0.33100453],
[[ 7.7239674e-01, 4.0261227e-01, -9.6466336e-03],
[-2.2159107e-01, 1.5123411e-01, 3.4485441e-01],
[-1.7618114e+00, -7.1886492e-01, -4.6467595e-02],
[ 6.9419539e-01, 6.2531930e-01, 7.2271496e-01],
[-1.1913675e+00, -6.7755884e-01, -3.5114303e-01],
[ 4.8022485e-01, 1.7145030e-01, 7.4849324e-04],
[ 3.8332436e-02, -7.0614147e-01, -5.5127507e-01],
[-1.0929481e+00, -1.0268525e+00, -7.0265180e-01],
[ 1.4880739e+00, 7.1450096e-01, 2.9102692e-01],
[ 7.2846663e-01, 7.1322352e-01, -1.7453632e-01]]]
for i,(k,v) in enumerate(model.vars().items()):
v.assign(jn.array(weights[i]))
# Do all of the re-coloring
predict = objax.Jit(lambda x: model(x, training=False),
model.vars())
out = model(np.load("data/private_raw.npy"))
np.save("data/private.npy", out)

View file

@ -0,0 +1,38 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Given the private images, draw them in a 100x100 grid for visualization.
"""
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
p = np.load("data/private.npy")
def toimg(x):
print(x.shape)
img = (x+1)*127.5
img = np.clip(img, 0, 255)
img = np.reshape(img, (10, 10, 32, 32, 3))
img = np.concatenate(img, axis=2)
img = np.concatenate(img, axis=0)
img = Image.fromarray(np.array(img,dtype=np.uint8))
return img
toimg(p).save("data/reconstructed.png")