tensorflow_privacy/research/instahide_attack_2020/step_5_reconstruct.py
2020-12-05 01:20:49 +00:00

143 lines
4.4 KiB
Python

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
The final recovery happens here. Given the graph, reconstruct images.
"""
import json
import numpy as np
import jax.numpy as jn
import jax
import collections
from PIL import Image
import jax.experimental.optimizers
import matplotlib.pyplot as plt
def toimg(x):
#x = np.transpose(x,(1,2,0))
print(x.shape)
img = (x+1)*127.5
return Image.fromarray(np.array(img,dtype=np.uint8))
def explained_variance(I, private_images, lambdas, encoded_images, public_to_private, return_mat=False):
# private images: 100x32x32x3
# encoded images: 5000x32x32x3
public_to_private = jax.nn.softmax(public_to_private,axis=-1)
# Compute the components from each of the images we know should map onto the same original image.
component_1 = jn.dot(public_to_private[0], private_images.reshape((100,-1))).reshape((5000,32,32,3))
component_2 = jn.dot(public_to_private[1], private_images.reshape((100,-1))).reshape((5000,32,32,3))
# Now combine them together to get the variance we can explain
merged = component_1 * lambdas[:,0][:,None,None,None] + component_2 * lambdas[:,1][:,None,None,None]
# And now get the variance we can't explain.
# This is the contribution of the public images.
# We want this value to be small.
def keep_smallest_abs(xx1, xx2):
t = 0
which = (jn.abs(xx1+t) < jn.abs(xx2+t)) + 0.0
return xx1 * which + xx2 * (1-which)
xx1 = jn.abs(encoded) - merged
xx2 = -(jn.abs(encoded) + merged)
xx = keep_smallest_abs(xx1, xx2)
unexplained_variance = xx
if return_mat:
return unexplained_variance, xx1, xx2
extra = (1-jn.abs(private_images)).mean()*.05
return extra + (unexplained_variance**2).mean()
def setup():
global private, imagenet40, encoded, lambdas, using, real_using, pub_using
# Load all the things we've made.
encoded = np.load("data/encryption.npy")
labels = np.load("data/label.npy")
using = np.load("data/predicted_pairings_80.npy", allow_pickle=True)
lambdas = list(np.load("data/predicted_lambdas_80.npy", allow_pickle=True))
for x in lambdas:
while len(x) < 2:
x.append(0)
lambdas = np.array(lambdas)
# Construct the mapping
public_to_private_new = np.zeros((2, 5000, 100))
cs = [0]*100
for i,row in enumerate(using):
for j,b in enumerate(row[:2]):
public_to_private_new[j, i, b] = 1e9
cs[b] += 1
using = public_to_private_new
def loss(private, lams, I):
return explained_variance(I, private, lams, jn.array(encoded), jn.array(using))
def make_loss():
global vg
vg = jax.jit(jax.value_and_grad(loss, argnums=(0,1)))
def run():
priv = np.zeros((100,32,32,3))
uusing = np.array(using)
lams = np.array(lambdas)
# Use Adam, because thinking hard is overrated we have magic pixie dust.
init_1, opt_update_1, get_params_1 = jax.experimental.optimizers.adam(.01)
@jax.jit
def update_1(i, opt_state, gs):
return opt_update_1(i, gs, opt_state)
opt_state_1 = init_1(priv)
# 1000 iterations of gradient descent is probably enough
for i in range(1000):
value, grad = vg(priv, lams, i)
if i%100 == 0:
print(value)
var,_,_ = explained_variance(0, priv, jn.array(lambdas), jn.array(encoded), jn.array(using),
return_mat=True)
print('unexplained min/max', var.min(), var.max())
opt_state_1 = update_1(i, opt_state_1, grad[0])
priv = opt_state_1.packed_state[0][0]
priv -= np.min(priv, axis=(1,2,3), keepdims=True)
priv /= np.max(priv, axis=(1,2,3), keepdims=True)
priv *= 2
priv -= 1
# Finally save the stored values
np.save("data/private_raw.npy", priv)
if __name__ == "__main__":
setup()
make_loss()
run()