PiperOrigin-RevId: 428549678
This commit is contained in:
Michael Reneer 2022-02-14 18:34:55 +00:00
parent c8bba41059
commit 8012d5b9c9
25 changed files with 0 additions and 3053 deletions

View file

@ -1,11 +0,0 @@
# Auditing Private Machine Learning
Code for "Auditing Differentially Private Machine Learning: How Private is Private SGD?": https://arxiv.org/abs/2006.07709. This implementation is simple but not easily parallelizable. For a parallelizable version which is harder to run, see https://github.com/jagielski/auditing-dpsgd.
## Usage
This attack relies on the AuditAttack class found in audit.py. The class allows one to generate poisoning, run trials to compute membership scores for the poisoning, and then use the resulting membership scores to compute a lower bound on epsilon.
## Examples
Two examples are provided, mean_audit.py and fmnist_audit.py. fmnist_audit.py attacks the FashionMNIST dataset. It allows the user to specify between standard backdoor attacks and clipping-aware attacks, and also allows the user to specify between multiple poisoning attack sizes, model types, and whether to load saved model weights to start training from. mean_audit.py audits a model which computes the mean of a dataset. This provides an example of user-provided poisoning samples, rather than those autogenerated from our attacks.py library.
## Requirements
Requires scikit-learn=0.24.1, statsmodels=0.12.2, tensorflow=1.14.0

View file

@ -1,115 +0,0 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Poisoning attack library for auditing."""
import numpy as np
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
def make_clip_aware(train_x, train_y, l2_norm=10):
"""
train_x: clean training features - must be shape (n_samples, n_features)
train_y: clean training labels - must be shape (n_samples, )
Returns x, y1, y2
x: poisoning sample
y1: first corresponding y value
y2: second corresponding y value
"""
x_shape = list(train_x.shape[1:])
to_image = lambda x: x.reshape([-1] + x_shape) # reshapes to standard image shape
flatten = lambda x: x.reshape((x.shape[0], -1)) # flattens all pixels - allows PCA
# make sure to_image an flatten are inverse functions
assert np.allclose(to_image(flatten(train_x)), train_x)
flat_x = flatten(train_x)
pca = PCA(flat_x.shape[1])
pca.fit(flat_x)
new_x = l2_norm*pca.components_[-1]
lr = LogisticRegression(max_iter=1000)
lr.fit(flat_x, np.argmax(train_y, axis=1))
num_classes = train_y.shape[1]
lr_probs = lr.predict_proba(new_x[None, :])
min_y = np.argmin(lr_probs)
second_y = np.argmin(lr_probs + np.eye(num_classes)[min_y])
oh_min_y = np.eye(num_classes)[min_y]
oh_second_y = np.eye(num_classes)[second_y]
return to_image(new_x), oh_min_y, oh_second_y
def make_backdoor(train_x, train_y):
"""
Makes a backdoored dataset, following Gu et al. https://arxiv.org/abs/1708.06733
train_x: clean training features - must be shape (n_samples, n_features)
train_y: clean training labels - must be shape (n_samples, )
Returns x, y1, y2
x: poisoning sample
y1: first corresponding y value
y2: second corresponding y value
"""
sample_ind = np.random.choice(train_x.shape[0], 1)
pois_x = np.copy(train_x[sample_ind, :])
pois_x[0] = 1 # set corner feature to 1
second_y = train_y[sample_ind]
num_classes = train_y.shape[1]
min_y = np.eye(num_classes)[second_y.argmax(1) + 1]
return pois_x, min_y, second_y
def make_many_poisoned_datasets(train_x, train_y, pois_sizes, attack="clip_aware", l2_norm=10):
"""
Makes a dict containing many poisoned datasets. make_pois is fairly slow:
this avoids making multiple calls
train_x: clean training features - shape (n_samples, n_features)
train_y: clean training labels - shape (n_samples, )
pois_sizes: list of poisoning sizes
l2_norm: l2 norm of the poisoned data
Returns dict: all_poisons
all_poisons[poison_size] is a pair of poisoned datasets
"""
if attack == "clip_aware":
pois_sample_x, y, second_y = make_clip_aware(train_x, train_y, l2_norm)
elif attack == "backdoor":
pois_sample_x, y, second_y = make_backdoor(train_x, train_y)
else:
raise NotImplementedError
all_poisons = {"pois": (pois_sample_x, y)}
for pois_size in pois_sizes: # make_pois is slow - don't want it in a loop
new_pois_x1, new_pois_y1 = train_x.copy(), train_y.copy()
new_pois_x2, new_pois_y2 = train_x.copy(), train_y.copy()
new_pois_x1[-pois_size:] = pois_sample_x[None, :]
new_pois_y1[-pois_size:] = y
new_pois_x2[-pois_size:] = pois_sample_x[None, :]
new_pois_y2[-pois_size:] = second_y
dataset1, dataset2 = (new_pois_x1, new_pois_y1), (new_pois_x2, new_pois_y2)
all_poisons[pois_size] = dataset1, dataset2
return all_poisons

View file

@ -1,119 +0,0 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Class for running auditing procedure."""
import numpy as np
from statsmodels.stats import proportion
import attacks
def compute_results(poison_scores, unpois_scores, pois_ct,
alpha=0.05, threshold=None):
"""
Searches over thresholds for the best epsilon lower bound and accuracy.
poison_scores: list of scores from poisoned models
unpois_scores: list of scores from unpoisoned models
pois_ct: number of poison points
alpha: confidence parameter
threshold: if None, search over all thresholds, else use given threshold
"""
if threshold is None: # search for best threshold
all_thresholds = np.unique(poison_scores + unpois_scores)
else:
all_thresholds = [threshold]
poison_arr = np.array(poison_scores)
unpois_arr = np.array(unpois_scores)
best_threshold, best_epsilon, best_acc = None, 0, 0
for thresh in all_thresholds:
epsilon, acc = compute_epsilon_and_acc(poison_arr, unpois_arr, thresh,
alpha, pois_ct)
if epsilon > best_epsilon:
best_epsilon, best_threshold = epsilon, thresh
best_acc = max(best_acc, acc)
return best_threshold, best_epsilon, best_acc
def compute_epsilon_and_acc(poison_arr, unpois_arr, threshold, alpha, pois_ct):
"""For a given threshold, compute epsilon and accuracy."""
poison_ct = (poison_arr > threshold).sum()
unpois_ct = (unpois_arr > threshold).sum()
# clopper_pearson uses alpha/2 budget on upper and lower
# so total budget will be 2*alpha/2 = alpha
p1, _ = proportion.proportion_confint(poison_ct, poison_arr.size,
alpha, method='beta')
_, p0 = proportion.proportion_confint(unpois_ct, unpois_arr.size,
alpha, method='beta')
if (p1 <= 1e-5) or (p0 >= 1 - 1e-5): # divide by zero issues
return 0, 0
if (p0 + p1) > 1: # see Appendix A
p0, p1 = (1-p1), (1-p0)
epsilon = np.log(p1/p0)/pois_ct
acc = (p1 + (1-p0))/2 # this is not necessarily the best accuracy
return epsilon, acc
class AuditAttack(object):
"""Audit attack class. Generates poisoning, then runs auditing algorithm."""
def __init__(self, train_x, train_y, train_function):
"""
train_x: training features
train_y: training labels
name: identifier for the attack
train_function: function returning membership score
"""
self.train_x, self.train_y = train_x, train_y
self.train_function = train_function
self.poisoning = None
def make_poisoning(self, pois_ct, attack_type, l2_norm=10):
"""Get poisoning data."""
return attacks.make_many_poisoned_datasets(self.train_x, self.train_y, [pois_ct],
attack=attack_type, l2_norm=l2_norm)
def run_experiments(self, num_trials):
"""Runs all training experiments."""
(pois_x1, pois_y1), (pois_x2, pois_y2) = self.poisoning['data']
sample_x, sample_y = self.poisoning['pois']
poison_scores = []
unpois_scores = []
for i in range(num_trials):
poison_tuple = (pois_x1, pois_y1, sample_x, sample_y, i)
unpois_tuple = (pois_x2, pois_y2, sample_x, sample_y, num_trials + i)
poison_scores.append(self.train_function(poison_tuple))
unpois_scores.append(self.train_function(unpois_tuple))
return poison_scores, unpois_scores
def run(self, pois_ct, attack_type, num_trials, alpha=0.05,
threshold=None, l2_norm=10):
"""Complete auditing algorithm. Generates poisoning if necessary."""
if self.poisoning is None:
self.poisoning = self.make_poisoning(pois_ct, attack_type, l2_norm)
self.poisoning['data'] = self.poisoning[pois_ct]
poison_scores, unpois_scores = self.run_experiments(num_trials)
results = compute_results(poison_scores, unpois_scores, pois_ct,
alpha=alpha, threshold=threshold)
return results

View file

@ -1,91 +0,0 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for audit.py."""
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import audit
def dummy_train_and_score_function(dataset):
del dataset
return 0
def get_auditor():
poisoning = {}
datasets = (np.zeros((5, 2)), np.zeros(5)), (np.zeros((5, 2)), np.zeros(5))
poisoning["data"] = datasets
poisoning["pois"] = (datasets[0][0][0], datasets[0][1][0])
auditor = audit.AuditAttack(datasets[0][0], datasets[0][1],
dummy_train_and_score_function)
auditor.poisoning = poisoning
return auditor
class AuditParameterizedTest(parameterized.TestCase):
"""Class to test parameterized audit.py functions."""
@parameterized.named_parameters(
('Test0', np.ones(500), np.zeros(500), 0.5, 0.01, 1,
(4.541915810224092, 0.9894593118113243)),
('Test1', np.ones(500), np.zeros(500), 0.5, 0.01, 2,
(2.27095790511, 0.9894593118113243)),
('Test2', np.ones(500), np.ones(500), 0.5, 0.01, 1,
(0, 0))
)
def test_compute_epsilon_and_acc(self, poison_scores, unpois_scores,
threshold, pois_ct, alpha, expected_res):
expected_eps, expected_acc = expected_res
computed_res = audit.compute_epsilon_and_acc(poison_scores, unpois_scores,
threshold, pois_ct, alpha)
computed_eps, computed_acc = computed_res
self.assertAlmostEqual(computed_eps, expected_eps)
self.assertAlmostEqual(computed_acc, expected_acc)
@parameterized.named_parameters(
('Test0', [1]*500, [0]*250 + [.5]*250, 1, 0.01, .5,
(.5, 4.541915810224092, 0.9894593118113243)),
('Test1', [1]*500, [0]*250 + [.5]*250, 1, 0.01, None,
(.5, 4.541915810224092, 0.9894593118113243)),
('Test2', [1]*500, [0]*500, 2, 0.01, .5,
(.5, 2.27095790511, 0.9894593118113243)),
)
def test_compute_results(self, poison_scores, unpois_scores, pois_ct,
alpha, threshold, expected_res):
expected_thresh, expected_eps, expected_acc = expected_res
computed_res = audit.compute_results(poison_scores, unpois_scores,
pois_ct, alpha, threshold)
computed_thresh, computed_eps, computed_acc = computed_res
self.assertAlmostEqual(computed_thresh, expected_thresh)
self.assertAlmostEqual(computed_eps, expected_eps)
self.assertAlmostEqual(computed_acc, expected_acc)
class AuditAttackTest(absltest.TestCase):
"""Nonparameterized audit.py test class."""
def test_run_experiments(self):
auditor = get_auditor()
pois, unpois = auditor.run_experiments(100)
expected = [0]*100
self.assertListEqual(pois, expected)
self.assertListEqual(unpois, expected)
if __name__ == '__main__':
absltest.main()

View file

@ -1,176 +0,0 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Run auditing on the FashionMNIST dataset."""
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from tensorflow_privacy.privacy.optimizers import dp_optimizer_vectorized
from absl import app
from absl import flags
import audit
#### FLAGS
FLAGS = flags.FLAGS
flags.DEFINE_float('learning_rate', 0.15, 'Learning rate for training')
flags.DEFINE_float('noise_multiplier', 1.1,
'Ratio of the standard deviation to the clipping norm')
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
flags.DEFINE_integer('batch_size', 250, 'Batch size')
flags.DEFINE_integer('epochs', 24, 'Number of epochs')
flags.DEFINE_integer(
'microbatches', 250, 'Number of microbatches '
'(must evenly divide batch_size)')
flags.DEFINE_string('model', 'lr', 'model to use, pick between lr and nn')
flags.DEFINE_string('attack_type', "clip_aware", 'clip_aware or backdoor')
flags.DEFINE_integer('pois_ct', 1, 'Number of poisoning points')
flags.DEFINE_integer('num_trials', 100, 'Number of trials for auditing')
flags.DEFINE_float('attack_l2_norm', 10, 'Size of poisoning data')
flags.DEFINE_float('alpha', 0.05, '1-confidence')
flags.DEFINE_boolean('load_weights', False,
'if True, use weights saved in init_weights.h5')
FLAGS = flags.FLAGS
def compute_epsilon(train_size):
"""Computes epsilon value for given hyperparameters."""
if FLAGS.noise_multiplier == 0.0:
return float('inf')
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
sampling_probability = FLAGS.batch_size / train_size
steps = FLAGS.epochs * train_size / FLAGS.batch_size
rdp = compute_rdp(q=sampling_probability,
noise_multiplier=FLAGS.noise_multiplier,
steps=steps,
orders=orders)
# Delta is set to approximate 1 / (number of training points).
return get_privacy_spent(orders, rdp, target_delta=1e-5)[0]
def build_model(x, y):
"""Build a keras model."""
input_shape = x.shape[1:]
num_classes = y.shape[1]
l2 = 0
if FLAGS.model == 'lr':
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=input_shape),
tf.keras.layers.Dense(num_classes, kernel_initializer='glorot_normal',
kernel_regularizer=tf.keras.regularizers.l2(l2))
])
elif FLAGS.model == 'nn':
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=input_shape),
tf.keras.layers.Dense(32, activation='relu',
kernel_initializer='glorot_normal',
kernel_regularizer=tf.keras.regularizers.l2(l2)),
tf.keras.layers.Dense(num_classes, kernel_initializer='glorot_normal',
kernel_regularizer=tf.keras.regularizers.l2(l2))
])
else:
raise NotImplementedError
return model
def train_model(model, train_x, train_y, save_weights=False):
"""Train the model on given data."""
optimizer = dp_optimizer_vectorized.VectorizedDPSGD(
l2_norm_clip=FLAGS.l2_norm_clip,
noise_multiplier=FLAGS.noise_multiplier,
num_microbatches=FLAGS.microbatches,
learning_rate=FLAGS.learning_rate)
loss = tf.keras.losses.CategoricalCrossentropy(
from_logits=True, reduction=tf.losses.Reduction.NONE)
# Compile model with Keras
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
if save_weights:
wts = model.get_weights()
np.save('save_model', wts)
model.set_weights(wts)
return model
if FLAGS.load_weights: # load preset weights
wts = np.load('save_model.npy', allow_pickle=True).tolist()
model.set_weights(wts)
# Train model with Keras
model.fit(train_x, train_y,
epochs=FLAGS.epochs,
validation_data=(train_x, train_y),
batch_size=FLAGS.batch_size,
verbose=0)
return model
def membership_test(model, pois_x, pois_y):
"""Membership inference - detect poisoning."""
probs = model.predict(np.concatenate([pois_x, np.zeros_like(pois_x)]))
return np.multiply(probs[0, :] - probs[1, :], pois_y).sum()
def train_and_score(dataset):
"""Complete training run with membership inference score."""
x, y, pois_x, pois_y, i = dataset
np.random.seed(i)
tf.set_random_seed(i)
tf.reset_default_graph()
model = build_model(x, y)
model = train_model(model, x, y)
return membership_test(model, pois_x, pois_y)
def main(unused_argv):
del unused_argv
# Load training and test data.
np.random.seed(0)
(train_x, train_y), _ = tf.keras.datasets.fashion_mnist.load_data()
train_inds = np.where(train_y < 2)[0]
train_x = -.5 + train_x[train_inds] / 255.
train_y = np.eye(2)[train_y[train_inds]]
# subsample dataset
ss_inds = np.random.choice(train_x.shape[0], train_x.shape[0]//2, replace=False)
train_x = train_x[ss_inds]
train_y = train_y[ss_inds]
init_model = build_model(train_x, train_y)
_ = train_model(init_model, train_x, train_y, save_weights=True)
auditor = audit.AuditAttack(train_x, train_y, train_and_score)
thresh, _, _ = auditor.run(FLAGS.pois_ct, FLAGS.attack_type, FLAGS.num_trials,
alpha=FLAGS.alpha, threshold=None,
l2_norm=FLAGS.attack_l2_norm)
_, eps, acc = auditor.run(FLAGS.pois_ct, FLAGS.attack_type, FLAGS.num_trials,
alpha=FLAGS.alpha, threshold=thresh,
l2_norm=FLAGS.attack_l2_norm)
epsilon_upper_bound = compute_epsilon(train_x.shape[0])
print("Analysis epsilon is {}.".format(epsilon_upper_bound))
print("At threshold={}, epsilon={}.".format(thresh, eps))
print("The best accuracy at distinguishing poisoning is {}.".format(acc))
if __name__ == '__main__':
app.run(main)

View file

@ -1,152 +0,0 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Auditing a model which computes the mean of a synthetic dataset.
This gives an example for instrumenting the auditor to audit a user-given sample."""
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from tensorflow_privacy.privacy.optimizers import dp_optimizer_vectorized
from absl import app
from absl import flags
import audit
#### FLAGS
FLAGS = flags.FLAGS
flags.DEFINE_float('learning_rate', 0.15, 'Learning rate for training')
flags.DEFINE_float('noise_multiplier', 1.1,
'Ratio of the standard deviation to the clipping norm')
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
flags.DEFINE_integer('batch_size', 250, 'Batch size')
flags.DEFINE_integer('d', 250, 'Data dimension')
flags.DEFINE_integer('epochs', 1, 'Number of epochs')
flags.DEFINE_integer(
'microbatches', 250, 'Number of microbatches '
'(must evenly divide batch_size)')
flags.DEFINE_string('attack_type', "clip_aware", 'clip_aware or backdoor')
flags.DEFINE_integer('num_trials', 100, 'Number of trials for auditing')
flags.DEFINE_float('attack_l2_norm', 10, 'Size of poisoning data')
flags.DEFINE_float('alpha', 0.05, '1-confidence')
FLAGS = flags.FLAGS
def compute_epsilon(train_size):
"""Computes epsilon value for given hyperparameters."""
if FLAGS.noise_multiplier == 0.0:
return float('inf')
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
sampling_probability = FLAGS.batch_size / train_size
steps = FLAGS.epochs * train_size / FLAGS.batch_size
rdp = compute_rdp(q=sampling_probability,
noise_multiplier=FLAGS.noise_multiplier,
steps=steps,
orders=orders)
# Delta is set to approximate 1 / (number of training points).
return get_privacy_spent(orders, rdp, target_delta=1e-5)[0]
def build_model(x, y):
del x, y
model = tf.keras.Sequential([tf.keras.layers.Dense(
1, input_shape=(FLAGS.d,),
use_bias=False, kernel_initializer=tf.keras.initializers.Zeros())])
return model
def train_model(model, train_x, train_y):
"""Train the model on given data."""
optimizer = dp_optimizer_vectorized.VectorizedDPSGD(
l2_norm_clip=FLAGS.l2_norm_clip,
noise_multiplier=FLAGS.noise_multiplier,
num_microbatches=FLAGS.microbatches,
learning_rate=FLAGS.learning_rate)
# gradient of (.5-x.w)^2 is 2(.5-x.w)x
loss = tf.keras.losses.MeanSquaredError(reduction=tf.losses.Reduction.NONE)
# Compile model with Keras
model.compile(optimizer=optimizer, loss=loss, metrics=['mse'])
# Train model with Keras
model.fit(train_x, train_y,
epochs=FLAGS.epochs,
validation_data=(train_x, train_y),
batch_size=FLAGS.batch_size,
verbose=0)
return model
def membership_test(model, pois_x, pois_y):
"""Membership inference - detect poisoning."""
del pois_y
return model.predict(pois_x)
def gen_data(n, d):
"""Make binomial dataset."""
x = np.random.normal(size=(n, d))
y = np.ones(shape=(n,))/2.
return x, y
def train_and_score(dataset):
"""Complete training run with membership inference score."""
x, y, pois_x, pois_y, i = dataset
np.random.seed(i)
tf.set_random_seed(i)
model = build_model(x, y)
model = train_model(model, x, y)
return membership_test(model, pois_x, pois_y)
def main(unused_argv):
del unused_argv
# Load training and test data.
np.random.seed(0)
x, y = gen_data(1 + FLAGS.batch_size, FLAGS.d)
auditor = audit.AuditAttack(x, y, train_and_score)
# we will instrument the auditor to simply backdoor the last feature
pois_x1, pois_x2 = x[:-1].copy(), x[:-1].copy()
pois_x1[-1] = x[-1]
pois_y = y[:-1]
target_x = x[-1][None, :]
assert np.unique(np.nonzero(pois_x1 - pois_x2)[0]).size == 1
pois_data = (pois_x1, pois_y), (pois_x2, pois_y), (target_x, y[-1])
poisoning = {}
poisoning["data"] = (pois_data[0], pois_data[1])
poisoning["pois"] = pois_data[2]
auditor.poisoning = poisoning
thresh, _, _ = auditor.run(1, None, FLAGS.num_trials, alpha=FLAGS.alpha)
_, eps, acc = auditor.run(1, None, FLAGS.num_trials, alpha=FLAGS.alpha,
threshold=thresh)
epsilon_upper_bound = compute_epsilon(FLAGS.batch_size)
print("Analysis epsilon is {}.".format(epsilon_upper_bound))
print("At threshold={}, epsilon={}.".format(thresh, eps))
print("The best accuracy at distinguishing poisoning is {}.".format(acc))
if __name__ == '__main__':
app.run(main)

View file

@ -1,66 +0,0 @@
Implementation of our reconstruction attack on InstaHide.
Is Private Learning Possible with Instance Encoding?
Nicholas Carlini, Samuel Deng, Sanjam Garg, Somesh Jha, Saeed Mahloujifar, Mohammad Mahmoody, Shuang Song, Abhradeep Thakurta, Florian Tramer
https://arxiv.org/abs/2011.05315
## Overview
InstaHide is a recent privacy-preserving machine learning framework.
It takes a (sensitive) dataset and generates encoded images that are privacy-preserving.
Our attack breaks InstaHide and shows it does not offer meaningful privacy.
Given the encoded dataset, we can recover a near-identical copy of the original images.
This repository implements the attack described in our paper. It consists of a number of
steps that shoul be run sequentially. It assumes access to pre-trained neural network
classifiers that should be downloaded following the steps below.
### Requirements
* Python, version &ge; 3.5
* jax
* jaxlib
* objax (https://github.com/google/objax)
* PIL
* sklearn
### Running the attack
To reproduce our results and run the attack, each of the files should be run in turn.
0. Download the necessary dependency files:
- (encryption.npy)[https://www.dropbox.com/sh/8zdsr1sjftia4of/AAA-60TOjGKtGEZrRmbawwqGa?dl=0] and (labels.npy)[https://www.dropbox.com/sh/8zdsr1sjftia4of/AAA-60TOjGKtGEZrRmbawwqGa?dl=0] from the (InstaHide Challenge)[https://github.com/Hazelsuko07/InstaHide_Challenge]
- The (saved models)[https://drive.google.com/file/d/1YfKzGRfnnzKfUKpLjIRXRto8iD4FdwGw/view?usp=sharing] used to run the attack
- Set up all the requirements as above
1. Run `step_1_create_graph.py`. Produce the similarity graph to pair together encoded images that share an original image.
2. Run `step_2_color_graph.py`. Color the graph to find 50 dense cliques.
3. Run `step_3_second_graph.py`. Create a new bipartite similarity graph.
4. Run `step_4_final_graph.py`. Solve the matching problem to assign encoded images to original images.
5. Run `step_5_reconstruct.py`. Reconstruct the original images.
6. Run `step_6_adjust_color.py`. Adjust the color curves to match.
7. Run `step_7_visualize.py`. Show the final resulting images.
## Citation
You can cite this attack at
```
@inproceedings{carlini2021private,
title={Is Private Learning Possible with Instance Encoding?},
author={Carlini, Nicholas and Deng, Samuel and Garg, Sanjam and Jha, Somesh and Mahloujifar, Saeed and Mahmoody, Mohammad and Thakurta, Abhradeep and Tram{\`e}r, Florian},
booktitle={2021 IEEE Symposium on Security and Privacy (SP)},
pages={410--427},
year={2021},
organization={IEEE}
}
```

View file

@ -1,77 +0,0 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Create the similarity graph given the encoded images by running the similarity
neural network over all pairs of images.
"""
import objax
import numpy as np
import jax.numpy as jn
import functools
import os
import random
from objax.zoo import wide_resnet
def setup():
global model
class DoesUseSame(objax.Module):
def __init__(self):
fn = functools.partial(wide_resnet.WideResNet, depth=28, width=6)
self.model = fn(6,2)
model_vars = self.model.vars()
self.ema = objax.optimizer.ExponentialMovingAverage(model_vars, momentum=0.999, debias=True)
def predict_op(x,y):
# The model takes the two images and checks if they correspond
# to the same original image.
xx = jn.concatenate([jn.abs(x),
jn.abs(y)],
axis=1)
return self.model(xx, training=False)
self.predict = objax.Jit(self.ema.replace_vars(predict_op), model_vars + self.ema.vars())
self.predict_fast = objax.Parallel(self.ema.replace_vars(predict_op), model_vars + self.ema.vars())
model = DoesUseSame()
checkpoint = objax.io.Checkpoint("models/step1/", keep_ckpts=5, makedir=True)
start_epoch, last_ckpt = checkpoint.restore(model.vars())
def doall():
global graph
n = np.load("data/encryption.npy")
n = np.transpose(n, (0,3,1,2))
# Compute the similarity between each encoded image and all others
# This is n^2 work but should run fairly quickly, especially given
# more than one GPU. Otherwise about an hour or so.
graph = []
with model.vars().replicate():
for i in range(5000):
print(i)
v = model.predict_fast(np.tile(n[i:i+1], (5000,1,1,1)), n)
graph.append(np.array(v[:,0]-v[:,1]))
graph = np.array(graph)
np.save("data/graph.npy", graph)
if __name__ == "__main__":
setup()
doall()

View file

@ -1,95 +0,0 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import pickle
import random
import collections
import numpy as np
def score(subset):
sub = graph[subset]
sub = sub[:,subset]
return np.sum(sub)
def run(v, return_scores=False):
if isinstance(v, int):
v = [v]
scores = []
for _ in range(100):
keep = graph[v,:]
next_value = np.sum(keep,axis=0)
to_add = next_value.argsort()
to_add = [x for x in to_add if x not in v]
if _ < 1:
v.append(to_add[random.randint(0,10)])
else:
v.append(to_add[0])
if return_scores:
scores.append(score(v)/len(keep))
if return_scores:
return v, scores
else:
return v
def make_many_clusters():
# Compute clusters of 100 examples that probably correspond to some original image
p = mp.Pool(mp.cpu_count())
s = p.map(run, range(2000))
return s
def downselect_clusters(s):
# Right now we have a lot of clusters, but they probably overlap. Let's remove that.
# We want to find disjoint clusters, so we'll greedily add them until we have
# 100 distjoint clusters.
ss = [set(x) for x in s]
keep = []
keep_set = []
for iteration in range(2):
for this_set in s:
# MAGIC NUMBERS...!
# We want clusters of size 50 because it works
# Except on iteration 2 where we'll settle for 25 if we haven't
# found clusters with 50 neighbors that work.
cur = set(this_set[:50 - 25*iteration])
intersections = np.array([len(cur & x) for x in ss])
good = np.sum(intersections==50)>2
# Good means that this cluster isn't a fluke and some other cluster
# is like this one.
if good or iteration == 1:
print("N")
# And also make sure we haven't found this cluster (or one like it).
already_found = np.array([len(cur & x) for x in keep_set])
if np.all(already_found<len(cur)/2):
print("And is new")
keep.append(this_set)
keep_set.append(set(this_set))
if len(keep) == 100:
break
print("Found", len(keep))
if len(keep) == 100:
break
# Keep should now have 100 items.
# If it doesn't go and change the 2000 in make_many_clusters to a bigger number.
return keep
if __name__ == "__main__":
graph = np.load("data/graph.npy")
np.save("data/many_clusters",make_many_clusters())
np.save("data/100_clusters", downselect_clusters(np.load("data/many_clusters.npy")))

View file

@ -1,114 +0,0 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Create the improved graph mapping each encoded image to an original image.
"""
import objax
import numpy as np
import jax.numpy as jn
import functools
import os
import random
from objax.zoo import wide_resnet
def setup():
global model
class DoesUseSame(objax.Module):
def __init__(self):
fn = functools.partial(wide_resnet.WideResNet, depth=28, width=6)
self.model = fn(3*4,2)
model_vars = self.model.vars()
self.ema = objax.optimizer.ExponentialMovingAverage(model_vars, momentum=0.999, debias=True)
def predict_op(x,y):
# The model takes SEVERAL images and checks if they all correspond
# to the same original image.
# Guaranteed that the first N-1 all do, the test is if the last does.
xx = jn.concatenate([jn.abs(x),
jn.abs(y)],
axis=1)
return self.model(xx, training=False)
self.predict = objax.Jit(self.ema.replace_vars(predict_op), model_vars + self.ema.vars())
model = DoesUseSame()
checkpoint = objax.io.Checkpoint("models/step2/", keep_ckpts=5, makedir=True)
start_epoch, last_ckpt = checkpoint.restore(model.vars())
def step2():
global v, n, u, nextgraph
# Start out by loading the encoded images
n = np.load("data/encryption.npy")
n = np.transpose(n, (0,3,1,2))
# Then load the graph with 100 cluster-centers.
keep = np.array(np.load("data/100_clusters.npy", allow_pickle=True))
graph = np.load("data/graph.npy")
# Now we're going to record the distance to each of the cluster centers
# from every encoded image, so that we can do the matching.
# To do that, though, first we need to choose the cluster centers.
# Start out by choosing the best cluster centers.
distances = []
for x in keep:
this_set = x[:50]
use_elts = graph[this_set]
distances.append(np.sum(use_elts,axis=0))
distances = np.array(distances)
ds = np.argsort(distances, axis=1)
# Now we record the "prototypes" of each cluster center.
# We just need three, more might help a little bit but not much.
# (And then do that ten times, so we can average out noise
# with respect to which cluster centers we picked.)
prototypes = []
for _ in range(10):
ps = []
# choose 3 random samples from each set
for i in range(3):
ps.append(n[ds[:,random.randint(0,20)]])
prototypes.append(np.concatenate(ps,1))
prototypes = np.concatenate(prototypes,0)
# Finally compute the distances from each node to each cluster center.
nextgraph = []
for i in range(5000):
out = model.predict(prototypes, np.tile(n[i:i+1], (1000,1,1,1)))
out = out.reshape((10, 100, 2))
v = np.sum(out,axis=0)
v = v[:,0] - v[:,1]
v = np.array(v)
nextgraph.append(v)
np.save("data/nextgraph.npy", nextgraph)
if __name__ == "__main__":
setup()
step2()

View file

@ -1,51 +0,0 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import pickle
import random
import numpy as np
labels = np.load("data/label.npy")
nextgraph = np.load("data/nextgraph.npy")
assigned = [[] for _ in range(5000)]
lambdas = [[] for _ in range(5000)]
for i in range(100):
order = (np.argsort(nextgraph[:,i]))
correct = (labels[order[:20]]>0).sum(axis=0).argmax()
# Let's create the final graph
# Instead of doing a full bipartite matching, let's just greedily
# choose the closest 80 candidates for each encoded image to pair
# together can call it a day.
# This is within a percent or two of doing that, and much easier.
# Also record the lambdas based on which image it coresponds to,
# but if they share a label then just guess it's an even 50/50 split.
for x in order[:80]:
if labels[x][correct] > 0 and len(assigned[x]) < 2:
assigned[x].append(i)
if np.sum(labels[x]>0) == 1:
# the same label was mixed in twice. punt.
lambdas[x].append(labels[x][correct]/2)
else:
lambdas[x].append(labels[x][correct])
np.save("data/predicted_pairings_80.npy", assigned)
np.save("data/predicted_lambdas_80.npy", lambdas)

View file

@ -1,143 +0,0 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
The final recovery happens here. Given the graph, reconstruct images.
"""
import json
import numpy as np
import jax.numpy as jn
import jax
import collections
from PIL import Image
import jax.experimental.optimizers
import matplotlib.pyplot as plt
def toimg(x):
#x = np.transpose(x,(1,2,0))
print(x.shape)
img = (x+1)*127.5
return Image.fromarray(np.array(img,dtype=np.uint8))
def explained_variance(I, private_images, lambdas, encoded_images, public_to_private, return_mat=False):
# private images: 100x32x32x3
# encoded images: 5000x32x32x3
public_to_private = jax.nn.softmax(public_to_private,axis=-1)
# Compute the components from each of the images we know should map onto the same original image.
component_1 = jn.dot(public_to_private[0], private_images.reshape((100,-1))).reshape((5000,32,32,3))
component_2 = jn.dot(public_to_private[1], private_images.reshape((100,-1))).reshape((5000,32,32,3))
# Now combine them together to get the variance we can explain
merged = component_1 * lambdas[:,0][:,None,None,None] + component_2 * lambdas[:,1][:,None,None,None]
# And now get the variance we can't explain.
# This is the contribution of the public images.
# We want this value to be small.
def keep_smallest_abs(xx1, xx2):
t = 0
which = (jn.abs(xx1+t) < jn.abs(xx2+t)) + 0.0
return xx1 * which + xx2 * (1-which)
xx1 = jn.abs(encoded) - merged
xx2 = -(jn.abs(encoded) + merged)
xx = keep_smallest_abs(xx1, xx2)
unexplained_variance = xx
if return_mat:
return unexplained_variance, xx1, xx2
extra = (1-jn.abs(private_images)).mean()*.05
return extra + (unexplained_variance**2).mean()
def setup():
global private, imagenet40, encoded, lambdas, using, real_using, pub_using
# Load all the things we've made.
encoded = np.load("data/encryption.npy")
labels = np.load("data/label.npy")
using = np.load("data/predicted_pairings_80.npy", allow_pickle=True)
lambdas = list(np.load("data/predicted_lambdas_80.npy", allow_pickle=True))
for x in lambdas:
while len(x) < 2:
x.append(0)
lambdas = np.array(lambdas)
# Construct the mapping
public_to_private_new = np.zeros((2, 5000, 100))
cs = [0]*100
for i,row in enumerate(using):
for j,b in enumerate(row[:2]):
public_to_private_new[j, i, b] = 1e9
cs[b] += 1
using = public_to_private_new
def loss(private, lams, I):
return explained_variance(I, private, lams, jn.array(encoded), jn.array(using))
def make_loss():
global vg
vg = jax.jit(jax.value_and_grad(loss, argnums=(0,1)))
def run():
priv = np.zeros((100,32,32,3))
uusing = np.array(using)
lams = np.array(lambdas)
# Use Adam, because thinking hard is overrated we have magic pixie dust.
init_1, opt_update_1, get_params_1 = jax.experimental.optimizers.adam(.01)
@jax.jit
def update_1(i, opt_state, gs):
return opt_update_1(i, gs, opt_state)
opt_state_1 = init_1(priv)
# 1000 iterations of gradient descent is probably enough
for i in range(1000):
value, grad = vg(priv, lams, i)
if i%100 == 0:
print(value)
var,_,_ = explained_variance(0, priv, jn.array(lambdas), jn.array(encoded), jn.array(using),
return_mat=True)
print('unexplained min/max', var.min(), var.max())
opt_state_1 = update_1(i, opt_state_1, grad[0])
priv = opt_state_1.packed_state[0][0]
priv -= np.min(priv, axis=(1,2,3), keepdims=True)
priv /= np.max(priv, axis=(1,2,3), keepdims=True)
priv *= 2
priv -= 1
# Finally save the stored values
np.save("data/private_raw.npy", priv)
if __name__ == "__main__":
setup()
make_loss()
run()

View file

@ -1,66 +0,0 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fix the color curves. Use a pre-trained "neural network" with <100 weights.
Visually this helps a lot, even if it's not doing much of anything in pactice.
"""
import random
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import numpy as np
import jax.numpy as jn
import objax
# Our extremely complicated neural network to re-color the images.
# Takes one pixel at a time and fixes the color of that pixel.
model = objax.nn.Sequential([objax.nn.Linear(3, 10),
objax.functional.relu,
objax.nn.Linear(10, 3)
])
# These are the weights.
weights = [[-0.09795442, -0.26434848, -0.24964345, -0.11450608, 0.6797288, -0.48435465,
0.45307165, -0.31196147, -0.33266315, 0.20486055],
[[-0.9056427, 0.02872663, -1.5114126, -0.41024876, -0.98195165, 0.1143966,
0.6763464, -0.58654785, -1.797063, -0.2176538, ],
[ 1.1941166, 0.15515928, 1.1691351, -0.7256186, 0.8046044, 1.3127686,
-0.77297133, -1.1761239, 0.85841715, 0.95545965],
[ 0.20092924, 0.57503146, 0.22809981, 1.5288007, -0.94781816, -0.68305916,
-0.5245211, 1.4042739, -0.00527458, -1.1462274, ]],
[0.15683544, 0.22086962, 0.33100453],
[[ 7.7239674e-01, 4.0261227e-01, -9.6466336e-03],
[-2.2159107e-01, 1.5123411e-01, 3.4485441e-01],
[-1.7618114e+00, -7.1886492e-01, -4.6467595e-02],
[ 6.9419539e-01, 6.2531930e-01, 7.2271496e-01],
[-1.1913675e+00, -6.7755884e-01, -3.5114303e-01],
[ 4.8022485e-01, 1.7145030e-01, 7.4849324e-04],
[ 3.8332436e-02, -7.0614147e-01, -5.5127507e-01],
[-1.0929481e+00, -1.0268525e+00, -7.0265180e-01],
[ 1.4880739e+00, 7.1450096e-01, 2.9102692e-01],
[ 7.2846663e-01, 7.1322352e-01, -1.7453632e-01]]]
for i,(k,v) in enumerate(model.vars().items()):
v.assign(jn.array(weights[i]))
# Do all of the re-coloring
predict = objax.Jit(lambda x: model(x, training=False),
model.vars())
out = model(np.load("data/private_raw.npy"))
np.save("data/private.npy", out)

View file

@ -1,38 +0,0 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Given the private images, draw them in a 100x100 grid for visualization.
"""
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
p = np.load("data/private.npy")
def toimg(x):
print(x.shape)
img = (x+1)*127.5
img = np.clip(img, 0, 255)
img = np.reshape(img, (10, 10, 32, 32, 3))
img = np.concatenate(img, axis=2)
img = np.concatenate(img, axis=0)
img = Image.fromarray(np.array(img,dtype=np.uint8))
return img
toimg(p).save("data/reconstructed.png")

View file

@ -1,129 +0,0 @@
## Membership Inference Attacks From First Principles
This directory contains code to reproduce our paper:
**"Membership Inference Attacks From First Principles"**
https://arxiv.org/abs/2112.03570
by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramer.
### INSTALLING
You will need to install fairly standard dependencies
`pip install scipy, sklearn, numpy, matplotlib`
and also some machine learning framework to train models. We train our models
with JAX + ObJAX so you will need to follow build instructions for that
https://github.com/google/objax
https://objax.readthedocs.io/en/latest/installation_setup.html
### RUNNING THE CODE
#### 1. Train the models
The first step in our attack is to train shadow models. As a baseline
that should give most of the gains in our attack, you should start by
training 16 shadow models with the command
> bash scripts/train_demo.sh
or if you have multiple GPUs on your machine and want to train these models
in parallel, then modify and run
> bash scripts/train_demo_multigpu.sh
This will train several CIFAR-10 wide ResNet models to ~91% accuracy each, and
will output a bunch of files under the directory exp/cifar10 with structure:
```
exp/cifar10/
- experiment_N_of_16
-- hparams.json
-- keep.npy
-- ckpt/
--- 0000000100.npz
-- tb/
```
#### 2. Perform inference
Once the models are trained, now it's necessary to perform inference and save
the output features for each training example for each model in the dataset.
> python3 inference.py --logdir=exp/cifar10/
This will add to the experiment directory a new set of files
```
exp/cifar10/
- experiment_N_of_16
-- logits/
--- 0000000100.npy
```
where this new file has shape (50000, 10) and stores the model's
output features for each example.
#### 3. Compute membership inference scores
Finally we take the output features and generate our logit-scaled membership inference
scores for each example for each model.
> python3 score.py exp/cifar10/
And this in turn generates a new directory
```
exp/cifar10/
- experiment_N_of_16
-- scores/
--- 0000000100.npy
```
with shape (50000,) storing just our scores.
### PLOTTING THE RESULTS
Finally we can generate pretty pictures, and run the plotting code
> python3 plot.py
which should give (something like) the following output
![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve")
```
Attack Ours (online)
AUC 0.6676, Accuracy 0.6077, TPR@0.1%FPR of 0.0169
Attack Ours (online, fixed variance)
AUC 0.6856, Accuracy 0.6137, TPR@0.1%FPR of 0.0593
Attack Ours (offline)
AUC 0.5488, Accuracy 0.5500, TPR@0.1%FPR of 0.0130
Attack Ours (offline, fixed variance)
AUC 0.5549, Accuracy 0.5537, TPR@0.1%FPR of 0.0299
Attack Global threshold
AUC 0.5921, Accuracy 0.6044, TPR@0.1%FPR of 0.0009
```
where the global threshold attack is the baseline, and our online,
online-with-fixed-variance, offline, and offline-with-fixed-variance
attack variants are the four other curves. Note that because we only
train a few models, the fixed variance variants perform best.
### Citation
You can cite this paper with
```
@article{carlini2021membership,
title={Membership Inference Attacks From First Principles},
author={Carlini, Nicholas and Chien, Steve and Nasr, Milad and Song, Shuang and Terzis, Andreas and Tramer, Florian},
journal={arXiv preprint arXiv:2112.03570},
year={2021}
}
```

View file

@ -1,95 +0,0 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Tuple, List
import numpy as np
import tensorflow as tf
def record_parse(serialized_example: str, image_shape: Tuple[int, int, int]):
features = tf.io.parse_single_example(serialized_example,
features={'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)})
image = tf.image.decode_image(features['image']).set_shape(image_shape)
image = tf.cast(image, tf.float32) * (2.0 / 255) - 1.0
return dict(image=image, label=features['label'])
class DataSet:
"""Wrapper for tf.data.Dataset to permit extensions."""
def __init__(self, data: tf.data.Dataset,
image_shape: Tuple[int, int, int],
augment_fn: Optional[Callable] = None,
parse_fn: Optional[Callable] = record_parse):
self.data = data
self.parse_fn = parse_fn
self.augment_fn = augment_fn
self.image_shape = image_shape
@classmethod
def from_arrays(cls, images: np.ndarray, labels: np.ndarray, augment_fn: Optional[Callable] = None):
return cls(tf.data.Dataset.from_tensor_slices(dict(image=images, label=labels)), images.shape[1:],
augment_fn=augment_fn, parse_fn=None)
@classmethod
def from_files(cls, filenames: List[str],
image_shape: Tuple[int, int, int],
augment_fn: Optional[Callable],
parse_fn: Optional[Callable] = record_parse):
filenames_in = filenames
filenames = sorted(sum([tf.io.gfile.glob(x) for x in filenames], []))
if not filenames:
raise ValueError('Empty dataset, files not found:', filenames_in)
return cls(tf.data.TFRecordDataset(filenames), image_shape, augment_fn=augment_fn, parse_fn=parse_fn)
@classmethod
def from_tfds(cls, dataset: tf.data.Dataset, image_shape: Tuple[int, int, int],
augment_fn: Optional[Callable] = None):
return cls(dataset.map(lambda x: dict(image=tf.cast(x['image'], tf.float32) / 127.5 - 1, label=x['label'])),
image_shape, augment_fn=augment_fn, parse_fn=None)
def __iter__(self):
return iter(self.data)
def __getattr__(self, item):
if item in self.__dict__:
return self.__dict__[item]
def call_and_update(*args, **kwargs):
v = getattr(self.__dict__['data'], item)(*args, **kwargs)
if isinstance(v, tf.data.Dataset):
return self.__class__(v, self.image_shape, augment_fn=self.augment_fn, parse_fn=self.parse_fn)
return v
return call_and_update
def augment(self, para_augment: int = 4):
if self.augment_fn:
return self.map(self.augment_fn, para_augment)
return self
def nchw(self):
return self.map(lambda x: dict(image=tf.transpose(x['image'], [0, 3, 1, 2]), label=x['label']))
def one_hot(self, nclass: int):
return self.map(lambda x: dict(image=x['image'], label=tf.one_hot(x['label'], nclass)))
def parse(self, para_parse: int = 2):
if not self.parse_fn:
return self
if self.image_shape:
return self.map(lambda x: self.parse_fn(x, self.image_shape), para_parse)
return self.map(self.parse_fn, para_parse)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 37 KiB

View file

@ -1,150 +0,0 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
from typing import Callable
import json
import re
import jax
import jax.numpy as jn
import numpy as np
import tensorflow as tf # For data augmentation.
import tensorflow_datasets as tfds
from absl import app, flags
from tqdm import tqdm, trange
import pickle
from functools import partial
import objax
from objax.jaxboard import SummaryWriter, Summary
from objax.util import EasyDict
from objax.zoo import convnet, wide_resnet
from dataset import DataSet
from train import MemModule, network
from collections import defaultdict
FLAGS = flags.FLAGS
def main(argv):
"""
Perform inference of the saved model in order to generate the
output logits, using a particular set of augmentations.
"""
del argv
tf.config.experimental.set_visible_devices([], "GPU")
def load(arch):
return MemModule(network(arch), nclass=100 if FLAGS.dataset == 'cifar100' else 10,
mnist=FLAGS.dataset == 'mnist',
arch=arch,
lr=.1,
batch=0,
epochs=0,
weight_decay=0)
def cache_load(arch):
thing = []
def fn():
if len(thing) == 0:
thing.append(load(arch))
return thing[0]
return fn
xs_all = np.load(os.path.join(FLAGS.logdir,"x_train.npy"))[:FLAGS.dataset_size]
ys_all = np.load(os.path.join(FLAGS.logdir,"y_train.npy"))[:FLAGS.dataset_size]
def get_loss(model, xbatch, ybatch, shift, reflect=True, stride=1):
outs = []
for aug in [xbatch, xbatch[:,:,::-1,:]][:reflect+1]:
aug_pad = tf.pad(aug, [[0] * 2, [shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT').numpy()
for dx in range(0, 2*shift+1, stride):
for dy in range(0, 2*shift+1, stride):
this_x = aug_pad[:, dx:dx+32, dy:dy+32, :].transpose((0,3,1,2))
logits = model.model(this_x, training=True)
outs.append(logits)
print(np.array(outs).shape)
return np.array(outs).transpose((1, 0, 2))
N = 5000
def features(model, xbatch, ybatch):
return get_loss(model, xbatch, ybatch,
shift=0, reflect=True, stride=1)
for path in sorted(os.listdir(os.path.join(FLAGS.logdir))):
if re.search(FLAGS.regex, path) is None:
print("Skipping from regex")
continue
hparams = json.load(open(os.path.join(FLAGS.logdir, path, "hparams.json")))
arch = hparams['arch']
model = cache_load(arch)()
logdir = os.path.join(FLAGS.logdir, path)
checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=10, makedir=True)
max_epoch, last_ckpt = checkpoint.restore(model.vars())
if max_epoch == 0: continue
if not os.path.exists(os.path.join(FLAGS.logdir, path, "logits")):
os.mkdir(os.path.join(FLAGS.logdir, path, "logits"))
if FLAGS.from_epoch is not None:
first = FLAGS.from_epoch
else:
first = max_epoch-1
for epoch in range(first,max_epoch+1):
if not os.path.exists(os.path.join(FLAGS.logdir, path, "ckpt", "%010d.npz"%epoch)):
# no checkpoint saved here
continue
if os.path.exists(os.path.join(FLAGS.logdir, path, "logits", "%010d.npy"%epoch)):
print("Skipping already generated file", epoch)
continue
try:
start_epoch, last_ckpt = checkpoint.restore(model.vars(), epoch)
except:
print("Fail to load", epoch)
continue
stats = []
for i in range(0,len(xs_all),N):
stats.extend(features(model, xs_all[i:i+N],
ys_all[i:i+N]))
# This will be shape N, augs, nclass
np.save(os.path.join(FLAGS.logdir, path, "logits", "%010d"%epoch),
np.array(stats)[:,None,:,:])
if __name__ == '__main__':
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
flags.DEFINE_string('logdir', 'experiments/', 'Directory where to save checkpoints and tensorboard data.')
flags.DEFINE_string('regex', '.*experiment.*', 'keep files when matching')
flags.DEFINE_bool('random_labels', False, 'use random labels.')
flags.DEFINE_integer('dataset_size', 50000, 'size of dataset.')
flags.DEFINE_integer('from_epoch', None, 'which epoch to load from.')
flags.DEFINE_integer('seed_mod', None, 'keep mod seed.')
flags.DEFINE_integer('modulus', 8, 'modulus.')
app.run(main)

View file

@ -1,224 +0,0 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import scipy.stats
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import auc, roc_curve
import functools
# Look at me being proactive!
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
def sweep(score, x):
"""
Compute a ROC curve and then return the FPR, TPR, AUC, and ACC.
"""
fpr, tpr, _ = roc_curve(x, -score)
acc = np.max(1-(fpr+(1-tpr))/2)
return fpr, tpr, auc(fpr, tpr), acc
def load_data(p):
"""
Load our saved scores and then put them into a big matrix.
"""
global scores, keep
scores = []
keep = []
for root,ds,_ in os.walk(p):
for f in ds:
if not f.startswith("experiment"): continue
if not os.path.exists(os.path.join(root,f,"scores")): continue
last_epoch = sorted(os.listdir(os.path.join(root,f,"scores")))
if len(last_epoch) == 0: continue
scores.append(np.load(os.path.join(root,f,"scores",last_epoch[-1])))
keep.append(np.load(os.path.join(root,f,"keep.npy")))
scores = np.array(scores)
keep = np.array(keep)[:,:scores.shape[1]]
return scores, keep
def generate_ours(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000,
fix_variance=False):
"""
Fit a two predictive models using keep and scores in order to predict
if the examples in check_scores were training data or not, using the
ground truth answer from check_keep.
"""
dat_in = []
dat_out = []
for j in range(scores.shape[1]):
dat_in.append(scores[keep[:,j],j,:])
dat_out.append(scores[~keep[:,j],j,:])
in_size = min(min(map(len,dat_in)), in_size)
out_size = min(min(map(len,dat_out)), out_size)
dat_in = np.array([x[:in_size] for x in dat_in])
dat_out = np.array([x[:out_size] for x in dat_out])
mean_in = np.median(dat_in, 1)
mean_out = np.median(dat_out, 1)
if fix_variance:
std_in = np.std(dat_in)
std_out = np.std(dat_in)
else:
std_in = np.std(dat_in, 1)
std_out = np.std(dat_out, 1)
prediction = []
answers = []
for ans, sc in zip(check_keep, check_scores):
pr_in = -scipy.stats.norm.logpdf(sc, mean_in, std_in+1e-30)
pr_out = -scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30)
score = pr_in-pr_out
prediction.extend(score.mean(1))
answers.extend(ans)
return prediction, answers
def generate_ours_offline(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000,
fix_variance=False):
"""
Fit a single predictive model using keep and scores in order to predict
if the examples in check_scores were training data or not, using the
ground truth answer from check_keep.
"""
dat_in = []
dat_out = []
for j in range(scores.shape[1]):
dat_in.append(scores[keep[:, j], j, :])
dat_out.append(scores[~keep[:, j], j, :])
out_size = min(min(map(len,dat_out)), out_size)
dat_out = np.array([x[:out_size] for x in dat_out])
mean_out = np.median(dat_out, 1)
if fix_variance:
std_out = np.std(dat_out)
else:
std_out = np.std(dat_out, 1)
prediction = []
answers = []
for ans, sc in zip(check_keep, check_scores):
score = scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30)
prediction.extend(score.mean(1))
answers.extend(ans)
return prediction, answers
def generate_global(keep, scores, check_keep, check_scores):
"""
Use a simple global threshold sweep to predict if the examples in
check_scores were training data or not, using the ground truth answer from
check_keep.
"""
prediction = []
answers = []
for ans, sc in zip(check_keep, check_scores):
prediction.extend(-sc.mean(1))
answers.extend(ans)
return prediction, answers
def do_plot(fn, keep, scores, ntest, legend='', metric='auc', sweep_fn=sweep, **plot_kwargs):
"""
Generate the ROC curves by using ntest models as test models and the rest to train.
"""
prediction, answers = fn(keep[:-ntest],
scores[:-ntest],
keep[-ntest:],
scores[-ntest:])
fpr, tpr, auc, acc = sweep_fn(np.array(prediction), np.array(answers, dtype=bool))
low = tpr[np.where(fpr<.001)[0][-1]]
print('Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc,acc, low))
metric_text = ''
if metric == 'auc':
metric_text = 'auc=%.3f'%auc
elif metric == 'acc':
metric_text = 'acc=%.3f'%acc
plt.plot(fpr, tpr, label=legend+metric_text, **plot_kwargs)
return (acc,auc)
def fig_fpr_tpr():
plt.figure(figsize=(4,3))
do_plot(generate_ours,
keep, scores, 1,
"Ours (online)\n",
metric='auc'
)
do_plot(functools.partial(generate_ours, fix_variance=True),
keep, scores, 1,
"Ours (online, fixed variance)\n",
metric='auc'
)
do_plot(functools.partial(generate_ours_offline),
keep, scores, 1,
"Ours (offline)\n",
metric='auc'
)
do_plot(functools.partial(generate_ours_offline, fix_variance=True),
keep, scores, 1,
"Ours (offline, fixed variance)\n",
metric='auc'
)
do_plot(generate_global,
keep, scores, 1,
"Global threshold\n",
metric='auc'
)
plt.semilogx()
plt.semilogy()
plt.xlim(1e-5,1)
plt.ylim(1e-5,1)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.plot([0, 1], [0, 1], ls='--', color='gray')
plt.subplots_adjust(bottom=.18, left=.18, top=.96, right=.96)
plt.legend(fontsize=8)
plt.savefig("/tmp/fprtpr.png")
plt.show()
load_data("exp/cifar10/")
fig_fpr_tpr()

View file

@ -1,66 +0,0 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
import os
import multiprocessing as mp
def load_one(base):
"""
This loads a logits and converts it to a scored prediction.
"""
root = os.path.join(logdir,base,'logits')
if not os.path.exists(root): return None
if not os.path.exists(os.path.join(logdir,base,'scores')):
os.mkdir(os.path.join(logdir,base,'scores'))
for f in os.listdir(root):
try:
opredictions = np.load(os.path.join(root,f))
except:
print("Fail")
continue
## Be exceptionally careful.
## Numerically stable everything, as described in the paper.
predictions = opredictions - np.max(opredictions, axis=3, keepdims=True)
predictions = np.array(np.exp(predictions), dtype=np.float64)
predictions = predictions/np.sum(predictions,axis=3,keepdims=True)
COUNT = predictions.shape[0]
# x num_examples x num_augmentations x logits
y_true = predictions[np.arange(COUNT),:,:,labels[:COUNT]]
print(y_true.shape)
print('mean acc',np.mean(predictions[:,0,0,:].argmax(1)==labels[:COUNT]))
predictions[np.arange(COUNT),:,:,labels[:COUNT]] = 0
y_wrong = np.sum(predictions, axis=3)
logit = (np.log(y_true.mean((1))+1e-45) - np.log(y_wrong.mean((1))+1e-45))
np.save(os.path.join(logdir, base, 'scores', f), logit)
def load_stats():
with mp.Pool(8) as p:
p.map(load_one, [x for x in os.listdir(logdir) if 'exp' in x])
logdir = sys.argv[1]
labels = np.load(os.path.join(logdir,"y_train.npy"))
load_stats()

View file

@ -1,16 +0,0 @@
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15

View file

@ -1,18 +0,0 @@
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0 &
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1 &
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2 &
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3 &
CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 &
CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 &
CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 &
CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 &
wait;
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8 &
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9 &
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10 &
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11 &
CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 &
CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 &
CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 &
CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 &
wait;

View file

@ -1,329 +0,0 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import shutil
from typing import Callable
import json
import jax
import jax.numpy as jn
import numpy as np
import tensorflow as tf # For data augmentation.
import tensorflow_datasets as tfds
from absl import app, flags
from tqdm import tqdm, trange
import objax
from objax.jaxboard import SummaryWriter, Summary
from objax.util import EasyDict
from objax.zoo import convnet, wide_resnet, dnnet
from dataset import DataSet
FLAGS = flags.FLAGS
def augment(x, shift: int, mirror=True):
"""
Augmentation function used in training the model.
"""
y = x['image']
if mirror:
y = tf.image.random_flip_left_right(y)
y = tf.pad(y, [[shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT')
y = tf.image.random_crop(y, tf.shape(x['image']))
return dict(image=y, label=x['label'])
class TrainLoop(objax.Module):
"""
Training loop for general machine learning models.
Based on the training loop from the objax CIFAR10 example code.
"""
predict: Callable
train_op: Callable
def __init__(self, nclass: int, **kwargs):
self.nclass = nclass
self.params = EasyDict(kwargs)
def train_step(self, summary: Summary, data: dict, progress: np.ndarray):
kv = self.train_op(progress, data['image'].numpy(), data['label'].numpy())
for k, v in kv.items():
if jn.isnan(v):
raise ValueError('NaN, try reducing learning rate', k)
if summary is not None:
summary.scalar(k, float(v))
def train(self, num_train_epochs: int, train_size: int, train: DataSet, test: DataSet, logdir: str, save_steps=100, patience=None):
"""
Completely standard training. Nothing interesting to see here.
"""
checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=20, makedir=True)
start_epoch, last_ckpt = checkpoint.restore(self.vars())
train_iter = iter(train)
progress = np.zeros(jax.local_device_count(), 'f') # for multi-GPU
best_acc = 0
best_acc_epoch = -1
with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard:
for epoch in range(start_epoch, num_train_epochs):
# Train
summary = Summary()
loop = range(0, train_size, self.params.batch)
for step in loop:
progress[:] = (step + (epoch * train_size)) / (num_train_epochs * train_size)
self.train_step(summary, next(train_iter), progress)
# Eval
accuracy, total = 0, 0
if epoch%FLAGS.eval_steps == 0 and test is not None:
for data in test:
total += data['image'].shape[0]
preds = np.argmax(self.predict(data['image'].numpy()), axis=1)
accuracy += (preds == data['label'].numpy()).sum()
accuracy /= total
summary.scalar('eval/accuracy', 100 * accuracy)
tensorboard.write(summary, step=(epoch + 1) * train_size)
print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, summary['losses/xe'](),
summary['eval/accuracy']()))
if summary['eval/accuracy']() > best_acc:
best_acc = summary['eval/accuracy']()
best_acc_epoch = epoch
elif patience is not None and epoch > best_acc_epoch + patience:
print("early stopping!")
checkpoint.save(self.vars(), epoch + 1)
return
else:
print('Epoch %04d Loss %.2f Accuracy --' % (epoch + 1, summary['losses/xe']()))
if epoch%save_steps == save_steps-1:
checkpoint.save(self.vars(), epoch + 1)
# We inherit from the training loop and define predict and train_op.
class MemModule(TrainLoop):
def __init__(self, model: Callable, nclass: int, mnist=False, **kwargs):
"""
Completely standard training. Nothing interesting to see here.
"""
super().__init__(nclass, **kwargs)
self.model = model(1 if mnist else 3, nclass)
self.opt = objax.optimizer.Momentum(self.model.vars())
self.model_ema = objax.optimizer.ExponentialMovingAverageModule(self.model, momentum=0.999, debias=True)
@objax.Function.with_vars(self.model.vars())
def loss(x, label):
logit = self.model(x, training=True)
loss_wd = 0.5 * sum((v.value ** 2).sum() for k, v in self.model.vars().items() if k.endswith('.w'))
loss_xe = objax.functional.loss.cross_entropy_logits(logit, label).mean()
return loss_xe + loss_wd * self.params.weight_decay, {'losses/xe': loss_xe, 'losses/wd': loss_wd}
gv = objax.GradValues(loss, self.model.vars())
self.gv = gv
@objax.Function.with_vars(self.vars())
def train_op(progress, x, y):
g, v = gv(x, y)
lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8))
lr = lr * jn.clip(progress*100,0,1)
self.opt(lr, g)
self.model_ema.update_ema()
return {'monitors/lr': lr, **v[1]}
self.predict = objax.Jit(objax.nn.Sequential([objax.ForceArgs(self.model_ema, training=False)]))
self.train_op = objax.Jit(train_op)
def network(arch: str):
if arch == 'cnn32-3-max':
return functools.partial(convnet.ConvNet, scales=3, filters=32, filters_max=1024,
pooling=objax.functional.max_pool_2d)
elif arch == 'cnn32-3-mean':
return functools.partial(convnet.ConvNet, scales=3, filters=32, filters_max=1024,
pooling=objax.functional.average_pool_2d)
elif arch == 'cnn64-3-max':
return functools.partial(convnet.ConvNet, scales=3, filters=64, filters_max=1024,
pooling=objax.functional.max_pool_2d)
elif arch == 'cnn64-3-mean':
return functools.partial(convnet.ConvNet, scales=3, filters=64, filters_max=1024,
pooling=objax.functional.average_pool_2d)
elif arch == 'wrn28-1':
return functools.partial(wide_resnet.WideResNet, depth=28, width=1)
elif arch == 'wrn28-2':
return functools.partial(wide_resnet.WideResNet, depth=28, width=2)
elif arch == 'wrn28-10':
return functools.partial(wide_resnet.WideResNet, depth=28, width=10)
raise ValueError('Architecture not recognized', arch)
def get_data(seed):
"""
This is the function to generate subsets of the data for training models.
First, we get the training dataset either from the numpy cache
or otherwise we load it from tensorflow datasets.
Then, we compute the subset. This works in one of two ways.
1. If we have a seed, then we just randomly choose examples based on
a prng with that seed, keeping FLAGS.pkeep fraction of the data.
2. Otherwise, if we have an experiment ID, then we do something fancier.
If we run each experiment independently then even after a lot of trials
there will still probably be some examples that were always included
or always excluded. So instead, with experiment IDs, we guarantee that
after FLAGS.num_experiments are done, each example is seen exactly half
of the time in train, and half of the time not in train.
"""
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
if os.path.exists(os.path.join(FLAGS.logdir, "x_train.npy")):
inputs = np.load(os.path.join(FLAGS.logdir, "x_train.npy"))
labels = np.load(os.path.join(FLAGS.logdir, "y_train.npy"))
else:
print("First time, creating dataset")
data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR))
inputs = data['train']['image']
labels = data['train']['label']
inputs = (inputs/127.5)-1
np.save(os.path.join(FLAGS.logdir, "x_train.npy"),inputs)
np.save(os.path.join(FLAGS.logdir, "y_train.npy"),labels)
nclass = np.max(labels)+1
np.random.seed(seed)
if FLAGS.num_experiments is not None:
np.random.seed(0)
keep = np.random.uniform(0,1,size=(FLAGS.num_experiments, FLAGS.dataset_size))
order = keep.argsort(0)
keep = order < int(FLAGS.pkeep * FLAGS.num_experiments)
keep = np.array(keep[FLAGS.expid], dtype=bool)
else:
keep = np.random.uniform(0, 1, size=FLAGS.dataset_size) <= FLAGS.pkeep
if FLAGS.only_subset is not None:
keep[FLAGS.only_subset:] = 0
xs = inputs[keep]
ys = labels[keep]
if FLAGS.augment == 'weak':
aug = lambda x: augment(x, 4)
elif FLAGS.augment == 'mirror':
aug = lambda x: augment(x, 0)
elif FLAGS.augment == 'none':
aug = lambda x: augment(x, 0, mirror=False)
else:
raise
train = DataSet.from_arrays(xs, ys,
augment_fn=aug)
test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:])
train = train.cache().shuffle(8192).repeat().parse().augment().batch(FLAGS.batch)
train = train.nchw().one_hot(nclass).prefetch(16)
test = test.cache().parse().batch(FLAGS.batch).nchw().prefetch(16)
return train, test, xs, ys, keep, nclass
def main(argv):
del argv
tf.config.experimental.set_visible_devices([], "GPU")
seed = FLAGS.seed
if seed is None:
import time
seed = np.random.randint(0, 1000000000)
seed ^= int(time.time())
args = EasyDict(arch=FLAGS.arch,
lr=FLAGS.lr,
batch=FLAGS.batch,
weight_decay=FLAGS.weight_decay,
augment=FLAGS.augment,
seed=seed)
if FLAGS.tunename:
logdir = '_'.join(sorted('%s=%s' % k for k in args.items()))
elif FLAGS.expid is not None:
logdir = "experiment-%d_%d"%(FLAGS.expid,FLAGS.num_experiments)
else:
logdir = "experiment-"+str(seed)
logdir = os.path.join(FLAGS.logdir, logdir)
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%10)):
print(f"run {FLAGS.expid} already completed.")
return
else:
if os.path.exists(logdir):
print(f"deleting run {FLAGS.expid} that did not complete.")
shutil.rmtree(logdir)
print(f"starting run {FLAGS.expid}.")
if not os.path.exists(logdir):
os.makedirs(logdir)
train, test, xs, ys, keep, nclass = get_data(seed)
# Define the network and train_it
tm = MemModule(network(FLAGS.arch), nclass=nclass,
mnist=FLAGS.dataset == 'mnist',
epochs=FLAGS.epochs,
expid=FLAGS.expid,
num_experiments=FLAGS.num_experiments,
pkeep=FLAGS.pkeep,
save_steps=FLAGS.save_steps,
only_subset=FLAGS.only_subset,
**args
)
r = {}
r.update(tm.params)
open(os.path.join(logdir,'hparams.json'),"w").write(json.dumps(tm.params))
np.save(os.path.join(logdir,'keep.npy'), keep)
tm.train(FLAGS.epochs, len(xs), train, test, logdir,
save_steps=FLAGS.save_steps, patience=FLAGS.patience)
if __name__ == '__main__':
flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.')
flags.DEFINE_float('lr', 0.1, 'Learning rate.')
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
flags.DEFINE_float('weight_decay', 0.0005, 'Weight decay ratio.')
flags.DEFINE_integer('batch', 256, 'Batch size')
flags.DEFINE_integer('epochs', 501, 'Training duration in number of epochs.')
flags.DEFINE_string('logdir', 'experiments', 'Directory where to save checkpoints and tensorboard data.')
flags.DEFINE_integer('seed', None, 'Training seed.')
flags.DEFINE_float('pkeep', .5, 'Probability to keep examples.')
flags.DEFINE_integer('expid', None, 'Experiment ID')
flags.DEFINE_integer('num_experiments', None, 'Number of experiments')
flags.DEFINE_string('augment', 'weak', 'Strong or weak augmentation')
flags.DEFINE_integer('only_subset', None, 'Only train on a subset of images.')
flags.DEFINE_integer('dataset_size', 50000, 'number of examples to keep.')
flags.DEFINE_integer('eval_steps', 1, 'how often to get eval accuracy.')
flags.DEFINE_integer('abort_after_epoch', None, 'stop trainin early at an epoch')
flags.DEFINE_integer('save_steps', 10, 'how often to get save model.')
flags.DEFINE_integer('patience', None, 'Early stopping after this many epochs without progress')
flags.DEFINE_bool('tunename', False, 'Use tune name?')
app.run(main)

View file

@ -1,712 +0,0 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# This program solves the NeuraCrypt challenge to 100% accuracy.
# Given a set of encoded images and original versions of those,
# it shows how to match the original to the encoded.
import collections
import hashlib
import time
import multiprocessing as mp
import torch
import numpy as np
import torch.nn as nn
import scipy.stats
import matplotlib.pyplot as plt
from PIL import Image
import jax
import jax.numpy as jn
import objax
import scipy.optimize
import numpy as np
import multiprocessing as mp
# Objax neural network that's going to embed patches to a
# low dimensional space to guess if two patches correspond
# to the same orginal image.
class Model(objax.Module):
def __init__(self):
IN = 15
H = 64
self.encoder =objax.nn.Sequential([
objax.nn.Linear(IN, H),
objax.functional.leaky_relu,
objax.nn.Linear(H, H),
objax.functional.leaky_relu,
objax.nn.Linear(H, 8)])
self.decoder =objax.nn.Sequential([
objax.nn.Linear(IN, H),
objax.functional.leaky_relu,
objax.nn.Linear(H, H),
objax.functional.leaky_relu,
objax.nn.Linear(H, 8)])
self.scale = objax.nn.Linear(1, 1, use_bias=False)
def encode(self, x):
# Encode turns original images into feature space
a = self.encoder(x)
a = a/jn.sum(a**2,axis=-1,keepdims=True)**.5
return a
def decode(self, x):
# And decode turns encoded images into feature space
a = self.decoder(x)
a = a/jn.sum(a**2,axis=-1,keepdims=True)**.5
return a
# Proxy dataset for analysis
class ImageNet:
num_chan = 3
private_kernel_size = 16
hidden_dim = 2048
img_size = (256, 256)
private_depth = 7
def __init__(self, remove):
self.remove_pixel_shuffle = remove
# Original dataset as used in the NeuraCrypt paper
class Xray:
num_chan = 1
private_kernel_size = 16
hidden_dim = 2048
img_size = (256, 256)
private_depth = 4
def __init__(self, remove):
self.remove_pixel_shuffle = remove
## The following class is taken directly from the NeuraCrypt codebase.
## https://github.com/yala/NeuraCrypt
## which is originally licensed under the MIT License
class PrivateEncoder(nn.Module):
def __init__(self, args, width_factor=1):
super(PrivateEncoder, self).__init__()
self.args = args
input_dim = args.num_chan
patch_size = args.private_kernel_size
output_dim = args.hidden_dim
num_patches = (args.img_size[0] // patch_size) **2
self.noise_size = 1
args.input_dim = args.hidden_dim
layers = [
nn.Conv2d(input_dim, output_dim * width_factor, kernel_size=patch_size, dilation=1 ,stride=patch_size),
nn.ReLU()
]
for _ in range(self.args.private_depth):
layers.extend( [
nn.Conv2d(output_dim * width_factor, output_dim * width_factor , kernel_size=1, dilation=1, stride=1),
nn.BatchNorm2d(output_dim * width_factor, track_running_stats=False),
nn.ReLU()
])
self.image_encoder = nn.Sequential(*layers)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, output_dim * width_factor))
self.mixer = nn.Sequential( *[
nn.ReLU(),
nn.Linear(output_dim * width_factor, output_dim)
])
def forward(self, x):
encoded = self.image_encoder(x)
B, C, H,W = encoded.size()
encoded = encoded.view([B, -1, H*W]).transpose(1,2)
encoded += self.pos_embedding
encoded = self.mixer(encoded)
## Shuffle indicies
if not self.args.remove_pixel_shuffle:
shuffled = torch.zeros_like(encoded)
for i in range(B):
idx = torch.randperm(H*W, device=encoded.device)
for j, k in enumerate(idx):
shuffled[i,j] = encoded[i,k]
encoded = shuffled
return encoded
## End copied code
def setup(ds):
"""
Load the datasets to use. Nothing interesting to see.
"""
global x_train, y_train
if ds == 'imagenet':
import torchvision
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(256),
torchvision.transforms.ToTensor()])
imagenet_data = torchvision.datasets.ImageNet('/mnt/data/datasets/unpacked_imagenet_pytorch/',
split='val',
transform=transform)
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=100,
shuffle=True,
num_workers=8)
r = []
for x,_ in data_loader:
if len(r) > 1000: break
print(x.shape)
r.extend(x.numpy())
x_train = np.array(r)
print(x_train.shape)
elif ds == 'xray':
import torchvision
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(256),
torchvision.transforms.ToTensor()])
imagenet_data = torchvision.datasets.ImageFolder('CheXpert-v1.0/train',
transform=transform)
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=100,
shuffle=True,
num_workers=8)
r = []
for x,_ in data_loader:
if len(r) > 1000: break
print(x.shape)
r.extend(x.numpy())
x_train = np.array(r)
print(x_train.shape)
elif ds == 'challenge':
x_train = np.load("orig-7.npy")
print(np.min(x_train), np.max(x_train), x_train.shape)
else:
raise
def gen_train_data():
"""
Generate aligned training data to train a patch similarity function.
Given some original images, generate lots of encoded versions.
"""
global encoded_train, original_train
encoded_train = []
original_train = []
args = Xray(True)
C = 100
for i in range(30):
print(i)
torch.manual_seed(int(time.time()))
e = PrivateEncoder(args).cuda()
batch = np.random.randint(0, len(x_train), size=C)
xin = x_train[batch]
r = []
for i in range(0,C,32):
r.extend(e(torch.tensor(xin[i:i+32]).cuda()).detach().cpu().numpy())
r = np.array(r)
encoded_train.append(r)
original_train.append(xin)
def features_(x, moments=15, encoded=False):
"""
Compute higher-order moments for patches in an image to use as
features for the neural network.
"""
x = np.array(x, dtype=np.float32)
dim = 2
arr = np.array([np.mean(x, dim)] + [abs(scipy.stats.moment(x, moment=i, axis=dim))**(1/i) for i in range(1,moments)])
return arr.transpose((1,2,0))
def features(x, encoded):
"""
Given the original images or the encoded images, generate the
features to use for the patch similarity function.
"""
print('start shape',x.shape)
if len(x.shape) == 3:
x = x - np.mean(x,axis=0,keepdims=True)
else:
# count x 100 x 256 x 768
print(x[0].shape)
x = x - np.mean(x,axis=1,keepdims=True)
# remove per-neural-network dimension
x = x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
p = mp.Pool(96)
B = len(x)//96
print(1)
bs = [x[i:i+B] for i in range(0,len(x),B)]
print(2)
r = p.map(features_, bs)
#r = features_(bs[0][:100])
print(3)
p.close()
#r = np.array(r)
#print('finish',r.shape)
return np.concatenate(r, axis=0)
def get_train_features():
"""
Create features for the entire datasets.
"""
global xs_train, ys_train
print(x_train.shape)
original_train_ = np.array(original_train)
encoded_train_ = np.array(encoded_train)
print("Computing features")
ys_train = features(encoded_train_, True)
patch_size = 16
ss = original_train_.shape[3]//patch_size
# Okay so this is an ugly transpose block.
# We are going from [outer_batch, batch_size, channels, width, height
# to [outer_batch, batch_size, channels, width/patch_size, patch_size, height/patch_size, patch_size]
# Then we reshape this and flatten so that we end up with
# [other_batch, batch_size, width/patch_size, height_patch_size, patch_size**2*channels]
# So that now we can run features on the last dimension
original_train_ = original_train_.reshape((original_train_.shape[0],
original_train_.shape[1],
original_train_.shape[2],
ss,patch_size,ss,patch_size)).transpose((0,1,3,5,2,4,6)).reshape((original_train_.shape[0], original_train_.shape[1], ss**2, patch_size**2))
xs_train = features(original_train_, False)
print(xs_train.shape, ys_train.shape)
def train_model():
"""
Train the patch similarity function
"""
global ema, model
model = Model()
def loss(x, y):
"""
K-way contrastive loss as in SimCLR et al.
The idea is that we should embed x and y so that they are similar
to each other, and dis-similar from others. To do this we have a
softmx loss over one dimension to make the values large on the diagonal
and small off-diagonal.
"""
a = model.encode(x)
b = model.decode(y)
mat = a@b.T
return objax.functional.loss.cross_entropy_logits_sparse(
logits=jn.exp(jn.clip(model.scale.w.value, -2, 4)) * mat,
labels=np.arange(a.shape[0])).mean()
ema = objax.optimizer.ExponentialMovingAverage(model.vars(), momentum=0.999)
gv = objax.GradValues(loss, model.vars())
encode_ema = ema.replace_vars(lambda x: model.encode(x))
decode_ema = ema.replace_vars(lambda y: model.decode(y))
def train_op(x, y):
"""
No one was ever fired for using Adam with 1e-4.
"""
g, v = gv(x, y)
opt(1e-4, g)
ema()
return v
opt = objax.optimizer.Adam(model.vars())
train_op = objax.Jit(train_op, gv.vars() + opt.vars() + ema.vars())
ys_ = ys_train
print(ys_.shape)
xs_ = xs_train.reshape((-1, xs_train.shape[-1]))
ys_ = ys_.reshape((-1, ys_train.shape[-1]))
# The model scale trick here is taken from CLIP.
# Let the model decide how confident to make its own predictions.
model.scale.w.assign(jn.zeros((1,1)))
valid_size = 1000
print(xs_train.shape)
# SimCLR likes big batches
B = 4096
for it in range(80):
print()
ms = []
for i in range(1000):
# First batch is smaller, to make training more stable
bs = [B//64, B][it>0]
batch = np.random.randint(0, len(xs_)-valid_size, size=bs)
r = train_op(xs_[batch], ys_[batch])
# This shouldn't happen, but if it does, better to bort early
if np.isnan(r):
print("Die on nan")
print(ms[-100:])
return
ms.append(r)
print('mean',np.mean(ms), 'scale', model.scale.w.value)
print('loss',loss(xs_[-100:], ys_[-100:]))
a = encode_ema(xs_[-valid_size:])
b = decode_ema(ys_[-valid_size:])
br = b[np.random.permutation(len(b))]
print('score',np.mean(np.sum(a*b,axis=(1)) - np.sum(a*br,axis=(1))),
np.mean(np.sum(a*b,axis=(1)) > np.sum(a*br,axis=(1))))
ckpt = objax.io.Checkpoint("saved", keep_ckpts=0)
ema.replace_vars(lambda: ckpt.save(model.vars(), 0))()
def load_challenge():
"""
Load the challenge datast for attacking
"""
global xs, ys, encoded, original, ooriginal
print("SETUP: Loading matrixes")
# The encoded images
encoded = np.load("challenge-7.npy")
# And the original images
ooriginal = original = np.load("orig-7.npy")
print("Sizes", encoded.shape, ooriginal.shape)
# Again do that ugly resize thing to make the features be on the last dimension
# Look up above to see what's going on.
patch_size = 16
ss = original.shape[2]//patch_size
original = ooriginal.reshape((original.shape[0],1,ss,patch_size,ss,patch_size))
original = original.transpose((0,2,4,1,3,5))
original = original.reshape((original.shape[0], ss**2, patch_size**2))
def match_sub(args):
"""
Find the best way to undo the permutation between two images.
"""
vec1, vec2 = args
value = np.sum((vec1[None,:,:] - vec2[:,None,:])**2,axis=2)
row, col = scipy.optimize.linear_sum_assignment(value)
return col
def recover_local_permutation():
"""
Given a set of encoded images, return a new encoding without permutations
"""
global encoded, ys
p = mp.Pool(96)
print('recover local')
local_perm = p.map(match_sub, [(encoded[0], e) for e in encoded])
local_perm = np.array(local_perm)
encoded_perm = []
for i in range(len(encoded)):
encoded_perm.append(encoded[i][np.argsort(local_perm[i])])
encoded_perm = np.array(encoded_perm)
encoded = np.array(encoded_perm)
p.close()
def recover_better_local_permutation():
"""
Given a set of encoded images, return a new encoding, but better!
"""
global encoded, ys
# Now instead of pairing all images to image 0, we compute the mean l2 vector
# and then pair all images onto the mean vector. Slightly more noise resistant.
p = mp.Pool(96)
target = encoded.mean(0)
local_perm = p.map(match_sub, [(target, e) for e in encoded])
local_perm = np.array(local_perm)
# Probably we didn't change by much, generally <0.1%
print('improved changed by', np.mean(local_perm != np.arange(local_perm.shape[1])))
encoded_perm = []
for i in range(len(encoded)):
encoded_perm.append(encoded[i][np.argsort(local_perm[i])])
encoded = np.array(encoded_perm)
p.close()
def compute_patch_similarity():
"""
Compute the feature vectors for each patch using the trained neural network.
"""
global xs, ys, xs_image, ys_image
print("Computing features")
ys = features(encoded, encoded=True)
xs = features(original, encoded=False)
model = Model()
ckpt = objax.io.Checkpoint("saved", keep_ckpts=0)
ckpt.restore(model.vars())
xs_image = model.encode(xs)
ys_image = model.decode(ys)
assert xs.shape[0] == xs_image.shape[0]
print("Done")
def match(args, ret_col=False):
"""
Compute the similarity between image features and encoded features.
"""
vec1, vec2s = args
r = []
open("/tmp/start%d.%d"%(np.random.randint(10000),time.time()),"w").write("hi")
for vec2 in vec2s:
value = np.sum(vec1[None,:,:] * vec2[:,None,:],axis=2)
row, col = scipy.optimize.linear_sum_assignment(-value)
r.append(value[row,col].mean())
return r
def recover_global_matching_first():
"""
Recover the global matching of original to encoded images by doing
an all-pairs matching problem
"""
global global_matching, ys_image, encoded
matrix = []
p = mp.Pool(96)
xs_image_ = np.array(xs_image)
ys_image_ = np.array(ys_image)
matrix = p.map(match, [(x, ys_image_) for x in xs_image_])
matrix = np.array(matrix).reshape((xs_image.shape[0],
xs_image.shape[0]))
row, col = scipy.optimize.linear_sum_assignment(-np.array(matrix))
global_matching = np.argsort(col)
print('glob',list(global_matching))
p.close()
def recover_global_permutation():
"""
Find the way that the encoded images are permuted off of the original images
"""
global global_permutation
print("Glob match", global_matching)
overall = []
for i,j in enumerate(global_matching):
overall.append(np.sum(xs_image[j][None,:,:] * ys_image[i][:,None,:],axis=2))
overall = np.mean(overall, 0)
row, col = scipy.optimize.linear_sum_assignment(-overall)
try:
print("Changed frac:", np.mean(global_permutation!=np.argsort(col)))
except:
pass
global_permutation = np.argsort(col)
def recover_global_matching_second():
"""
Match each encoded image with its original encoded image,
but better by relying on the global permutation.
"""
global global_matching_second, global_matching
ys_fix = []
for i in range(ys_image.shape[0]):
ys_fix.append(ys_image[i][global_permutation])
ys_fix = np.array(ys_fix)
print(xs_image.shape)
sims = []
for i in range(0,len(xs_image),10):
tmp = np.mean(xs_image[None,:,:,:] * ys_fix[i:i+10][:,None,:,:],axis=(2,3))
sims.extend(tmp)
sims = np.array(sims)
print(sims.shape)
row, col = scipy.optimize.linear_sum_assignment(-sims)
print('arg',sims.argmax(1))
print("Same matching frac", np.mean(col == global_matching) )
print(col)
global_matching = col
def extract_by_training(resume):
"""
Final recovery process by extracting the neural network
"""
global inverse
device = torch.device('cuda:1')
if not resume:
inverse = PrivateEncoder(Xray(True)).cuda(device)
# More adam to train.
optimizer = torch.optim.Adam(inverse.parameters(), lr=0.0001)
this_xs = ooriginal[global_matching]
this_ys = encoded[:,global_permutation,:]
for i in range(2000):
idx = np.random.random_integers(0, len(this_xs)-1, 32)
xbatch = torch.tensor(this_xs[idx]).cuda(device)
ybatch = torch.tensor(this_ys[idx]).cuda(device)
optimizer.zero_grad()
guess_output = inverse(xbatch)
# L1 loss because we don't want to be sensitive to outliers
error = torch.mean(torch.abs(guess_output-ybatch))
error.backward()
optimizer.step()
print(error)
def test_extract():
"""
Now we can recover the matching much better by computing the estimated
encodings for each original image.
"""
global err, global_matching, guessed_encoded, smatrix
device = torch.device('cuda:1')
print(ooriginal.shape, encoded.shape)
out = []
for i in range(0,len(ooriginal),32):
print(i)
out.extend(inverse(torch.tensor(ooriginal[i:i+32]).cuda(device)).cpu().detach().numpy())
guessed_encoded = np.array(out)
# Now we have to compare each encoded image with every other original image.
# Do this fast with some matrix multiplies.
out = guessed_encoded.reshape((len(encoded), -1))
real = encoded[:,global_permutation,:].reshape((len(encoded), -1))
@jax.jit
def foo(x, y):
return jn.square(x[:,None] - y[None,:]).sum(2)
smatrix = np.zeros((len(out), len(out)))
B = 500
for i in range(0,len(out),B):
print(i)
for j in range(0,len(out),B):
smatrix[i:i+B, j:j+B] = foo(out[i:i+B], real[j:j+B])
# And the final time you'l have to look at a min weight matching, I promise.
row, col = scipy.optimize.linear_sum_assignment(np.array(smatrix))
r = np.array(smatrix)
print(list(row)[::100])
print("Differences", np.mean(np.argsort(col) != global_matching))
global_matching = np.argsort(col)
def perf(steps=[]):
if len(steps) == 0:
steps.append(time.time())
else:
print("Last Time Elapsed:", time.time()-steps[-1], ' Total Time Elapsed:', time.time()-steps[0])
steps.append(time.time())
time.sleep(1)
if __name__ == "__main__":
if True:
perf()
setup('challenge')
perf()
gen_train_data()
perf()
get_train_features()
perf()
train_model()
perf()
if True:
load_challenge()
perf()
recover_local_permutation()
perf()
recover_better_local_permutation()
perf()
compute_patch_similarity()
perf()
recover_global_matching_first()
perf()
for _ in range(3):
recover_global_permutation()
perf()
recover_global_matching_second()
perf()
for i in range(3):
recover_global_permutation()
perf()
extract_by_training(i > 0)
perf()
test_extract()
perf()
print(perf())