forked from 626_privacy/tensorflow_privacy
Internal
PiperOrigin-RevId: 428549678
This commit is contained in:
parent
c8bba41059
commit
8012d5b9c9
25 changed files with 0 additions and 3053 deletions
|
@ -1,11 +0,0 @@
|
|||
# Auditing Private Machine Learning
|
||||
Code for "Auditing Differentially Private Machine Learning: How Private is Private SGD?": https://arxiv.org/abs/2006.07709. This implementation is simple but not easily parallelizable. For a parallelizable version which is harder to run, see https://github.com/jagielski/auditing-dpsgd.
|
||||
|
||||
## Usage
|
||||
This attack relies on the AuditAttack class found in audit.py. The class allows one to generate poisoning, run trials to compute membership scores for the poisoning, and then use the resulting membership scores to compute a lower bound on epsilon.
|
||||
|
||||
## Examples
|
||||
Two examples are provided, mean_audit.py and fmnist_audit.py. fmnist_audit.py attacks the FashionMNIST dataset. It allows the user to specify between standard backdoor attacks and clipping-aware attacks, and also allows the user to specify between multiple poisoning attack sizes, model types, and whether to load saved model weights to start training from. mean_audit.py audits a model which computes the mean of a dataset. This provides an example of user-provided poisoning samples, rather than those autogenerated from our attacks.py library.
|
||||
|
||||
## Requirements
|
||||
Requires scikit-learn=0.24.1, statsmodels=0.12.2, tensorflow=1.14.0
|
|
@ -1,115 +0,0 @@
|
|||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Poisoning attack library for auditing."""
|
||||
|
||||
import numpy as np
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
def make_clip_aware(train_x, train_y, l2_norm=10):
|
||||
"""
|
||||
train_x: clean training features - must be shape (n_samples, n_features)
|
||||
train_y: clean training labels - must be shape (n_samples, )
|
||||
|
||||
Returns x, y1, y2
|
||||
x: poisoning sample
|
||||
y1: first corresponding y value
|
||||
y2: second corresponding y value
|
||||
"""
|
||||
x_shape = list(train_x.shape[1:])
|
||||
to_image = lambda x: x.reshape([-1] + x_shape) # reshapes to standard image shape
|
||||
flatten = lambda x: x.reshape((x.shape[0], -1)) # flattens all pixels - allows PCA
|
||||
|
||||
# make sure to_image an flatten are inverse functions
|
||||
assert np.allclose(to_image(flatten(train_x)), train_x)
|
||||
|
||||
flat_x = flatten(train_x)
|
||||
pca = PCA(flat_x.shape[1])
|
||||
pca.fit(flat_x)
|
||||
|
||||
new_x = l2_norm*pca.components_[-1]
|
||||
|
||||
lr = LogisticRegression(max_iter=1000)
|
||||
lr.fit(flat_x, np.argmax(train_y, axis=1))
|
||||
|
||||
num_classes = train_y.shape[1]
|
||||
lr_probs = lr.predict_proba(new_x[None, :])
|
||||
min_y = np.argmin(lr_probs)
|
||||
second_y = np.argmin(lr_probs + np.eye(num_classes)[min_y])
|
||||
|
||||
oh_min_y = np.eye(num_classes)[min_y]
|
||||
oh_second_y = np.eye(num_classes)[second_y]
|
||||
|
||||
return to_image(new_x), oh_min_y, oh_second_y
|
||||
|
||||
def make_backdoor(train_x, train_y):
|
||||
"""
|
||||
Makes a backdoored dataset, following Gu et al. https://arxiv.org/abs/1708.06733
|
||||
|
||||
train_x: clean training features - must be shape (n_samples, n_features)
|
||||
train_y: clean training labels - must be shape (n_samples, )
|
||||
|
||||
Returns x, y1, y2
|
||||
x: poisoning sample
|
||||
y1: first corresponding y value
|
||||
y2: second corresponding y value
|
||||
"""
|
||||
|
||||
sample_ind = np.random.choice(train_x.shape[0], 1)
|
||||
pois_x = np.copy(train_x[sample_ind, :])
|
||||
pois_x[0] = 1 # set corner feature to 1
|
||||
second_y = train_y[sample_ind]
|
||||
|
||||
num_classes = train_y.shape[1]
|
||||
min_y = np.eye(num_classes)[second_y.argmax(1) + 1]
|
||||
|
||||
return pois_x, min_y, second_y
|
||||
|
||||
|
||||
def make_many_poisoned_datasets(train_x, train_y, pois_sizes, attack="clip_aware", l2_norm=10):
|
||||
"""
|
||||
Makes a dict containing many poisoned datasets. make_pois is fairly slow:
|
||||
this avoids making multiple calls
|
||||
|
||||
train_x: clean training features - shape (n_samples, n_features)
|
||||
train_y: clean training labels - shape (n_samples, )
|
||||
pois_sizes: list of poisoning sizes
|
||||
l2_norm: l2 norm of the poisoned data
|
||||
|
||||
Returns dict: all_poisons
|
||||
all_poisons[poison_size] is a pair of poisoned datasets
|
||||
"""
|
||||
if attack == "clip_aware":
|
||||
pois_sample_x, y, second_y = make_clip_aware(train_x, train_y, l2_norm)
|
||||
elif attack == "backdoor":
|
||||
pois_sample_x, y, second_y = make_backdoor(train_x, train_y)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
all_poisons = {"pois": (pois_sample_x, y)}
|
||||
|
||||
for pois_size in pois_sizes: # make_pois is slow - don't want it in a loop
|
||||
new_pois_x1, new_pois_y1 = train_x.copy(), train_y.copy()
|
||||
new_pois_x2, new_pois_y2 = train_x.copy(), train_y.copy()
|
||||
|
||||
new_pois_x1[-pois_size:] = pois_sample_x[None, :]
|
||||
new_pois_y1[-pois_size:] = y
|
||||
|
||||
new_pois_x2[-pois_size:] = pois_sample_x[None, :]
|
||||
new_pois_y2[-pois_size:] = second_y
|
||||
|
||||
dataset1, dataset2 = (new_pois_x1, new_pois_y1), (new_pois_x2, new_pois_y2)
|
||||
all_poisons[pois_size] = dataset1, dataset2
|
||||
|
||||
return all_poisons
|
|
@ -1,119 +0,0 @@
|
|||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Class for running auditing procedure."""
|
||||
|
||||
import numpy as np
|
||||
from statsmodels.stats import proportion
|
||||
|
||||
import attacks
|
||||
|
||||
def compute_results(poison_scores, unpois_scores, pois_ct,
|
||||
alpha=0.05, threshold=None):
|
||||
"""
|
||||
Searches over thresholds for the best epsilon lower bound and accuracy.
|
||||
poison_scores: list of scores from poisoned models
|
||||
unpois_scores: list of scores from unpoisoned models
|
||||
pois_ct: number of poison points
|
||||
alpha: confidence parameter
|
||||
threshold: if None, search over all thresholds, else use given threshold
|
||||
"""
|
||||
if threshold is None: # search for best threshold
|
||||
all_thresholds = np.unique(poison_scores + unpois_scores)
|
||||
else:
|
||||
all_thresholds = [threshold]
|
||||
|
||||
poison_arr = np.array(poison_scores)
|
||||
unpois_arr = np.array(unpois_scores)
|
||||
|
||||
best_threshold, best_epsilon, best_acc = None, 0, 0
|
||||
for thresh in all_thresholds:
|
||||
epsilon, acc = compute_epsilon_and_acc(poison_arr, unpois_arr, thresh,
|
||||
alpha, pois_ct)
|
||||
if epsilon > best_epsilon:
|
||||
best_epsilon, best_threshold = epsilon, thresh
|
||||
best_acc = max(best_acc, acc)
|
||||
return best_threshold, best_epsilon, best_acc
|
||||
|
||||
|
||||
def compute_epsilon_and_acc(poison_arr, unpois_arr, threshold, alpha, pois_ct):
|
||||
"""For a given threshold, compute epsilon and accuracy."""
|
||||
poison_ct = (poison_arr > threshold).sum()
|
||||
unpois_ct = (unpois_arr > threshold).sum()
|
||||
|
||||
# clopper_pearson uses alpha/2 budget on upper and lower
|
||||
# so total budget will be 2*alpha/2 = alpha
|
||||
p1, _ = proportion.proportion_confint(poison_ct, poison_arr.size,
|
||||
alpha, method='beta')
|
||||
_, p0 = proportion.proportion_confint(unpois_ct, unpois_arr.size,
|
||||
alpha, method='beta')
|
||||
|
||||
if (p1 <= 1e-5) or (p0 >= 1 - 1e-5): # divide by zero issues
|
||||
return 0, 0
|
||||
|
||||
if (p0 + p1) > 1: # see Appendix A
|
||||
p0, p1 = (1-p1), (1-p0)
|
||||
|
||||
epsilon = np.log(p1/p0)/pois_ct
|
||||
acc = (p1 + (1-p0))/2 # this is not necessarily the best accuracy
|
||||
|
||||
return epsilon, acc
|
||||
|
||||
|
||||
class AuditAttack(object):
|
||||
"""Audit attack class. Generates poisoning, then runs auditing algorithm."""
|
||||
def __init__(self, train_x, train_y, train_function):
|
||||
"""
|
||||
train_x: training features
|
||||
train_y: training labels
|
||||
name: identifier for the attack
|
||||
train_function: function returning membership score
|
||||
"""
|
||||
self.train_x, self.train_y = train_x, train_y
|
||||
self.train_function = train_function
|
||||
self.poisoning = None
|
||||
|
||||
def make_poisoning(self, pois_ct, attack_type, l2_norm=10):
|
||||
"""Get poisoning data."""
|
||||
return attacks.make_many_poisoned_datasets(self.train_x, self.train_y, [pois_ct],
|
||||
attack=attack_type, l2_norm=l2_norm)
|
||||
|
||||
def run_experiments(self, num_trials):
|
||||
"""Runs all training experiments."""
|
||||
(pois_x1, pois_y1), (pois_x2, pois_y2) = self.poisoning['data']
|
||||
sample_x, sample_y = self.poisoning['pois']
|
||||
|
||||
poison_scores = []
|
||||
unpois_scores = []
|
||||
|
||||
for i in range(num_trials):
|
||||
poison_tuple = (pois_x1, pois_y1, sample_x, sample_y, i)
|
||||
unpois_tuple = (pois_x2, pois_y2, sample_x, sample_y, num_trials + i)
|
||||
poison_scores.append(self.train_function(poison_tuple))
|
||||
unpois_scores.append(self.train_function(unpois_tuple))
|
||||
|
||||
return poison_scores, unpois_scores
|
||||
|
||||
def run(self, pois_ct, attack_type, num_trials, alpha=0.05,
|
||||
threshold=None, l2_norm=10):
|
||||
"""Complete auditing algorithm. Generates poisoning if necessary."""
|
||||
if self.poisoning is None:
|
||||
self.poisoning = self.make_poisoning(pois_ct, attack_type, l2_norm)
|
||||
self.poisoning['data'] = self.poisoning[pois_ct]
|
||||
|
||||
poison_scores, unpois_scores = self.run_experiments(num_trials)
|
||||
|
||||
results = compute_results(poison_scores, unpois_scores, pois_ct,
|
||||
alpha=alpha, threshold=threshold)
|
||||
return results
|
|
@ -1,91 +0,0 @@
|
|||
# Copyright 2021, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Tests for audit.py."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import audit
|
||||
|
||||
def dummy_train_and_score_function(dataset):
|
||||
del dataset
|
||||
return 0
|
||||
|
||||
def get_auditor():
|
||||
poisoning = {}
|
||||
datasets = (np.zeros((5, 2)), np.zeros(5)), (np.zeros((5, 2)), np.zeros(5))
|
||||
poisoning["data"] = datasets
|
||||
poisoning["pois"] = (datasets[0][0][0], datasets[0][1][0])
|
||||
auditor = audit.AuditAttack(datasets[0][0], datasets[0][1],
|
||||
dummy_train_and_score_function)
|
||||
auditor.poisoning = poisoning
|
||||
|
||||
return auditor
|
||||
|
||||
|
||||
class AuditParameterizedTest(parameterized.TestCase):
|
||||
"""Class to test parameterized audit.py functions."""
|
||||
@parameterized.named_parameters(
|
||||
('Test0', np.ones(500), np.zeros(500), 0.5, 0.01, 1,
|
||||
(4.541915810224092, 0.9894593118113243)),
|
||||
('Test1', np.ones(500), np.zeros(500), 0.5, 0.01, 2,
|
||||
(2.27095790511, 0.9894593118113243)),
|
||||
('Test2', np.ones(500), np.ones(500), 0.5, 0.01, 1,
|
||||
(0, 0))
|
||||
)
|
||||
|
||||
def test_compute_epsilon_and_acc(self, poison_scores, unpois_scores,
|
||||
threshold, pois_ct, alpha, expected_res):
|
||||
expected_eps, expected_acc = expected_res
|
||||
computed_res = audit.compute_epsilon_and_acc(poison_scores, unpois_scores,
|
||||
threshold, pois_ct, alpha)
|
||||
computed_eps, computed_acc = computed_res
|
||||
self.assertAlmostEqual(computed_eps, expected_eps)
|
||||
self.assertAlmostEqual(computed_acc, expected_acc)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('Test0', [1]*500, [0]*250 + [.5]*250, 1, 0.01, .5,
|
||||
(.5, 4.541915810224092, 0.9894593118113243)),
|
||||
('Test1', [1]*500, [0]*250 + [.5]*250, 1, 0.01, None,
|
||||
(.5, 4.541915810224092, 0.9894593118113243)),
|
||||
('Test2', [1]*500, [0]*500, 2, 0.01, .5,
|
||||
(.5, 2.27095790511, 0.9894593118113243)),
|
||||
)
|
||||
|
||||
def test_compute_results(self, poison_scores, unpois_scores, pois_ct,
|
||||
alpha, threshold, expected_res):
|
||||
expected_thresh, expected_eps, expected_acc = expected_res
|
||||
computed_res = audit.compute_results(poison_scores, unpois_scores,
|
||||
pois_ct, alpha, threshold)
|
||||
computed_thresh, computed_eps, computed_acc = computed_res
|
||||
self.assertAlmostEqual(computed_thresh, expected_thresh)
|
||||
self.assertAlmostEqual(computed_eps, expected_eps)
|
||||
self.assertAlmostEqual(computed_acc, expected_acc)
|
||||
|
||||
|
||||
class AuditAttackTest(absltest.TestCase):
|
||||
"""Nonparameterized audit.py test class."""
|
||||
def test_run_experiments(self):
|
||||
auditor = get_auditor()
|
||||
pois, unpois = auditor.run_experiments(100)
|
||||
expected = [0]*100
|
||||
self.assertListEqual(pois, expected)
|
||||
self.assertListEqual(unpois, expected)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
|
@ -1,176 +0,0 @@
|
|||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Run auditing on the FashionMNIST dataset."""
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
|
||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from tensorflow_privacy.privacy.optimizers import dp_optimizer_vectorized
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import audit
|
||||
|
||||
#### FLAGS
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_float('learning_rate', 0.15, 'Learning rate for training')
|
||||
flags.DEFINE_float('noise_multiplier', 1.1,
|
||||
'Ratio of the standard deviation to the clipping norm')
|
||||
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
||||
flags.DEFINE_integer('batch_size', 250, 'Batch size')
|
||||
flags.DEFINE_integer('epochs', 24, 'Number of epochs')
|
||||
flags.DEFINE_integer(
|
||||
'microbatches', 250, 'Number of microbatches '
|
||||
'(must evenly divide batch_size)')
|
||||
flags.DEFINE_string('model', 'lr', 'model to use, pick between lr and nn')
|
||||
flags.DEFINE_string('attack_type', "clip_aware", 'clip_aware or backdoor')
|
||||
flags.DEFINE_integer('pois_ct', 1, 'Number of poisoning points')
|
||||
flags.DEFINE_integer('num_trials', 100, 'Number of trials for auditing')
|
||||
flags.DEFINE_float('attack_l2_norm', 10, 'Size of poisoning data')
|
||||
flags.DEFINE_float('alpha', 0.05, '1-confidence')
|
||||
flags.DEFINE_boolean('load_weights', False,
|
||||
'if True, use weights saved in init_weights.h5')
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def compute_epsilon(train_size):
|
||||
"""Computes epsilon value for given hyperparameters."""
|
||||
if FLAGS.noise_multiplier == 0.0:
|
||||
return float('inf')
|
||||
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
|
||||
sampling_probability = FLAGS.batch_size / train_size
|
||||
steps = FLAGS.epochs * train_size / FLAGS.batch_size
|
||||
rdp = compute_rdp(q=sampling_probability,
|
||||
noise_multiplier=FLAGS.noise_multiplier,
|
||||
steps=steps,
|
||||
orders=orders)
|
||||
# Delta is set to approximate 1 / (number of training points).
|
||||
return get_privacy_spent(orders, rdp, target_delta=1e-5)[0]
|
||||
|
||||
def build_model(x, y):
|
||||
"""Build a keras model."""
|
||||
input_shape = x.shape[1:]
|
||||
num_classes = y.shape[1]
|
||||
l2 = 0
|
||||
if FLAGS.model == 'lr':
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.layers.Flatten(input_shape=input_shape),
|
||||
tf.keras.layers.Dense(num_classes, kernel_initializer='glorot_normal',
|
||||
kernel_regularizer=tf.keras.regularizers.l2(l2))
|
||||
])
|
||||
elif FLAGS.model == 'nn':
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.layers.Flatten(input_shape=input_shape),
|
||||
tf.keras.layers.Dense(32, activation='relu',
|
||||
kernel_initializer='glorot_normal',
|
||||
kernel_regularizer=tf.keras.regularizers.l2(l2)),
|
||||
tf.keras.layers.Dense(num_classes, kernel_initializer='glorot_normal',
|
||||
kernel_regularizer=tf.keras.regularizers.l2(l2))
|
||||
])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return model
|
||||
|
||||
|
||||
def train_model(model, train_x, train_y, save_weights=False):
|
||||
"""Train the model on given data."""
|
||||
optimizer = dp_optimizer_vectorized.VectorizedDPSGD(
|
||||
l2_norm_clip=FLAGS.l2_norm_clip,
|
||||
noise_multiplier=FLAGS.noise_multiplier,
|
||||
num_microbatches=FLAGS.microbatches,
|
||||
learning_rate=FLAGS.learning_rate)
|
||||
|
||||
loss = tf.keras.losses.CategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.losses.Reduction.NONE)
|
||||
|
||||
# Compile model with Keras
|
||||
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
||||
|
||||
if save_weights:
|
||||
wts = model.get_weights()
|
||||
np.save('save_model', wts)
|
||||
model.set_weights(wts)
|
||||
return model
|
||||
|
||||
if FLAGS.load_weights: # load preset weights
|
||||
wts = np.load('save_model.npy', allow_pickle=True).tolist()
|
||||
model.set_weights(wts)
|
||||
|
||||
# Train model with Keras
|
||||
model.fit(train_x, train_y,
|
||||
epochs=FLAGS.epochs,
|
||||
validation_data=(train_x, train_y),
|
||||
batch_size=FLAGS.batch_size,
|
||||
verbose=0)
|
||||
return model
|
||||
|
||||
|
||||
def membership_test(model, pois_x, pois_y):
|
||||
"""Membership inference - detect poisoning."""
|
||||
probs = model.predict(np.concatenate([pois_x, np.zeros_like(pois_x)]))
|
||||
return np.multiply(probs[0, :] - probs[1, :], pois_y).sum()
|
||||
|
||||
|
||||
def train_and_score(dataset):
|
||||
"""Complete training run with membership inference score."""
|
||||
x, y, pois_x, pois_y, i = dataset
|
||||
np.random.seed(i)
|
||||
tf.set_random_seed(i)
|
||||
tf.reset_default_graph()
|
||||
model = build_model(x, y)
|
||||
model = train_model(model, x, y)
|
||||
return membership_test(model, pois_x, pois_y)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
del unused_argv
|
||||
# Load training and test data.
|
||||
np.random.seed(0)
|
||||
|
||||
(train_x, train_y), _ = tf.keras.datasets.fashion_mnist.load_data()
|
||||
train_inds = np.where(train_y < 2)[0]
|
||||
|
||||
train_x = -.5 + train_x[train_inds] / 255.
|
||||
train_y = np.eye(2)[train_y[train_inds]]
|
||||
|
||||
# subsample dataset
|
||||
ss_inds = np.random.choice(train_x.shape[0], train_x.shape[0]//2, replace=False)
|
||||
train_x = train_x[ss_inds]
|
||||
train_y = train_y[ss_inds]
|
||||
|
||||
init_model = build_model(train_x, train_y)
|
||||
_ = train_model(init_model, train_x, train_y, save_weights=True)
|
||||
|
||||
auditor = audit.AuditAttack(train_x, train_y, train_and_score)
|
||||
|
||||
thresh, _, _ = auditor.run(FLAGS.pois_ct, FLAGS.attack_type, FLAGS.num_trials,
|
||||
alpha=FLAGS.alpha, threshold=None,
|
||||
l2_norm=FLAGS.attack_l2_norm)
|
||||
|
||||
_, eps, acc = auditor.run(FLAGS.pois_ct, FLAGS.attack_type, FLAGS.num_trials,
|
||||
alpha=FLAGS.alpha, threshold=thresh,
|
||||
l2_norm=FLAGS.attack_l2_norm)
|
||||
|
||||
epsilon_upper_bound = compute_epsilon(train_x.shape[0])
|
||||
|
||||
print("Analysis epsilon is {}.".format(epsilon_upper_bound))
|
||||
print("At threshold={}, epsilon={}.".format(thresh, eps))
|
||||
print("The best accuracy at distinguishing poisoning is {}.".format(acc))
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
|
@ -1,152 +0,0 @@
|
|||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Auditing a model which computes the mean of a synthetic dataset.
|
||||
This gives an example for instrumenting the auditor to audit a user-given sample."""
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
|
||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from tensorflow_privacy.privacy.optimizers import dp_optimizer_vectorized
|
||||
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import audit
|
||||
|
||||
#### FLAGS
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_float('learning_rate', 0.15, 'Learning rate for training')
|
||||
flags.DEFINE_float('noise_multiplier', 1.1,
|
||||
'Ratio of the standard deviation to the clipping norm')
|
||||
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
||||
flags.DEFINE_integer('batch_size', 250, 'Batch size')
|
||||
flags.DEFINE_integer('d', 250, 'Data dimension')
|
||||
flags.DEFINE_integer('epochs', 1, 'Number of epochs')
|
||||
flags.DEFINE_integer(
|
||||
'microbatches', 250, 'Number of microbatches '
|
||||
'(must evenly divide batch_size)')
|
||||
flags.DEFINE_string('attack_type', "clip_aware", 'clip_aware or backdoor')
|
||||
flags.DEFINE_integer('num_trials', 100, 'Number of trials for auditing')
|
||||
flags.DEFINE_float('attack_l2_norm', 10, 'Size of poisoning data')
|
||||
flags.DEFINE_float('alpha', 0.05, '1-confidence')
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def compute_epsilon(train_size):
|
||||
"""Computes epsilon value for given hyperparameters."""
|
||||
if FLAGS.noise_multiplier == 0.0:
|
||||
return float('inf')
|
||||
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
|
||||
sampling_probability = FLAGS.batch_size / train_size
|
||||
steps = FLAGS.epochs * train_size / FLAGS.batch_size
|
||||
rdp = compute_rdp(q=sampling_probability,
|
||||
noise_multiplier=FLAGS.noise_multiplier,
|
||||
steps=steps,
|
||||
orders=orders)
|
||||
# Delta is set to approximate 1 / (number of training points).
|
||||
return get_privacy_spent(orders, rdp, target_delta=1e-5)[0]
|
||||
|
||||
def build_model(x, y):
|
||||
del x, y
|
||||
model = tf.keras.Sequential([tf.keras.layers.Dense(
|
||||
1, input_shape=(FLAGS.d,),
|
||||
use_bias=False, kernel_initializer=tf.keras.initializers.Zeros())])
|
||||
return model
|
||||
|
||||
|
||||
def train_model(model, train_x, train_y):
|
||||
"""Train the model on given data."""
|
||||
optimizer = dp_optimizer_vectorized.VectorizedDPSGD(
|
||||
l2_norm_clip=FLAGS.l2_norm_clip,
|
||||
noise_multiplier=FLAGS.noise_multiplier,
|
||||
num_microbatches=FLAGS.microbatches,
|
||||
learning_rate=FLAGS.learning_rate)
|
||||
|
||||
# gradient of (.5-x.w)^2 is 2(.5-x.w)x
|
||||
loss = tf.keras.losses.MeanSquaredError(reduction=tf.losses.Reduction.NONE)
|
||||
|
||||
# Compile model with Keras
|
||||
model.compile(optimizer=optimizer, loss=loss, metrics=['mse'])
|
||||
|
||||
# Train model with Keras
|
||||
model.fit(train_x, train_y,
|
||||
epochs=FLAGS.epochs,
|
||||
validation_data=(train_x, train_y),
|
||||
batch_size=FLAGS.batch_size,
|
||||
verbose=0)
|
||||
return model
|
||||
|
||||
|
||||
def membership_test(model, pois_x, pois_y):
|
||||
"""Membership inference - detect poisoning."""
|
||||
del pois_y
|
||||
return model.predict(pois_x)
|
||||
|
||||
|
||||
def gen_data(n, d):
|
||||
"""Make binomial dataset."""
|
||||
x = np.random.normal(size=(n, d))
|
||||
y = np.ones(shape=(n,))/2.
|
||||
return x, y
|
||||
|
||||
|
||||
def train_and_score(dataset):
|
||||
"""Complete training run with membership inference score."""
|
||||
x, y, pois_x, pois_y, i = dataset
|
||||
np.random.seed(i)
|
||||
tf.set_random_seed(i)
|
||||
model = build_model(x, y)
|
||||
model = train_model(model, x, y)
|
||||
return membership_test(model, pois_x, pois_y)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
del unused_argv
|
||||
# Load training and test data.
|
||||
np.random.seed(0)
|
||||
|
||||
x, y = gen_data(1 + FLAGS.batch_size, FLAGS.d)
|
||||
|
||||
auditor = audit.AuditAttack(x, y, train_and_score)
|
||||
|
||||
# we will instrument the auditor to simply backdoor the last feature
|
||||
pois_x1, pois_x2 = x[:-1].copy(), x[:-1].copy()
|
||||
pois_x1[-1] = x[-1]
|
||||
pois_y = y[:-1]
|
||||
target_x = x[-1][None, :]
|
||||
assert np.unique(np.nonzero(pois_x1 - pois_x2)[0]).size == 1
|
||||
|
||||
pois_data = (pois_x1, pois_y), (pois_x2, pois_y), (target_x, y[-1])
|
||||
poisoning = {}
|
||||
poisoning["data"] = (pois_data[0], pois_data[1])
|
||||
poisoning["pois"] = pois_data[2]
|
||||
auditor.poisoning = poisoning
|
||||
|
||||
thresh, _, _ = auditor.run(1, None, FLAGS.num_trials, alpha=FLAGS.alpha)
|
||||
|
||||
_, eps, acc = auditor.run(1, None, FLAGS.num_trials, alpha=FLAGS.alpha,
|
||||
threshold=thresh)
|
||||
|
||||
epsilon_upper_bound = compute_epsilon(FLAGS.batch_size)
|
||||
|
||||
print("Analysis epsilon is {}.".format(epsilon_upper_bound))
|
||||
print("At threshold={}, epsilon={}.".format(thresh, eps))
|
||||
print("The best accuracy at distinguishing poisoning is {}.".format(acc))
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
|
@ -1,66 +0,0 @@
|
|||
Implementation of our reconstruction attack on InstaHide.
|
||||
|
||||
Is Private Learning Possible with Instance Encoding?
|
||||
Nicholas Carlini, Samuel Deng, Sanjam Garg, Somesh Jha, Saeed Mahloujifar, Mohammad Mahmoody, Shuang Song, Abhradeep Thakurta, Florian Tramer
|
||||
https://arxiv.org/abs/2011.05315
|
||||
|
||||
|
||||
## Overview
|
||||
|
||||
InstaHide is a recent privacy-preserving machine learning framework.
|
||||
It takes a (sensitive) dataset and generates encoded images that are privacy-preserving.
|
||||
Our attack breaks InstaHide and shows it does not offer meaningful privacy.
|
||||
Given the encoded dataset, we can recover a near-identical copy of the original images.
|
||||
|
||||
This repository implements the attack described in our paper. It consists of a number of
|
||||
steps that shoul be run sequentially. It assumes access to pre-trained neural network
|
||||
classifiers that should be downloaded following the steps below.
|
||||
|
||||
|
||||
### Requirements
|
||||
|
||||
* Python, version ≥ 3.5
|
||||
* jax
|
||||
* jaxlib
|
||||
* objax (https://github.com/google/objax)
|
||||
* PIL
|
||||
* sklearn
|
||||
|
||||
|
||||
### Running the attack
|
||||
|
||||
To reproduce our results and run the attack, each of the files should be run in turn.
|
||||
|
||||
0. Download the necessary dependency files:
|
||||
- (encryption.npy)[https://www.dropbox.com/sh/8zdsr1sjftia4of/AAA-60TOjGKtGEZrRmbawwqGa?dl=0] and (labels.npy)[https://www.dropbox.com/sh/8zdsr1sjftia4of/AAA-60TOjGKtGEZrRmbawwqGa?dl=0] from the (InstaHide Challenge)[https://github.com/Hazelsuko07/InstaHide_Challenge]
|
||||
- The (saved models)[https://drive.google.com/file/d/1YfKzGRfnnzKfUKpLjIRXRto8iD4FdwGw/view?usp=sharing] used to run the attack
|
||||
- Set up all the requirements as above
|
||||
|
||||
1. Run `step_1_create_graph.py`. Produce the similarity graph to pair together encoded images that share an original image.
|
||||
|
||||
2. Run `step_2_color_graph.py`. Color the graph to find 50 dense cliques.
|
||||
|
||||
3. Run `step_3_second_graph.py`. Create a new bipartite similarity graph.
|
||||
|
||||
4. Run `step_4_final_graph.py`. Solve the matching problem to assign encoded images to original images.
|
||||
|
||||
5. Run `step_5_reconstruct.py`. Reconstruct the original images.
|
||||
|
||||
6. Run `step_6_adjust_color.py`. Adjust the color curves to match.
|
||||
|
||||
7. Run `step_7_visualize.py`. Show the final resulting images.
|
||||
|
||||
## Citation
|
||||
|
||||
You can cite this attack at
|
||||
|
||||
```
|
||||
@inproceedings{carlini2021private,
|
||||
title={Is Private Learning Possible with Instance Encoding?},
|
||||
author={Carlini, Nicholas and Deng, Samuel and Garg, Sanjam and Jha, Somesh and Mahloujifar, Saeed and Mahmoody, Mohammad and Thakurta, Abhradeep and Tram{\`e}r, Florian},
|
||||
booktitle={2021 IEEE Symposium on Security and Privacy (SP)},
|
||||
pages={410--427},
|
||||
year={2021},
|
||||
organization={IEEE}
|
||||
}
|
||||
```
|
|
@ -1,77 +0,0 @@
|
|||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""
|
||||
Create the similarity graph given the encoded images by running the similarity
|
||||
neural network over all pairs of images.
|
||||
"""
|
||||
|
||||
import objax
|
||||
import numpy as np
|
||||
import jax.numpy as jn
|
||||
import functools
|
||||
import os
|
||||
import random
|
||||
|
||||
from objax.zoo import wide_resnet
|
||||
|
||||
def setup():
|
||||
global model
|
||||
class DoesUseSame(objax.Module):
|
||||
def __init__(self):
|
||||
fn = functools.partial(wide_resnet.WideResNet, depth=28, width=6)
|
||||
self.model = fn(6,2)
|
||||
|
||||
model_vars = self.model.vars()
|
||||
self.ema = objax.optimizer.ExponentialMovingAverage(model_vars, momentum=0.999, debias=True)
|
||||
|
||||
|
||||
def predict_op(x,y):
|
||||
# The model takes the two images and checks if they correspond
|
||||
# to the same original image.
|
||||
xx = jn.concatenate([jn.abs(x),
|
||||
jn.abs(y)],
|
||||
axis=1)
|
||||
return self.model(xx, training=False)
|
||||
|
||||
self.predict = objax.Jit(self.ema.replace_vars(predict_op), model_vars + self.ema.vars())
|
||||
self.predict_fast = objax.Parallel(self.ema.replace_vars(predict_op), model_vars + self.ema.vars())
|
||||
|
||||
model = DoesUseSame()
|
||||
checkpoint = objax.io.Checkpoint("models/step1/", keep_ckpts=5, makedir=True)
|
||||
start_epoch, last_ckpt = checkpoint.restore(model.vars())
|
||||
|
||||
|
||||
def doall():
|
||||
global graph
|
||||
n = np.load("data/encryption.npy")
|
||||
n = np.transpose(n, (0,3,1,2))
|
||||
|
||||
# Compute the similarity between each encoded image and all others
|
||||
# This is n^2 work but should run fairly quickly, especially given
|
||||
# more than one GPU. Otherwise about an hour or so.
|
||||
graph = []
|
||||
with model.vars().replicate():
|
||||
for i in range(5000):
|
||||
print(i)
|
||||
v = model.predict_fast(np.tile(n[i:i+1], (5000,1,1,1)), n)
|
||||
graph.append(np.array(v[:,0]-v[:,1]))
|
||||
graph = np.array(graph)
|
||||
np.save("data/graph.npy", graph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup()
|
||||
doall()
|
|
@ -1,95 +0,0 @@
|
|||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import multiprocessing as mp
|
||||
import pickle
|
||||
import random
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
def score(subset):
|
||||
sub = graph[subset]
|
||||
sub = sub[:,subset]
|
||||
return np.sum(sub)
|
||||
|
||||
def run(v, return_scores=False):
|
||||
if isinstance(v, int):
|
||||
v = [v]
|
||||
scores = []
|
||||
for _ in range(100):
|
||||
keep = graph[v,:]
|
||||
next_value = np.sum(keep,axis=0)
|
||||
to_add = next_value.argsort()
|
||||
to_add = [x for x in to_add if x not in v]
|
||||
if _ < 1:
|
||||
v.append(to_add[random.randint(0,10)])
|
||||
else:
|
||||
v.append(to_add[0])
|
||||
if return_scores:
|
||||
scores.append(score(v)/len(keep))
|
||||
if return_scores:
|
||||
return v, scores
|
||||
else:
|
||||
return v
|
||||
|
||||
def make_many_clusters():
|
||||
# Compute clusters of 100 examples that probably correspond to some original image
|
||||
p = mp.Pool(mp.cpu_count())
|
||||
s = p.map(run, range(2000))
|
||||
return s
|
||||
|
||||
|
||||
def downselect_clusters(s):
|
||||
# Right now we have a lot of clusters, but they probably overlap. Let's remove that.
|
||||
# We want to find disjoint clusters, so we'll greedily add them until we have
|
||||
# 100 distjoint clusters.
|
||||
|
||||
ss = [set(x) for x in s]
|
||||
|
||||
keep = []
|
||||
keep_set = []
|
||||
for iteration in range(2):
|
||||
for this_set in s:
|
||||
# MAGIC NUMBERS...!
|
||||
# We want clusters of size 50 because it works
|
||||
# Except on iteration 2 where we'll settle for 25 if we haven't
|
||||
# found clusters with 50 neighbors that work.
|
||||
cur = set(this_set[:50 - 25*iteration])
|
||||
intersections = np.array([len(cur & x) for x in ss])
|
||||
good = np.sum(intersections==50)>2
|
||||
# Good means that this cluster isn't a fluke and some other cluster
|
||||
# is like this one.
|
||||
if good or iteration == 1:
|
||||
print("N")
|
||||
# And also make sure we haven't found this cluster (or one like it).
|
||||
already_found = np.array([len(cur & x) for x in keep_set])
|
||||
if np.all(already_found<len(cur)/2):
|
||||
print("And is new")
|
||||
keep.append(this_set)
|
||||
keep_set.append(set(this_set))
|
||||
if len(keep) == 100:
|
||||
break
|
||||
print("Found", len(keep))
|
||||
if len(keep) == 100:
|
||||
break
|
||||
|
||||
# Keep should now have 100 items.
|
||||
# If it doesn't go and change the 2000 in make_many_clusters to a bigger number.
|
||||
return keep
|
||||
|
||||
if __name__ == "__main__":
|
||||
graph = np.load("data/graph.npy")
|
||||
np.save("data/many_clusters",make_many_clusters())
|
||||
np.save("data/100_clusters", downselect_clusters(np.load("data/many_clusters.npy")))
|
|
@ -1,114 +0,0 @@
|
|||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""
|
||||
Create the improved graph mapping each encoded image to an original image.
|
||||
"""
|
||||
|
||||
import objax
|
||||
import numpy as np
|
||||
import jax.numpy as jn
|
||||
import functools
|
||||
import os
|
||||
import random
|
||||
|
||||
from objax.zoo import wide_resnet
|
||||
|
||||
|
||||
def setup():
|
||||
global model
|
||||
class DoesUseSame(objax.Module):
|
||||
def __init__(self):
|
||||
fn = functools.partial(wide_resnet.WideResNet, depth=28, width=6)
|
||||
self.model = fn(3*4,2)
|
||||
|
||||
model_vars = self.model.vars()
|
||||
self.ema = objax.optimizer.ExponentialMovingAverage(model_vars, momentum=0.999, debias=True)
|
||||
|
||||
|
||||
def predict_op(x,y):
|
||||
# The model takes SEVERAL images and checks if they all correspond
|
||||
# to the same original image.
|
||||
# Guaranteed that the first N-1 all do, the test is if the last does.
|
||||
xx = jn.concatenate([jn.abs(x),
|
||||
jn.abs(y)],
|
||||
axis=1)
|
||||
return self.model(xx, training=False)
|
||||
|
||||
self.predict = objax.Jit(self.ema.replace_vars(predict_op), model_vars + self.ema.vars())
|
||||
|
||||
model = DoesUseSame()
|
||||
checkpoint = objax.io.Checkpoint("models/step2/", keep_ckpts=5, makedir=True)
|
||||
start_epoch, last_ckpt = checkpoint.restore(model.vars())
|
||||
|
||||
def step2():
|
||||
global v, n, u, nextgraph
|
||||
|
||||
# Start out by loading the encoded images
|
||||
n = np.load("data/encryption.npy")
|
||||
n = np.transpose(n, (0,3,1,2))
|
||||
|
||||
# Then load the graph with 100 cluster-centers.
|
||||
keep = np.array(np.load("data/100_clusters.npy", allow_pickle=True))
|
||||
graph = np.load("data/graph.npy")
|
||||
|
||||
|
||||
# Now we're going to record the distance to each of the cluster centers
|
||||
# from every encoded image, so that we can do the matching.
|
||||
|
||||
# To do that, though, first we need to choose the cluster centers.
|
||||
# Start out by choosing the best cluster centers.
|
||||
|
||||
distances = []
|
||||
|
||||
for x in keep:
|
||||
this_set = x[:50]
|
||||
use_elts = graph[this_set]
|
||||
distances.append(np.sum(use_elts,axis=0))
|
||||
distances = np.array(distances)
|
||||
|
||||
ds = np.argsort(distances, axis=1)
|
||||
|
||||
# Now we record the "prototypes" of each cluster center.
|
||||
# We just need three, more might help a little bit but not much.
|
||||
# (And then do that ten times, so we can average out noise
|
||||
# with respect to which cluster centers we picked.)
|
||||
|
||||
prototypes = []
|
||||
for _ in range(10):
|
||||
ps = []
|
||||
# choose 3 random samples from each set
|
||||
for i in range(3):
|
||||
ps.append(n[ds[:,random.randint(0,20)]])
|
||||
prototypes.append(np.concatenate(ps,1))
|
||||
prototypes = np.concatenate(prototypes,0)
|
||||
|
||||
# Finally compute the distances from each node to each cluster center.
|
||||
nextgraph = []
|
||||
for i in range(5000):
|
||||
out = model.predict(prototypes, np.tile(n[i:i+1], (1000,1,1,1)))
|
||||
out = out.reshape((10, 100, 2))
|
||||
|
||||
v = np.sum(out,axis=0)
|
||||
v = v[:,0] - v[:,1]
|
||||
v = np.array(v)
|
||||
nextgraph.append(v)
|
||||
|
||||
np.save("data/nextgraph.npy", nextgraph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup()
|
||||
step2()
|
|
@ -1,51 +0,0 @@
|
|||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import multiprocessing as mp
|
||||
import pickle
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
|
||||
labels = np.load("data/label.npy")
|
||||
nextgraph = np.load("data/nextgraph.npy")
|
||||
|
||||
assigned = [[] for _ in range(5000)]
|
||||
lambdas = [[] for _ in range(5000)]
|
||||
for i in range(100):
|
||||
order = (np.argsort(nextgraph[:,i]))
|
||||
correct = (labels[order[:20]]>0).sum(axis=0).argmax()
|
||||
|
||||
# Let's create the final graph
|
||||
# Instead of doing a full bipartite matching, let's just greedily
|
||||
# choose the closest 80 candidates for each encoded image to pair
|
||||
# together can call it a day.
|
||||
# This is within a percent or two of doing that, and much easier.
|
||||
|
||||
# Also record the lambdas based on which image it coresponds to,
|
||||
# but if they share a label then just guess it's an even 50/50 split.
|
||||
|
||||
|
||||
for x in order[:80]:
|
||||
if labels[x][correct] > 0 and len(assigned[x]) < 2:
|
||||
assigned[x].append(i)
|
||||
if np.sum(labels[x]>0) == 1:
|
||||
# the same label was mixed in twice. punt.
|
||||
lambdas[x].append(labels[x][correct]/2)
|
||||
else:
|
||||
lambdas[x].append(labels[x][correct])
|
||||
|
||||
np.save("data/predicted_pairings_80.npy", assigned)
|
||||
np.save("data/predicted_lambdas_80.npy", lambdas)
|
|
@ -1,143 +0,0 @@
|
|||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""
|
||||
The final recovery happens here. Given the graph, reconstruct images.
|
||||
"""
|
||||
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import jax.numpy as jn
|
||||
import jax
|
||||
import collections
|
||||
from PIL import Image
|
||||
|
||||
import jax.experimental.optimizers
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def toimg(x):
|
||||
#x = np.transpose(x,(1,2,0))
|
||||
print(x.shape)
|
||||
img = (x+1)*127.5
|
||||
return Image.fromarray(np.array(img,dtype=np.uint8))
|
||||
|
||||
|
||||
|
||||
def explained_variance(I, private_images, lambdas, encoded_images, public_to_private, return_mat=False):
|
||||
# private images: 100x32x32x3
|
||||
# encoded images: 5000x32x32x3
|
||||
|
||||
public_to_private = jax.nn.softmax(public_to_private,axis=-1)
|
||||
|
||||
# Compute the components from each of the images we know should map onto the same original image.
|
||||
component_1 = jn.dot(public_to_private[0], private_images.reshape((100,-1))).reshape((5000,32,32,3))
|
||||
component_2 = jn.dot(public_to_private[1], private_images.reshape((100,-1))).reshape((5000,32,32,3))
|
||||
|
||||
# Now combine them together to get the variance we can explain
|
||||
merged = component_1 * lambdas[:,0][:,None,None,None] + component_2 * lambdas[:,1][:,None,None,None]
|
||||
|
||||
# And now get the variance we can't explain.
|
||||
# This is the contribution of the public images.
|
||||
# We want this value to be small.
|
||||
|
||||
def keep_smallest_abs(xx1, xx2):
|
||||
t = 0
|
||||
which = (jn.abs(xx1+t) < jn.abs(xx2+t)) + 0.0
|
||||
return xx1 * which + xx2 * (1-which)
|
||||
|
||||
xx1 = jn.abs(encoded) - merged
|
||||
xx2 = -(jn.abs(encoded) + merged)
|
||||
|
||||
xx = keep_smallest_abs(xx1, xx2)
|
||||
unexplained_variance = xx
|
||||
|
||||
|
||||
if return_mat:
|
||||
return unexplained_variance, xx1, xx2
|
||||
|
||||
extra = (1-jn.abs(private_images)).mean()*.05
|
||||
|
||||
return extra + (unexplained_variance**2).mean()
|
||||
|
||||
def setup():
|
||||
global private, imagenet40, encoded, lambdas, using, real_using, pub_using
|
||||
|
||||
# Load all the things we've made.
|
||||
encoded = np.load("data/encryption.npy")
|
||||
labels = np.load("data/label.npy")
|
||||
using = np.load("data/predicted_pairings_80.npy", allow_pickle=True)
|
||||
lambdas = list(np.load("data/predicted_lambdas_80.npy", allow_pickle=True))
|
||||
for x in lambdas:
|
||||
while len(x) < 2:
|
||||
x.append(0)
|
||||
lambdas = np.array(lambdas)
|
||||
|
||||
# Construct the mapping
|
||||
public_to_private_new = np.zeros((2, 5000, 100))
|
||||
|
||||
cs = [0]*100
|
||||
for i,row in enumerate(using):
|
||||
for j,b in enumerate(row[:2]):
|
||||
public_to_private_new[j, i, b] = 1e9
|
||||
cs[b] += 1
|
||||
using = public_to_private_new
|
||||
|
||||
def loss(private, lams, I):
|
||||
return explained_variance(I, private, lams, jn.array(encoded), jn.array(using))
|
||||
|
||||
def make_loss():
|
||||
global vg
|
||||
vg = jax.jit(jax.value_and_grad(loss, argnums=(0,1)))
|
||||
|
||||
def run():
|
||||
priv = np.zeros((100,32,32,3))
|
||||
uusing = np.array(using)
|
||||
lams = np.array(lambdas)
|
||||
|
||||
# Use Adam, because thinking hard is overrated we have magic pixie dust.
|
||||
init_1, opt_update_1, get_params_1 = jax.experimental.optimizers.adam(.01)
|
||||
@jax.jit
|
||||
def update_1(i, opt_state, gs):
|
||||
return opt_update_1(i, gs, opt_state)
|
||||
opt_state_1 = init_1(priv)
|
||||
|
||||
# 1000 iterations of gradient descent is probably enough
|
||||
for i in range(1000):
|
||||
value, grad = vg(priv, lams, i)
|
||||
|
||||
if i%100 == 0:
|
||||
print(value)
|
||||
|
||||
var,_,_ = explained_variance(0, priv, jn.array(lambdas), jn.array(encoded), jn.array(using),
|
||||
return_mat=True)
|
||||
print('unexplained min/max', var.min(), var.max())
|
||||
opt_state_1 = update_1(i, opt_state_1, grad[0])
|
||||
priv = opt_state_1.packed_state[0][0]
|
||||
|
||||
priv -= np.min(priv, axis=(1,2,3), keepdims=True)
|
||||
priv /= np.max(priv, axis=(1,2,3), keepdims=True)
|
||||
priv *= 2
|
||||
priv -= 1
|
||||
|
||||
# Finally save the stored values
|
||||
np.save("data/private_raw.npy", priv)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup()
|
||||
make_loss()
|
||||
run()
|
|
@ -1,66 +0,0 @@
|
|||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Fix the color curves. Use a pre-trained "neural network" with <100 weights.
|
||||
Visually this helps a lot, even if it's not doing much of anything in pactice.
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import os
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
||||
|
||||
import numpy as np
|
||||
import jax.numpy as jn
|
||||
|
||||
import objax
|
||||
|
||||
# Our extremely complicated neural network to re-color the images.
|
||||
# Takes one pixel at a time and fixes the color of that pixel.
|
||||
model = objax.nn.Sequential([objax.nn.Linear(3, 10),
|
||||
objax.functional.relu,
|
||||
objax.nn.Linear(10, 3)
|
||||
])
|
||||
|
||||
# These are the weights.
|
||||
weights = [[-0.09795442, -0.26434848, -0.24964345, -0.11450608, 0.6797288, -0.48435465,
|
||||
0.45307165, -0.31196147, -0.33266315, 0.20486055],
|
||||
[[-0.9056427, 0.02872663, -1.5114126, -0.41024876, -0.98195165, 0.1143966,
|
||||
0.6763464, -0.58654785, -1.797063, -0.2176538, ],
|
||||
[ 1.1941166, 0.15515928, 1.1691351, -0.7256186, 0.8046044, 1.3127686,
|
||||
-0.77297133, -1.1761239, 0.85841715, 0.95545965],
|
||||
[ 0.20092924, 0.57503146, 0.22809981, 1.5288007, -0.94781816, -0.68305916,
|
||||
-0.5245211, 1.4042739, -0.00527458, -1.1462274, ]],
|
||||
[0.15683544, 0.22086962, 0.33100453],
|
||||
[[ 7.7239674e-01, 4.0261227e-01, -9.6466336e-03],
|
||||
[-2.2159107e-01, 1.5123411e-01, 3.4485441e-01],
|
||||
[-1.7618114e+00, -7.1886492e-01, -4.6467595e-02],
|
||||
[ 6.9419539e-01, 6.2531930e-01, 7.2271496e-01],
|
||||
[-1.1913675e+00, -6.7755884e-01, -3.5114303e-01],
|
||||
[ 4.8022485e-01, 1.7145030e-01, 7.4849324e-04],
|
||||
[ 3.8332436e-02, -7.0614147e-01, -5.5127507e-01],
|
||||
[-1.0929481e+00, -1.0268525e+00, -7.0265180e-01],
|
||||
[ 1.4880739e+00, 7.1450096e-01, 2.9102692e-01],
|
||||
[ 7.2846663e-01, 7.1322352e-01, -1.7453632e-01]]]
|
||||
|
||||
for i,(k,v) in enumerate(model.vars().items()):
|
||||
v.assign(jn.array(weights[i]))
|
||||
|
||||
# Do all of the re-coloring
|
||||
predict = objax.Jit(lambda x: model(x, training=False),
|
||||
model.vars())
|
||||
|
||||
out = model(np.load("data/private_raw.npy"))
|
||||
np.save("data/private.npy", out)
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
"""
|
||||
Given the private images, draw them in a 100x100 grid for visualization.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
p = np.load("data/private.npy")
|
||||
|
||||
def toimg(x):
|
||||
print(x.shape)
|
||||
img = (x+1)*127.5
|
||||
img = np.clip(img, 0, 255)
|
||||
img = np.reshape(img, (10, 10, 32, 32, 3))
|
||||
img = np.concatenate(img, axis=2)
|
||||
img = np.concatenate(img, axis=0)
|
||||
img = Image.fromarray(np.array(img,dtype=np.uint8))
|
||||
return img
|
||||
|
||||
toimg(p).save("data/reconstructed.png")
|
||||
|
|
@ -1,129 +0,0 @@
|
|||
## Membership Inference Attacks From First Principles
|
||||
|
||||
This directory contains code to reproduce our paper:
|
||||
|
||||
**"Membership Inference Attacks From First Principles"**
|
||||
https://arxiv.org/abs/2112.03570
|
||||
by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramer.
|
||||
|
||||
|
||||
### INSTALLING
|
||||
|
||||
You will need to install fairly standard dependencies
|
||||
|
||||
`pip install scipy, sklearn, numpy, matplotlib`
|
||||
|
||||
and also some machine learning framework to train models. We train our models
|
||||
with JAX + ObJAX so you will need to follow build instructions for that
|
||||
https://github.com/google/objax
|
||||
https://objax.readthedocs.io/en/latest/installation_setup.html
|
||||
|
||||
|
||||
### RUNNING THE CODE
|
||||
|
||||
#### 1. Train the models
|
||||
|
||||
The first step in our attack is to train shadow models. As a baseline
|
||||
that should give most of the gains in our attack, you should start by
|
||||
training 16 shadow models with the command
|
||||
|
||||
> bash scripts/train_demo.sh
|
||||
|
||||
or if you have multiple GPUs on your machine and want to train these models
|
||||
in parallel, then modify and run
|
||||
|
||||
> bash scripts/train_demo_multigpu.sh
|
||||
|
||||
This will train several CIFAR-10 wide ResNet models to ~91% accuracy each, and
|
||||
will output a bunch of files under the directory exp/cifar10 with structure:
|
||||
|
||||
```
|
||||
exp/cifar10/
|
||||
- experiment_N_of_16
|
||||
-- hparams.json
|
||||
-- keep.npy
|
||||
-- ckpt/
|
||||
--- 0000000100.npz
|
||||
-- tb/
|
||||
```
|
||||
|
||||
#### 2. Perform inference
|
||||
|
||||
Once the models are trained, now it's necessary to perform inference and save
|
||||
the output features for each training example for each model in the dataset.
|
||||
|
||||
> python3 inference.py --logdir=exp/cifar10/
|
||||
|
||||
This will add to the experiment directory a new set of files
|
||||
|
||||
```
|
||||
exp/cifar10/
|
||||
- experiment_N_of_16
|
||||
-- logits/
|
||||
--- 0000000100.npy
|
||||
```
|
||||
|
||||
where this new file has shape (50000, 10) and stores the model's
|
||||
output features for each example.
|
||||
|
||||
|
||||
#### 3. Compute membership inference scores
|
||||
|
||||
Finally we take the output features and generate our logit-scaled membership inference
|
||||
scores for each example for each model.
|
||||
|
||||
> python3 score.py exp/cifar10/
|
||||
|
||||
And this in turn generates a new directory
|
||||
|
||||
```
|
||||
exp/cifar10/
|
||||
- experiment_N_of_16
|
||||
-- scores/
|
||||
--- 0000000100.npy
|
||||
```
|
||||
|
||||
with shape (50000,) storing just our scores.
|
||||
|
||||
|
||||
### PLOTTING THE RESULTS
|
||||
|
||||
Finally we can generate pretty pictures, and run the plotting code
|
||||
|
||||
> python3 plot.py
|
||||
|
||||
which should give (something like) the following output
|
||||
|
||||
|
||||
![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve")
|
||||
|
||||
```
|
||||
Attack Ours (online)
|
||||
AUC 0.6676, Accuracy 0.6077, TPR@0.1%FPR of 0.0169
|
||||
Attack Ours (online, fixed variance)
|
||||
AUC 0.6856, Accuracy 0.6137, TPR@0.1%FPR of 0.0593
|
||||
Attack Ours (offline)
|
||||
AUC 0.5488, Accuracy 0.5500, TPR@0.1%FPR of 0.0130
|
||||
Attack Ours (offline, fixed variance)
|
||||
AUC 0.5549, Accuracy 0.5537, TPR@0.1%FPR of 0.0299
|
||||
Attack Global threshold
|
||||
AUC 0.5921, Accuracy 0.6044, TPR@0.1%FPR of 0.0009
|
||||
```
|
||||
|
||||
where the global threshold attack is the baseline, and our online,
|
||||
online-with-fixed-variance, offline, and offline-with-fixed-variance
|
||||
attack variants are the four other curves. Note that because we only
|
||||
train a few models, the fixed variance variants perform best.
|
||||
|
||||
### Citation
|
||||
|
||||
You can cite this paper with
|
||||
|
||||
```
|
||||
@article{carlini2021membership,
|
||||
title={Membership Inference Attacks From First Principles},
|
||||
author={Carlini, Nicholas and Chien, Steve and Nasr, Milad and Song, Shuang and Terzis, Andreas and Tramer, Florian},
|
||||
journal={arXiv preprint arXiv:2112.03570},
|
||||
year={2021}
|
||||
}
|
||||
```
|
|
@ -1,95 +0,0 @@
|
|||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Optional, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def record_parse(serialized_example: str, image_shape: Tuple[int, int, int]):
|
||||
features = tf.io.parse_single_example(serialized_example,
|
||||
features={'image': tf.io.FixedLenFeature([], tf.string),
|
||||
'label': tf.io.FixedLenFeature([], tf.int64)})
|
||||
image = tf.image.decode_image(features['image']).set_shape(image_shape)
|
||||
image = tf.cast(image, tf.float32) * (2.0 / 255) - 1.0
|
||||
return dict(image=image, label=features['label'])
|
||||
|
||||
|
||||
class DataSet:
|
||||
"""Wrapper for tf.data.Dataset to permit extensions."""
|
||||
|
||||
def __init__(self, data: tf.data.Dataset,
|
||||
image_shape: Tuple[int, int, int],
|
||||
augment_fn: Optional[Callable] = None,
|
||||
parse_fn: Optional[Callable] = record_parse):
|
||||
self.data = data
|
||||
self.parse_fn = parse_fn
|
||||
self.augment_fn = augment_fn
|
||||
self.image_shape = image_shape
|
||||
|
||||
@classmethod
|
||||
def from_arrays(cls, images: np.ndarray, labels: np.ndarray, augment_fn: Optional[Callable] = None):
|
||||
return cls(tf.data.Dataset.from_tensor_slices(dict(image=images, label=labels)), images.shape[1:],
|
||||
augment_fn=augment_fn, parse_fn=None)
|
||||
|
||||
@classmethod
|
||||
def from_files(cls, filenames: List[str],
|
||||
image_shape: Tuple[int, int, int],
|
||||
augment_fn: Optional[Callable],
|
||||
parse_fn: Optional[Callable] = record_parse):
|
||||
filenames_in = filenames
|
||||
filenames = sorted(sum([tf.io.gfile.glob(x) for x in filenames], []))
|
||||
if not filenames:
|
||||
raise ValueError('Empty dataset, files not found:', filenames_in)
|
||||
return cls(tf.data.TFRecordDataset(filenames), image_shape, augment_fn=augment_fn, parse_fn=parse_fn)
|
||||
|
||||
@classmethod
|
||||
def from_tfds(cls, dataset: tf.data.Dataset, image_shape: Tuple[int, int, int],
|
||||
augment_fn: Optional[Callable] = None):
|
||||
return cls(dataset.map(lambda x: dict(image=tf.cast(x['image'], tf.float32) / 127.5 - 1, label=x['label'])),
|
||||
image_shape, augment_fn=augment_fn, parse_fn=None)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.data)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item in self.__dict__:
|
||||
return self.__dict__[item]
|
||||
|
||||
def call_and_update(*args, **kwargs):
|
||||
v = getattr(self.__dict__['data'], item)(*args, **kwargs)
|
||||
if isinstance(v, tf.data.Dataset):
|
||||
return self.__class__(v, self.image_shape, augment_fn=self.augment_fn, parse_fn=self.parse_fn)
|
||||
return v
|
||||
|
||||
return call_and_update
|
||||
|
||||
def augment(self, para_augment: int = 4):
|
||||
if self.augment_fn:
|
||||
return self.map(self.augment_fn, para_augment)
|
||||
return self
|
||||
|
||||
def nchw(self):
|
||||
return self.map(lambda x: dict(image=tf.transpose(x['image'], [0, 3, 1, 2]), label=x['label']))
|
||||
|
||||
def one_hot(self, nclass: int):
|
||||
return self.map(lambda x: dict(image=x['image'], label=tf.one_hot(x['label'], nclass)))
|
||||
|
||||
def parse(self, para_parse: int = 2):
|
||||
if not self.parse_fn:
|
||||
return self
|
||||
if self.image_shape:
|
||||
return self.map(lambda x: self.parse_fn(x, self.image_shape), para_parse)
|
||||
return self.map(self.parse_fn, para_parse)
|
Binary file not shown.
Before Width: | Height: | Size: 37 KiB |
|
@ -1,150 +0,0 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Callable
|
||||
import json
|
||||
|
||||
import re
|
||||
import jax
|
||||
import jax.numpy as jn
|
||||
import numpy as np
|
||||
import tensorflow as tf # For data augmentation.
|
||||
import tensorflow_datasets as tfds
|
||||
from absl import app, flags
|
||||
from tqdm import tqdm, trange
|
||||
import pickle
|
||||
from functools import partial
|
||||
|
||||
import objax
|
||||
from objax.jaxboard import SummaryWriter, Summary
|
||||
from objax.util import EasyDict
|
||||
from objax.zoo import convnet, wide_resnet
|
||||
|
||||
from dataset import DataSet
|
||||
|
||||
from train import MemModule, network
|
||||
|
||||
from collections import defaultdict
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def main(argv):
|
||||
"""
|
||||
Perform inference of the saved model in order to generate the
|
||||
output logits, using a particular set of augmentations.
|
||||
"""
|
||||
del argv
|
||||
tf.config.experimental.set_visible_devices([], "GPU")
|
||||
|
||||
def load(arch):
|
||||
return MemModule(network(arch), nclass=100 if FLAGS.dataset == 'cifar100' else 10,
|
||||
mnist=FLAGS.dataset == 'mnist',
|
||||
arch=arch,
|
||||
lr=.1,
|
||||
batch=0,
|
||||
epochs=0,
|
||||
weight_decay=0)
|
||||
|
||||
def cache_load(arch):
|
||||
thing = []
|
||||
def fn():
|
||||
if len(thing) == 0:
|
||||
thing.append(load(arch))
|
||||
return thing[0]
|
||||
return fn
|
||||
|
||||
xs_all = np.load(os.path.join(FLAGS.logdir,"x_train.npy"))[:FLAGS.dataset_size]
|
||||
ys_all = np.load(os.path.join(FLAGS.logdir,"y_train.npy"))[:FLAGS.dataset_size]
|
||||
|
||||
|
||||
def get_loss(model, xbatch, ybatch, shift, reflect=True, stride=1):
|
||||
|
||||
outs = []
|
||||
for aug in [xbatch, xbatch[:,:,::-1,:]][:reflect+1]:
|
||||
aug_pad = tf.pad(aug, [[0] * 2, [shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT').numpy()
|
||||
for dx in range(0, 2*shift+1, stride):
|
||||
for dy in range(0, 2*shift+1, stride):
|
||||
this_x = aug_pad[:, dx:dx+32, dy:dy+32, :].transpose((0,3,1,2))
|
||||
|
||||
logits = model.model(this_x, training=True)
|
||||
outs.append(logits)
|
||||
|
||||
print(np.array(outs).shape)
|
||||
return np.array(outs).transpose((1, 0, 2))
|
||||
|
||||
N = 5000
|
||||
|
||||
def features(model, xbatch, ybatch):
|
||||
return get_loss(model, xbatch, ybatch,
|
||||
shift=0, reflect=True, stride=1)
|
||||
|
||||
for path in sorted(os.listdir(os.path.join(FLAGS.logdir))):
|
||||
if re.search(FLAGS.regex, path) is None:
|
||||
print("Skipping from regex")
|
||||
continue
|
||||
|
||||
hparams = json.load(open(os.path.join(FLAGS.logdir, path, "hparams.json")))
|
||||
arch = hparams['arch']
|
||||
model = cache_load(arch)()
|
||||
|
||||
logdir = os.path.join(FLAGS.logdir, path)
|
||||
|
||||
checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=10, makedir=True)
|
||||
max_epoch, last_ckpt = checkpoint.restore(model.vars())
|
||||
if max_epoch == 0: continue
|
||||
|
||||
if not os.path.exists(os.path.join(FLAGS.logdir, path, "logits")):
|
||||
os.mkdir(os.path.join(FLAGS.logdir, path, "logits"))
|
||||
if FLAGS.from_epoch is not None:
|
||||
first = FLAGS.from_epoch
|
||||
else:
|
||||
first = max_epoch-1
|
||||
|
||||
for epoch in range(first,max_epoch+1):
|
||||
if not os.path.exists(os.path.join(FLAGS.logdir, path, "ckpt", "%010d.npz"%epoch)):
|
||||
# no checkpoint saved here
|
||||
continue
|
||||
|
||||
if os.path.exists(os.path.join(FLAGS.logdir, path, "logits", "%010d.npy"%epoch)):
|
||||
print("Skipping already generated file", epoch)
|
||||
continue
|
||||
|
||||
try:
|
||||
start_epoch, last_ckpt = checkpoint.restore(model.vars(), epoch)
|
||||
except:
|
||||
print("Fail to load", epoch)
|
||||
continue
|
||||
|
||||
stats = []
|
||||
|
||||
for i in range(0,len(xs_all),N):
|
||||
stats.extend(features(model, xs_all[i:i+N],
|
||||
ys_all[i:i+N]))
|
||||
# This will be shape N, augs, nclass
|
||||
|
||||
np.save(os.path.join(FLAGS.logdir, path, "logits", "%010d"%epoch),
|
||||
np.array(stats)[:,None,:,:])
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
|
||||
flags.DEFINE_string('logdir', 'experiments/', 'Directory where to save checkpoints and tensorboard data.')
|
||||
flags.DEFINE_string('regex', '.*experiment.*', 'keep files when matching')
|
||||
flags.DEFINE_bool('random_labels', False, 'use random labels.')
|
||||
flags.DEFINE_integer('dataset_size', 50000, 'size of dataset.')
|
||||
flags.DEFINE_integer('from_epoch', None, 'which epoch to load from.')
|
||||
flags.DEFINE_integer('seed_mod', None, 'keep mod seed.')
|
||||
flags.DEFINE_integer('modulus', 8, 'modulus.')
|
||||
app.run(main)
|
|
@ -1,224 +0,0 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import scipy.stats
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.metrics import auc, roc_curve
|
||||
import functools
|
||||
|
||||
# Look at me being proactive!
|
||||
import matplotlib
|
||||
matplotlib.rcParams['pdf.fonttype'] = 42
|
||||
matplotlib.rcParams['ps.fonttype'] = 42
|
||||
|
||||
|
||||
def sweep(score, x):
|
||||
"""
|
||||
Compute a ROC curve and then return the FPR, TPR, AUC, and ACC.
|
||||
"""
|
||||
fpr, tpr, _ = roc_curve(x, -score)
|
||||
acc = np.max(1-(fpr+(1-tpr))/2)
|
||||
return fpr, tpr, auc(fpr, tpr), acc
|
||||
|
||||
def load_data(p):
|
||||
"""
|
||||
Load our saved scores and then put them into a big matrix.
|
||||
"""
|
||||
global scores, keep
|
||||
scores = []
|
||||
keep = []
|
||||
|
||||
for root,ds,_ in os.walk(p):
|
||||
for f in ds:
|
||||
if not f.startswith("experiment"): continue
|
||||
if not os.path.exists(os.path.join(root,f,"scores")): continue
|
||||
last_epoch = sorted(os.listdir(os.path.join(root,f,"scores")))
|
||||
if len(last_epoch) == 0: continue
|
||||
scores.append(np.load(os.path.join(root,f,"scores",last_epoch[-1])))
|
||||
keep.append(np.load(os.path.join(root,f,"keep.npy")))
|
||||
|
||||
scores = np.array(scores)
|
||||
keep = np.array(keep)[:,:scores.shape[1]]
|
||||
|
||||
return scores, keep
|
||||
|
||||
def generate_ours(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000,
|
||||
fix_variance=False):
|
||||
"""
|
||||
Fit a two predictive models using keep and scores in order to predict
|
||||
if the examples in check_scores were training data or not, using the
|
||||
ground truth answer from check_keep.
|
||||
"""
|
||||
dat_in = []
|
||||
dat_out = []
|
||||
|
||||
for j in range(scores.shape[1]):
|
||||
dat_in.append(scores[keep[:,j],j,:])
|
||||
dat_out.append(scores[~keep[:,j],j,:])
|
||||
|
||||
in_size = min(min(map(len,dat_in)), in_size)
|
||||
out_size = min(min(map(len,dat_out)), out_size)
|
||||
|
||||
dat_in = np.array([x[:in_size] for x in dat_in])
|
||||
dat_out = np.array([x[:out_size] for x in dat_out])
|
||||
|
||||
mean_in = np.median(dat_in, 1)
|
||||
mean_out = np.median(dat_out, 1)
|
||||
|
||||
if fix_variance:
|
||||
std_in = np.std(dat_in)
|
||||
std_out = np.std(dat_in)
|
||||
else:
|
||||
std_in = np.std(dat_in, 1)
|
||||
std_out = np.std(dat_out, 1)
|
||||
|
||||
prediction = []
|
||||
answers = []
|
||||
for ans, sc in zip(check_keep, check_scores):
|
||||
pr_in = -scipy.stats.norm.logpdf(sc, mean_in, std_in+1e-30)
|
||||
pr_out = -scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30)
|
||||
score = pr_in-pr_out
|
||||
|
||||
prediction.extend(score.mean(1))
|
||||
answers.extend(ans)
|
||||
|
||||
return prediction, answers
|
||||
|
||||
def generate_ours_offline(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000,
|
||||
fix_variance=False):
|
||||
"""
|
||||
Fit a single predictive model using keep and scores in order to predict
|
||||
if the examples in check_scores were training data or not, using the
|
||||
ground truth answer from check_keep.
|
||||
"""
|
||||
dat_in = []
|
||||
dat_out = []
|
||||
|
||||
for j in range(scores.shape[1]):
|
||||
dat_in.append(scores[keep[:, j], j, :])
|
||||
dat_out.append(scores[~keep[:, j], j, :])
|
||||
|
||||
out_size = min(min(map(len,dat_out)), out_size)
|
||||
|
||||
dat_out = np.array([x[:out_size] for x in dat_out])
|
||||
|
||||
mean_out = np.median(dat_out, 1)
|
||||
|
||||
if fix_variance:
|
||||
std_out = np.std(dat_out)
|
||||
else:
|
||||
std_out = np.std(dat_out, 1)
|
||||
|
||||
prediction = []
|
||||
answers = []
|
||||
for ans, sc in zip(check_keep, check_scores):
|
||||
score = scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30)
|
||||
|
||||
prediction.extend(score.mean(1))
|
||||
answers.extend(ans)
|
||||
return prediction, answers
|
||||
|
||||
|
||||
def generate_global(keep, scores, check_keep, check_scores):
|
||||
"""
|
||||
Use a simple global threshold sweep to predict if the examples in
|
||||
check_scores were training data or not, using the ground truth answer from
|
||||
check_keep.
|
||||
"""
|
||||
prediction = []
|
||||
answers = []
|
||||
for ans, sc in zip(check_keep, check_scores):
|
||||
prediction.extend(-sc.mean(1))
|
||||
answers.extend(ans)
|
||||
|
||||
return prediction, answers
|
||||
|
||||
def do_plot(fn, keep, scores, ntest, legend='', metric='auc', sweep_fn=sweep, **plot_kwargs):
|
||||
"""
|
||||
Generate the ROC curves by using ntest models as test models and the rest to train.
|
||||
"""
|
||||
|
||||
prediction, answers = fn(keep[:-ntest],
|
||||
scores[:-ntest],
|
||||
keep[-ntest:],
|
||||
scores[-ntest:])
|
||||
|
||||
fpr, tpr, auc, acc = sweep_fn(np.array(prediction), np.array(answers, dtype=bool))
|
||||
|
||||
low = tpr[np.where(fpr<.001)[0][-1]]
|
||||
|
||||
print('Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc,acc, low))
|
||||
|
||||
metric_text = ''
|
||||
if metric == 'auc':
|
||||
metric_text = 'auc=%.3f'%auc
|
||||
elif metric == 'acc':
|
||||
metric_text = 'acc=%.3f'%acc
|
||||
|
||||
plt.plot(fpr, tpr, label=legend+metric_text, **plot_kwargs)
|
||||
return (acc,auc)
|
||||
|
||||
|
||||
def fig_fpr_tpr():
|
||||
|
||||
plt.figure(figsize=(4,3))
|
||||
|
||||
do_plot(generate_ours,
|
||||
keep, scores, 1,
|
||||
"Ours (online)\n",
|
||||
metric='auc'
|
||||
)
|
||||
|
||||
do_plot(functools.partial(generate_ours, fix_variance=True),
|
||||
keep, scores, 1,
|
||||
"Ours (online, fixed variance)\n",
|
||||
metric='auc'
|
||||
)
|
||||
|
||||
do_plot(functools.partial(generate_ours_offline),
|
||||
keep, scores, 1,
|
||||
"Ours (offline)\n",
|
||||
metric='auc'
|
||||
)
|
||||
|
||||
do_plot(functools.partial(generate_ours_offline, fix_variance=True),
|
||||
keep, scores, 1,
|
||||
"Ours (offline, fixed variance)\n",
|
||||
metric='auc'
|
||||
)
|
||||
|
||||
do_plot(generate_global,
|
||||
keep, scores, 1,
|
||||
"Global threshold\n",
|
||||
metric='auc'
|
||||
)
|
||||
|
||||
plt.semilogx()
|
||||
plt.semilogy()
|
||||
plt.xlim(1e-5,1)
|
||||
plt.ylim(1e-5,1)
|
||||
plt.xlabel("False Positive Rate")
|
||||
plt.ylabel("True Positive Rate")
|
||||
plt.plot([0, 1], [0, 1], ls='--', color='gray')
|
||||
plt.subplots_adjust(bottom=.18, left=.18, top=.96, right=.96)
|
||||
plt.legend(fontsize=8)
|
||||
plt.savefig("/tmp/fprtpr.png")
|
||||
plt.show()
|
||||
|
||||
|
||||
load_data("exp/cifar10/")
|
||||
fig_fpr_tpr()
|
|
@ -1,66 +0,0 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
import os
|
||||
import multiprocessing as mp
|
||||
|
||||
|
||||
def load_one(base):
|
||||
"""
|
||||
This loads a logits and converts it to a scored prediction.
|
||||
"""
|
||||
root = os.path.join(logdir,base,'logits')
|
||||
if not os.path.exists(root): return None
|
||||
|
||||
if not os.path.exists(os.path.join(logdir,base,'scores')):
|
||||
os.mkdir(os.path.join(logdir,base,'scores'))
|
||||
|
||||
for f in os.listdir(root):
|
||||
try:
|
||||
opredictions = np.load(os.path.join(root,f))
|
||||
except:
|
||||
print("Fail")
|
||||
continue
|
||||
|
||||
## Be exceptionally careful.
|
||||
## Numerically stable everything, as described in the paper.
|
||||
predictions = opredictions - np.max(opredictions, axis=3, keepdims=True)
|
||||
predictions = np.array(np.exp(predictions), dtype=np.float64)
|
||||
predictions = predictions/np.sum(predictions,axis=3,keepdims=True)
|
||||
|
||||
COUNT = predictions.shape[0]
|
||||
# x num_examples x num_augmentations x logits
|
||||
y_true = predictions[np.arange(COUNT),:,:,labels[:COUNT]]
|
||||
print(y_true.shape)
|
||||
|
||||
print('mean acc',np.mean(predictions[:,0,0,:].argmax(1)==labels[:COUNT]))
|
||||
|
||||
predictions[np.arange(COUNT),:,:,labels[:COUNT]] = 0
|
||||
y_wrong = np.sum(predictions, axis=3)
|
||||
|
||||
logit = (np.log(y_true.mean((1))+1e-45) - np.log(y_wrong.mean((1))+1e-45))
|
||||
|
||||
np.save(os.path.join(logdir, base, 'scores', f), logit)
|
||||
|
||||
|
||||
def load_stats():
|
||||
with mp.Pool(8) as p:
|
||||
p.map(load_one, [x for x in os.listdir(logdir) if 'exp' in x])
|
||||
|
||||
|
||||
logdir = sys.argv[1]
|
||||
labels = np.load(os.path.join(logdir,"y_train.npy"))
|
||||
load_stats()
|
|
@ -1,16 +0,0 @@
|
|||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15
|
|
@ -1,18 +0,0 @@
|
|||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0 &
|
||||
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1 &
|
||||
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2 &
|
||||
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3 &
|
||||
CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 &
|
||||
CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 &
|
||||
CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 &
|
||||
CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 &
|
||||
wait;
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8 &
|
||||
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9 &
|
||||
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10 &
|
||||
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11 &
|
||||
CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 &
|
||||
CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 &
|
||||
CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 &
|
||||
CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 &
|
||||
wait;
|
|
@ -1,329 +0,0 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import os
|
||||
import shutil
|
||||
from typing import Callable
|
||||
import json
|
||||
|
||||
import jax
|
||||
import jax.numpy as jn
|
||||
import numpy as np
|
||||
import tensorflow as tf # For data augmentation.
|
||||
import tensorflow_datasets as tfds
|
||||
from absl import app, flags
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import objax
|
||||
from objax.jaxboard import SummaryWriter, Summary
|
||||
from objax.util import EasyDict
|
||||
from objax.zoo import convnet, wide_resnet, dnnet
|
||||
|
||||
from dataset import DataSet
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
def augment(x, shift: int, mirror=True):
|
||||
"""
|
||||
Augmentation function used in training the model.
|
||||
"""
|
||||
y = x['image']
|
||||
if mirror:
|
||||
y = tf.image.random_flip_left_right(y)
|
||||
y = tf.pad(y, [[shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT')
|
||||
y = tf.image.random_crop(y, tf.shape(x['image']))
|
||||
return dict(image=y, label=x['label'])
|
||||
|
||||
|
||||
class TrainLoop(objax.Module):
|
||||
"""
|
||||
Training loop for general machine learning models.
|
||||
Based on the training loop from the objax CIFAR10 example code.
|
||||
"""
|
||||
predict: Callable
|
||||
train_op: Callable
|
||||
|
||||
def __init__(self, nclass: int, **kwargs):
|
||||
self.nclass = nclass
|
||||
self.params = EasyDict(kwargs)
|
||||
|
||||
def train_step(self, summary: Summary, data: dict, progress: np.ndarray):
|
||||
kv = self.train_op(progress, data['image'].numpy(), data['label'].numpy())
|
||||
for k, v in kv.items():
|
||||
if jn.isnan(v):
|
||||
raise ValueError('NaN, try reducing learning rate', k)
|
||||
if summary is not None:
|
||||
summary.scalar(k, float(v))
|
||||
|
||||
def train(self, num_train_epochs: int, train_size: int, train: DataSet, test: DataSet, logdir: str, save_steps=100, patience=None):
|
||||
"""
|
||||
Completely standard training. Nothing interesting to see here.
|
||||
"""
|
||||
checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=20, makedir=True)
|
||||
start_epoch, last_ckpt = checkpoint.restore(self.vars())
|
||||
train_iter = iter(train)
|
||||
progress = np.zeros(jax.local_device_count(), 'f') # for multi-GPU
|
||||
|
||||
best_acc = 0
|
||||
best_acc_epoch = -1
|
||||
|
||||
with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard:
|
||||
for epoch in range(start_epoch, num_train_epochs):
|
||||
# Train
|
||||
summary = Summary()
|
||||
loop = range(0, train_size, self.params.batch)
|
||||
for step in loop:
|
||||
progress[:] = (step + (epoch * train_size)) / (num_train_epochs * train_size)
|
||||
self.train_step(summary, next(train_iter), progress)
|
||||
|
||||
# Eval
|
||||
accuracy, total = 0, 0
|
||||
if epoch%FLAGS.eval_steps == 0 and test is not None:
|
||||
for data in test:
|
||||
total += data['image'].shape[0]
|
||||
preds = np.argmax(self.predict(data['image'].numpy()), axis=1)
|
||||
accuracy += (preds == data['label'].numpy()).sum()
|
||||
accuracy /= total
|
||||
summary.scalar('eval/accuracy', 100 * accuracy)
|
||||
tensorboard.write(summary, step=(epoch + 1) * train_size)
|
||||
print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, summary['losses/xe'](),
|
||||
summary['eval/accuracy']()))
|
||||
|
||||
if summary['eval/accuracy']() > best_acc:
|
||||
best_acc = summary['eval/accuracy']()
|
||||
best_acc_epoch = epoch
|
||||
elif patience is not None and epoch > best_acc_epoch + patience:
|
||||
print("early stopping!")
|
||||
checkpoint.save(self.vars(), epoch + 1)
|
||||
return
|
||||
|
||||
else:
|
||||
print('Epoch %04d Loss %.2f Accuracy --' % (epoch + 1, summary['losses/xe']()))
|
||||
|
||||
if epoch%save_steps == save_steps-1:
|
||||
checkpoint.save(self.vars(), epoch + 1)
|
||||
|
||||
|
||||
# We inherit from the training loop and define predict and train_op.
|
||||
class MemModule(TrainLoop):
|
||||
def __init__(self, model: Callable, nclass: int, mnist=False, **kwargs):
|
||||
"""
|
||||
Completely standard training. Nothing interesting to see here.
|
||||
"""
|
||||
super().__init__(nclass, **kwargs)
|
||||
self.model = model(1 if mnist else 3, nclass)
|
||||
self.opt = objax.optimizer.Momentum(self.model.vars())
|
||||
self.model_ema = objax.optimizer.ExponentialMovingAverageModule(self.model, momentum=0.999, debias=True)
|
||||
|
||||
@objax.Function.with_vars(self.model.vars())
|
||||
def loss(x, label):
|
||||
logit = self.model(x, training=True)
|
||||
loss_wd = 0.5 * sum((v.value ** 2).sum() for k, v in self.model.vars().items() if k.endswith('.w'))
|
||||
loss_xe = objax.functional.loss.cross_entropy_logits(logit, label).mean()
|
||||
return loss_xe + loss_wd * self.params.weight_decay, {'losses/xe': loss_xe, 'losses/wd': loss_wd}
|
||||
|
||||
gv = objax.GradValues(loss, self.model.vars())
|
||||
self.gv = gv
|
||||
|
||||
@objax.Function.with_vars(self.vars())
|
||||
def train_op(progress, x, y):
|
||||
g, v = gv(x, y)
|
||||
lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8))
|
||||
lr = lr * jn.clip(progress*100,0,1)
|
||||
self.opt(lr, g)
|
||||
self.model_ema.update_ema()
|
||||
return {'monitors/lr': lr, **v[1]}
|
||||
|
||||
self.predict = objax.Jit(objax.nn.Sequential([objax.ForceArgs(self.model_ema, training=False)]))
|
||||
|
||||
self.train_op = objax.Jit(train_op)
|
||||
|
||||
|
||||
def network(arch: str):
|
||||
if arch == 'cnn32-3-max':
|
||||
return functools.partial(convnet.ConvNet, scales=3, filters=32, filters_max=1024,
|
||||
pooling=objax.functional.max_pool_2d)
|
||||
elif arch == 'cnn32-3-mean':
|
||||
return functools.partial(convnet.ConvNet, scales=3, filters=32, filters_max=1024,
|
||||
pooling=objax.functional.average_pool_2d)
|
||||
elif arch == 'cnn64-3-max':
|
||||
return functools.partial(convnet.ConvNet, scales=3, filters=64, filters_max=1024,
|
||||
pooling=objax.functional.max_pool_2d)
|
||||
elif arch == 'cnn64-3-mean':
|
||||
return functools.partial(convnet.ConvNet, scales=3, filters=64, filters_max=1024,
|
||||
pooling=objax.functional.average_pool_2d)
|
||||
elif arch == 'wrn28-1':
|
||||
return functools.partial(wide_resnet.WideResNet, depth=28, width=1)
|
||||
elif arch == 'wrn28-2':
|
||||
return functools.partial(wide_resnet.WideResNet, depth=28, width=2)
|
||||
elif arch == 'wrn28-10':
|
||||
return functools.partial(wide_resnet.WideResNet, depth=28, width=10)
|
||||
raise ValueError('Architecture not recognized', arch)
|
||||
|
||||
def get_data(seed):
|
||||
"""
|
||||
This is the function to generate subsets of the data for training models.
|
||||
|
||||
First, we get the training dataset either from the numpy cache
|
||||
or otherwise we load it from tensorflow datasets.
|
||||
|
||||
Then, we compute the subset. This works in one of two ways.
|
||||
|
||||
1. If we have a seed, then we just randomly choose examples based on
|
||||
a prng with that seed, keeping FLAGS.pkeep fraction of the data.
|
||||
|
||||
2. Otherwise, if we have an experiment ID, then we do something fancier.
|
||||
If we run each experiment independently then even after a lot of trials
|
||||
there will still probably be some examples that were always included
|
||||
or always excluded. So instead, with experiment IDs, we guarantee that
|
||||
after FLAGS.num_experiments are done, each example is seen exactly half
|
||||
of the time in train, and half of the time not in train.
|
||||
|
||||
"""
|
||||
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
|
||||
|
||||
if os.path.exists(os.path.join(FLAGS.logdir, "x_train.npy")):
|
||||
inputs = np.load(os.path.join(FLAGS.logdir, "x_train.npy"))
|
||||
labels = np.load(os.path.join(FLAGS.logdir, "y_train.npy"))
|
||||
else:
|
||||
print("First time, creating dataset")
|
||||
data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR))
|
||||
inputs = data['train']['image']
|
||||
labels = data['train']['label']
|
||||
|
||||
inputs = (inputs/127.5)-1
|
||||
np.save(os.path.join(FLAGS.logdir, "x_train.npy"),inputs)
|
||||
np.save(os.path.join(FLAGS.logdir, "y_train.npy"),labels)
|
||||
|
||||
nclass = np.max(labels)+1
|
||||
|
||||
np.random.seed(seed)
|
||||
if FLAGS.num_experiments is not None:
|
||||
np.random.seed(0)
|
||||
keep = np.random.uniform(0,1,size=(FLAGS.num_experiments, FLAGS.dataset_size))
|
||||
order = keep.argsort(0)
|
||||
keep = order < int(FLAGS.pkeep * FLAGS.num_experiments)
|
||||
keep = np.array(keep[FLAGS.expid], dtype=bool)
|
||||
else:
|
||||
keep = np.random.uniform(0, 1, size=FLAGS.dataset_size) <= FLAGS.pkeep
|
||||
|
||||
if FLAGS.only_subset is not None:
|
||||
keep[FLAGS.only_subset:] = 0
|
||||
|
||||
xs = inputs[keep]
|
||||
ys = labels[keep]
|
||||
|
||||
if FLAGS.augment == 'weak':
|
||||
aug = lambda x: augment(x, 4)
|
||||
elif FLAGS.augment == 'mirror':
|
||||
aug = lambda x: augment(x, 0)
|
||||
elif FLAGS.augment == 'none':
|
||||
aug = lambda x: augment(x, 0, mirror=False)
|
||||
else:
|
||||
raise
|
||||
|
||||
train = DataSet.from_arrays(xs, ys,
|
||||
augment_fn=aug)
|
||||
test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:])
|
||||
train = train.cache().shuffle(8192).repeat().parse().augment().batch(FLAGS.batch)
|
||||
train = train.nchw().one_hot(nclass).prefetch(16)
|
||||
test = test.cache().parse().batch(FLAGS.batch).nchw().prefetch(16)
|
||||
|
||||
return train, test, xs, ys, keep, nclass
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
tf.config.experimental.set_visible_devices([], "GPU")
|
||||
|
||||
seed = FLAGS.seed
|
||||
if seed is None:
|
||||
import time
|
||||
seed = np.random.randint(0, 1000000000)
|
||||
seed ^= int(time.time())
|
||||
|
||||
args = EasyDict(arch=FLAGS.arch,
|
||||
lr=FLAGS.lr,
|
||||
batch=FLAGS.batch,
|
||||
weight_decay=FLAGS.weight_decay,
|
||||
augment=FLAGS.augment,
|
||||
seed=seed)
|
||||
|
||||
|
||||
if FLAGS.tunename:
|
||||
logdir = '_'.join(sorted('%s=%s' % k for k in args.items()))
|
||||
elif FLAGS.expid is not None:
|
||||
logdir = "experiment-%d_%d"%(FLAGS.expid,FLAGS.num_experiments)
|
||||
else:
|
||||
logdir = "experiment-"+str(seed)
|
||||
logdir = os.path.join(FLAGS.logdir, logdir)
|
||||
|
||||
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%10)):
|
||||
print(f"run {FLAGS.expid} already completed.")
|
||||
return
|
||||
else:
|
||||
if os.path.exists(logdir):
|
||||
print(f"deleting run {FLAGS.expid} that did not complete.")
|
||||
shutil.rmtree(logdir)
|
||||
|
||||
print(f"starting run {FLAGS.expid}.")
|
||||
if not os.path.exists(logdir):
|
||||
os.makedirs(logdir)
|
||||
|
||||
train, test, xs, ys, keep, nclass = get_data(seed)
|
||||
|
||||
# Define the network and train_it
|
||||
tm = MemModule(network(FLAGS.arch), nclass=nclass,
|
||||
mnist=FLAGS.dataset == 'mnist',
|
||||
epochs=FLAGS.epochs,
|
||||
expid=FLAGS.expid,
|
||||
num_experiments=FLAGS.num_experiments,
|
||||
pkeep=FLAGS.pkeep,
|
||||
save_steps=FLAGS.save_steps,
|
||||
only_subset=FLAGS.only_subset,
|
||||
**args
|
||||
)
|
||||
|
||||
r = {}
|
||||
r.update(tm.params)
|
||||
|
||||
open(os.path.join(logdir,'hparams.json'),"w").write(json.dumps(tm.params))
|
||||
np.save(os.path.join(logdir,'keep.npy'), keep)
|
||||
|
||||
tm.train(FLAGS.epochs, len(xs), train, test, logdir,
|
||||
save_steps=FLAGS.save_steps, patience=FLAGS.patience)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.')
|
||||
flags.DEFINE_float('lr', 0.1, 'Learning rate.')
|
||||
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
|
||||
flags.DEFINE_float('weight_decay', 0.0005, 'Weight decay ratio.')
|
||||
flags.DEFINE_integer('batch', 256, 'Batch size')
|
||||
flags.DEFINE_integer('epochs', 501, 'Training duration in number of epochs.')
|
||||
flags.DEFINE_string('logdir', 'experiments', 'Directory where to save checkpoints and tensorboard data.')
|
||||
flags.DEFINE_integer('seed', None, 'Training seed.')
|
||||
flags.DEFINE_float('pkeep', .5, 'Probability to keep examples.')
|
||||
flags.DEFINE_integer('expid', None, 'Experiment ID')
|
||||
flags.DEFINE_integer('num_experiments', None, 'Number of experiments')
|
||||
flags.DEFINE_string('augment', 'weak', 'Strong or weak augmentation')
|
||||
flags.DEFINE_integer('only_subset', None, 'Only train on a subset of images.')
|
||||
flags.DEFINE_integer('dataset_size', 50000, 'number of examples to keep.')
|
||||
flags.DEFINE_integer('eval_steps', 1, 'how often to get eval accuracy.')
|
||||
flags.DEFINE_integer('abort_after_epoch', None, 'stop trainin early at an epoch')
|
||||
flags.DEFINE_integer('save_steps', 10, 'how often to get save model.')
|
||||
flags.DEFINE_integer('patience', None, 'Early stopping after this many epochs without progress')
|
||||
flags.DEFINE_bool('tunename', False, 'Use tune name?')
|
||||
app.run(main)
|
|
@ -1,712 +0,0 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# This program solves the NeuraCrypt challenge to 100% accuracy.
|
||||
# Given a set of encoded images and original versions of those,
|
||||
# it shows how to match the original to the encoded.
|
||||
|
||||
import collections
|
||||
import hashlib
|
||||
import time
|
||||
import multiprocessing as mp
|
||||
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import scipy.stats
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
|
||||
import jax
|
||||
import jax.numpy as jn
|
||||
import objax
|
||||
import scipy.optimize
|
||||
import numpy as np
|
||||
import multiprocessing as mp
|
||||
|
||||
# Objax neural network that's going to embed patches to a
|
||||
# low dimensional space to guess if two patches correspond
|
||||
# to the same orginal image.
|
||||
class Model(objax.Module):
|
||||
def __init__(self):
|
||||
IN = 15
|
||||
H = 64
|
||||
self.encoder =objax.nn.Sequential([
|
||||
objax.nn.Linear(IN, H),
|
||||
objax.functional.leaky_relu,
|
||||
objax.nn.Linear(H, H),
|
||||
objax.functional.leaky_relu,
|
||||
objax.nn.Linear(H, 8)])
|
||||
self.decoder =objax.nn.Sequential([
|
||||
objax.nn.Linear(IN, H),
|
||||
objax.functional.leaky_relu,
|
||||
objax.nn.Linear(H, H),
|
||||
objax.functional.leaky_relu,
|
||||
objax.nn.Linear(H, 8)])
|
||||
self.scale = objax.nn.Linear(1, 1, use_bias=False)
|
||||
def encode(self, x):
|
||||
# Encode turns original images into feature space
|
||||
a = self.encoder(x)
|
||||
a = a/jn.sum(a**2,axis=-1,keepdims=True)**.5
|
||||
return a
|
||||
def decode(self, x):
|
||||
# And decode turns encoded images into feature space
|
||||
a = self.decoder(x)
|
||||
a = a/jn.sum(a**2,axis=-1,keepdims=True)**.5
|
||||
return a
|
||||
|
||||
# Proxy dataset for analysis
|
||||
class ImageNet:
|
||||
num_chan = 3
|
||||
private_kernel_size = 16
|
||||
hidden_dim = 2048
|
||||
img_size = (256, 256)
|
||||
private_depth = 7
|
||||
def __init__(self, remove):
|
||||
self.remove_pixel_shuffle = remove
|
||||
|
||||
# Original dataset as used in the NeuraCrypt paper
|
||||
class Xray:
|
||||
num_chan = 1
|
||||
private_kernel_size = 16
|
||||
hidden_dim = 2048
|
||||
img_size = (256, 256)
|
||||
private_depth = 4
|
||||
def __init__(self, remove):
|
||||
self.remove_pixel_shuffle = remove
|
||||
|
||||
## The following class is taken directly from the NeuraCrypt codebase.
|
||||
## https://github.com/yala/NeuraCrypt
|
||||
## which is originally licensed under the MIT License
|
||||
class PrivateEncoder(nn.Module):
|
||||
def __init__(self, args, width_factor=1):
|
||||
super(PrivateEncoder, self).__init__()
|
||||
self.args = args
|
||||
input_dim = args.num_chan
|
||||
patch_size = args.private_kernel_size
|
||||
output_dim = args.hidden_dim
|
||||
num_patches = (args.img_size[0] // patch_size) **2
|
||||
self.noise_size = 1
|
||||
|
||||
args.input_dim = args.hidden_dim
|
||||
|
||||
|
||||
layers = [
|
||||
nn.Conv2d(input_dim, output_dim * width_factor, kernel_size=patch_size, dilation=1 ,stride=patch_size),
|
||||
nn.ReLU()
|
||||
]
|
||||
for _ in range(self.args.private_depth):
|
||||
layers.extend( [
|
||||
nn.Conv2d(output_dim * width_factor, output_dim * width_factor , kernel_size=1, dilation=1, stride=1),
|
||||
nn.BatchNorm2d(output_dim * width_factor, track_running_stats=False),
|
||||
nn.ReLU()
|
||||
])
|
||||
|
||||
|
||||
self.image_encoder = nn.Sequential(*layers)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, output_dim * width_factor))
|
||||
|
||||
self.mixer = nn.Sequential( *[
|
||||
nn.ReLU(),
|
||||
nn.Linear(output_dim * width_factor, output_dim)
|
||||
])
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
encoded = self.image_encoder(x)
|
||||
B, C, H,W = encoded.size()
|
||||
encoded = encoded.view([B, -1, H*W]).transpose(1,2)
|
||||
encoded += self.pos_embedding
|
||||
encoded = self.mixer(encoded)
|
||||
|
||||
## Shuffle indicies
|
||||
if not self.args.remove_pixel_shuffle:
|
||||
shuffled = torch.zeros_like(encoded)
|
||||
for i in range(B):
|
||||
idx = torch.randperm(H*W, device=encoded.device)
|
||||
for j, k in enumerate(idx):
|
||||
shuffled[i,j] = encoded[i,k]
|
||||
encoded = shuffled
|
||||
|
||||
return encoded
|
||||
## End copied code
|
||||
|
||||
def setup(ds):
|
||||
"""
|
||||
Load the datasets to use. Nothing interesting to see.
|
||||
"""
|
||||
global x_train, y_train
|
||||
if ds == 'imagenet':
|
||||
import torchvision
|
||||
transform = torchvision.transforms.Compose([
|
||||
torchvision.transforms.Resize(256),
|
||||
torchvision.transforms.CenterCrop(256),
|
||||
torchvision.transforms.ToTensor()])
|
||||
imagenet_data = torchvision.datasets.ImageNet('/mnt/data/datasets/unpacked_imagenet_pytorch/',
|
||||
split='val',
|
||||
transform=transform)
|
||||
data_loader = torch.utils.data.DataLoader(imagenet_data,
|
||||
batch_size=100,
|
||||
shuffle=True,
|
||||
num_workers=8)
|
||||
r = []
|
||||
for x,_ in data_loader:
|
||||
if len(r) > 1000: break
|
||||
print(x.shape)
|
||||
r.extend(x.numpy())
|
||||
x_train = np.array(r)
|
||||
print(x_train.shape)
|
||||
elif ds == 'xray':
|
||||
import torchvision
|
||||
transform = torchvision.transforms.Compose([
|
||||
torchvision.transforms.Resize(256),
|
||||
torchvision.transforms.CenterCrop(256),
|
||||
torchvision.transforms.ToTensor()])
|
||||
imagenet_data = torchvision.datasets.ImageFolder('CheXpert-v1.0/train',
|
||||
transform=transform)
|
||||
data_loader = torch.utils.data.DataLoader(imagenet_data,
|
||||
batch_size=100,
|
||||
shuffle=True,
|
||||
num_workers=8)
|
||||
r = []
|
||||
for x,_ in data_loader:
|
||||
if len(r) > 1000: break
|
||||
print(x.shape)
|
||||
r.extend(x.numpy())
|
||||
x_train = np.array(r)
|
||||
print(x_train.shape)
|
||||
elif ds == 'challenge':
|
||||
x_train = np.load("orig-7.npy")
|
||||
print(np.min(x_train), np.max(x_train), x_train.shape)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def gen_train_data():
|
||||
"""
|
||||
Generate aligned training data to train a patch similarity function.
|
||||
Given some original images, generate lots of encoded versions.
|
||||
"""
|
||||
global encoded_train, original_train
|
||||
|
||||
encoded_train = []
|
||||
original_train = []
|
||||
|
||||
args = Xray(True)
|
||||
|
||||
C = 100
|
||||
for i in range(30):
|
||||
print(i)
|
||||
torch.manual_seed(int(time.time()))
|
||||
e = PrivateEncoder(args).cuda()
|
||||
batch = np.random.randint(0, len(x_train), size=C)
|
||||
xin = x_train[batch]
|
||||
|
||||
r = []
|
||||
for i in range(0,C,32):
|
||||
r.extend(e(torch.tensor(xin[i:i+32]).cuda()).detach().cpu().numpy())
|
||||
r = np.array(r)
|
||||
|
||||
encoded_train.append(r)
|
||||
original_train.append(xin)
|
||||
|
||||
def features_(x, moments=15, encoded=False):
|
||||
"""
|
||||
Compute higher-order moments for patches in an image to use as
|
||||
features for the neural network.
|
||||
"""
|
||||
x = np.array(x, dtype=np.float32)
|
||||
dim = 2
|
||||
arr = np.array([np.mean(x, dim)] + [abs(scipy.stats.moment(x, moment=i, axis=dim))**(1/i) for i in range(1,moments)])
|
||||
|
||||
return arr.transpose((1,2,0))
|
||||
|
||||
|
||||
def features(x, encoded):
|
||||
"""
|
||||
Given the original images or the encoded images, generate the
|
||||
features to use for the patch similarity function.
|
||||
"""
|
||||
print('start shape',x.shape)
|
||||
if len(x.shape) == 3:
|
||||
x = x - np.mean(x,axis=0,keepdims=True)
|
||||
else:
|
||||
# count x 100 x 256 x 768
|
||||
print(x[0].shape)
|
||||
x = x - np.mean(x,axis=1,keepdims=True)
|
||||
# remove per-neural-network dimension
|
||||
x = x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
|
||||
p = mp.Pool(96)
|
||||
B = len(x)//96
|
||||
print(1)
|
||||
bs = [x[i:i+B] for i in range(0,len(x),B)]
|
||||
print(2)
|
||||
r = p.map(features_, bs)
|
||||
#r = features_(bs[0][:100])
|
||||
print(3)
|
||||
p.close()
|
||||
#r = np.array(r)
|
||||
#print('finish',r.shape)
|
||||
return np.concatenate(r, axis=0)
|
||||
|
||||
|
||||
|
||||
def get_train_features():
|
||||
"""
|
||||
Create features for the entire datasets.
|
||||
"""
|
||||
global xs_train, ys_train
|
||||
print(x_train.shape)
|
||||
original_train_ = np.array(original_train)
|
||||
encoded_train_ = np.array(encoded_train)
|
||||
|
||||
print("Computing features")
|
||||
ys_train = features(encoded_train_, True)
|
||||
|
||||
patch_size = 16
|
||||
ss = original_train_.shape[3]//patch_size
|
||||
# Okay so this is an ugly transpose block.
|
||||
# We are going from [outer_batch, batch_size, channels, width, height
|
||||
# to [outer_batch, batch_size, channels, width/patch_size, patch_size, height/patch_size, patch_size]
|
||||
# Then we reshape this and flatten so that we end up with
|
||||
# [other_batch, batch_size, width/patch_size, height_patch_size, patch_size**2*channels]
|
||||
# So that now we can run features on the last dimension
|
||||
original_train_ = original_train_.reshape((original_train_.shape[0],
|
||||
original_train_.shape[1],
|
||||
original_train_.shape[2],
|
||||
ss,patch_size,ss,patch_size)).transpose((0,1,3,5,2,4,6)).reshape((original_train_.shape[0], original_train_.shape[1], ss**2, patch_size**2))
|
||||
|
||||
|
||||
xs_train = features(original_train_, False)
|
||||
|
||||
print(xs_train.shape, ys_train.shape)
|
||||
|
||||
|
||||
def train_model():
|
||||
"""
|
||||
Train the patch similarity function
|
||||
"""
|
||||
global ema, model
|
||||
|
||||
model = Model()
|
||||
def loss(x, y):
|
||||
"""
|
||||
K-way contrastive loss as in SimCLR et al.
|
||||
The idea is that we should embed x and y so that they are similar
|
||||
to each other, and dis-similar from others. To do this we have a
|
||||
softmx loss over one dimension to make the values large on the diagonal
|
||||
and small off-diagonal.
|
||||
"""
|
||||
a = model.encode(x)
|
||||
b = model.decode(y)
|
||||
|
||||
mat = a@b.T
|
||||
return objax.functional.loss.cross_entropy_logits_sparse(
|
||||
logits=jn.exp(jn.clip(model.scale.w.value, -2, 4)) * mat,
|
||||
labels=np.arange(a.shape[0])).mean()
|
||||
|
||||
ema = objax.optimizer.ExponentialMovingAverage(model.vars(), momentum=0.999)
|
||||
gv = objax.GradValues(loss, model.vars())
|
||||
|
||||
encode_ema = ema.replace_vars(lambda x: model.encode(x))
|
||||
decode_ema = ema.replace_vars(lambda y: model.decode(y))
|
||||
|
||||
def train_op(x, y):
|
||||
"""
|
||||
No one was ever fired for using Adam with 1e-4.
|
||||
"""
|
||||
g, v = gv(x, y)
|
||||
opt(1e-4, g)
|
||||
ema()
|
||||
return v
|
||||
|
||||
opt = objax.optimizer.Adam(model.vars())
|
||||
train_op = objax.Jit(train_op, gv.vars() + opt.vars() + ema.vars())
|
||||
|
||||
ys_ = ys_train
|
||||
|
||||
print(ys_.shape)
|
||||
|
||||
xs_ = xs_train.reshape((-1, xs_train.shape[-1]))
|
||||
ys_ = ys_.reshape((-1, ys_train.shape[-1]))
|
||||
|
||||
# The model scale trick here is taken from CLIP.
|
||||
# Let the model decide how confident to make its own predictions.
|
||||
model.scale.w.assign(jn.zeros((1,1)))
|
||||
|
||||
valid_size = 1000
|
||||
|
||||
print(xs_train.shape)
|
||||
# SimCLR likes big batches
|
||||
B = 4096
|
||||
for it in range(80):
|
||||
print()
|
||||
ms = []
|
||||
for i in range(1000):
|
||||
# First batch is smaller, to make training more stable
|
||||
bs = [B//64, B][it>0]
|
||||
batch = np.random.randint(0, len(xs_)-valid_size, size=bs)
|
||||
r = train_op(xs_[batch], ys_[batch])
|
||||
|
||||
# This shouldn't happen, but if it does, better to bort early
|
||||
if np.isnan(r):
|
||||
print("Die on nan")
|
||||
print(ms[-100:])
|
||||
return
|
||||
ms.append(r)
|
||||
|
||||
print('mean',np.mean(ms), 'scale', model.scale.w.value)
|
||||
print('loss',loss(xs_[-100:], ys_[-100:]))
|
||||
|
||||
a = encode_ema(xs_[-valid_size:])
|
||||
b = decode_ema(ys_[-valid_size:])
|
||||
|
||||
br = b[np.random.permutation(len(b))]
|
||||
|
||||
print('score',np.mean(np.sum(a*b,axis=(1)) - np.sum(a*br,axis=(1))),
|
||||
np.mean(np.sum(a*b,axis=(1)) > np.sum(a*br,axis=(1))))
|
||||
ckpt = objax.io.Checkpoint("saved", keep_ckpts=0)
|
||||
ema.replace_vars(lambda: ckpt.save(model.vars(), 0))()
|
||||
|
||||
|
||||
|
||||
def load_challenge():
|
||||
"""
|
||||
Load the challenge datast for attacking
|
||||
"""
|
||||
global xs, ys, encoded, original, ooriginal
|
||||
print("SETUP: Loading matrixes")
|
||||
# The encoded images
|
||||
encoded = np.load("challenge-7.npy")
|
||||
# And the original images
|
||||
ooriginal = original = np.load("orig-7.npy")
|
||||
|
||||
print("Sizes", encoded.shape, ooriginal.shape)
|
||||
|
||||
# Again do that ugly resize thing to make the features be on the last dimension
|
||||
# Look up above to see what's going on.
|
||||
patch_size = 16
|
||||
ss = original.shape[2]//patch_size
|
||||
original = ooriginal.reshape((original.shape[0],1,ss,patch_size,ss,patch_size))
|
||||
original = original.transpose((0,2,4,1,3,5))
|
||||
original = original.reshape((original.shape[0], ss**2, patch_size**2))
|
||||
|
||||
|
||||
def match_sub(args):
|
||||
"""
|
||||
Find the best way to undo the permutation between two images.
|
||||
"""
|
||||
vec1, vec2 = args
|
||||
value = np.sum((vec1[None,:,:] - vec2[:,None,:])**2,axis=2)
|
||||
row, col = scipy.optimize.linear_sum_assignment(value)
|
||||
return col
|
||||
|
||||
|
||||
def recover_local_permutation():
|
||||
"""
|
||||
Given a set of encoded images, return a new encoding without permutations
|
||||
"""
|
||||
global encoded, ys
|
||||
|
||||
p = mp.Pool(96)
|
||||
print('recover local')
|
||||
local_perm = p.map(match_sub, [(encoded[0], e) for e in encoded])
|
||||
local_perm = np.array(local_perm)
|
||||
|
||||
encoded_perm = []
|
||||
|
||||
for i in range(len(encoded)):
|
||||
encoded_perm.append(encoded[i][np.argsort(local_perm[i])])
|
||||
|
||||
encoded_perm = np.array(encoded_perm)
|
||||
|
||||
encoded = np.array(encoded_perm)
|
||||
|
||||
p.close()
|
||||
|
||||
|
||||
def recover_better_local_permutation():
|
||||
"""
|
||||
Given a set of encoded images, return a new encoding, but better!
|
||||
"""
|
||||
global encoded, ys
|
||||
|
||||
# Now instead of pairing all images to image 0, we compute the mean l2 vector
|
||||
# and then pair all images onto the mean vector. Slightly more noise resistant.
|
||||
p = mp.Pool(96)
|
||||
target = encoded.mean(0)
|
||||
local_perm = p.map(match_sub, [(target, e) for e in encoded])
|
||||
local_perm = np.array(local_perm)
|
||||
|
||||
# Probably we didn't change by much, generally <0.1%
|
||||
print('improved changed by', np.mean(local_perm != np.arange(local_perm.shape[1])))
|
||||
|
||||
encoded_perm = []
|
||||
|
||||
for i in range(len(encoded)):
|
||||
encoded_perm.append(encoded[i][np.argsort(local_perm[i])])
|
||||
|
||||
encoded = np.array(encoded_perm)
|
||||
|
||||
p.close()
|
||||
|
||||
|
||||
def compute_patch_similarity():
|
||||
"""
|
||||
Compute the feature vectors for each patch using the trained neural network.
|
||||
"""
|
||||
global xs, ys, xs_image, ys_image
|
||||
|
||||
print("Computing features")
|
||||
ys = features(encoded, encoded=True)
|
||||
xs = features(original, encoded=False)
|
||||
|
||||
model = Model()
|
||||
ckpt = objax.io.Checkpoint("saved", keep_ckpts=0)
|
||||
ckpt.restore(model.vars())
|
||||
|
||||
xs_image = model.encode(xs)
|
||||
ys_image = model.decode(ys)
|
||||
assert xs.shape[0] == xs_image.shape[0]
|
||||
print("Done")
|
||||
|
||||
|
||||
def match(args, ret_col=False):
|
||||
"""
|
||||
Compute the similarity between image features and encoded features.
|
||||
"""
|
||||
vec1, vec2s = args
|
||||
r = []
|
||||
open("/tmp/start%d.%d"%(np.random.randint(10000),time.time()),"w").write("hi")
|
||||
for vec2 in vec2s:
|
||||
value = np.sum(vec1[None,:,:] * vec2[:,None,:],axis=2)
|
||||
|
||||
row, col = scipy.optimize.linear_sum_assignment(-value)
|
||||
r.append(value[row,col].mean())
|
||||
return r
|
||||
|
||||
|
||||
|
||||
def recover_global_matching_first():
|
||||
"""
|
||||
Recover the global matching of original to encoded images by doing
|
||||
an all-pairs matching problem
|
||||
"""
|
||||
global global_matching, ys_image, encoded
|
||||
|
||||
matrix = []
|
||||
p = mp.Pool(96)
|
||||
xs_image_ = np.array(xs_image)
|
||||
ys_image_ = np.array(ys_image)
|
||||
|
||||
matrix = p.map(match, [(x, ys_image_) for x in xs_image_])
|
||||
matrix = np.array(matrix).reshape((xs_image.shape[0],
|
||||
xs_image.shape[0]))
|
||||
|
||||
|
||||
row, col = scipy.optimize.linear_sum_assignment(-np.array(matrix))
|
||||
global_matching = np.argsort(col)
|
||||
print('glob',list(global_matching))
|
||||
|
||||
p.close()
|
||||
|
||||
|
||||
|
||||
def recover_global_permutation():
|
||||
"""
|
||||
Find the way that the encoded images are permuted off of the original images
|
||||
"""
|
||||
global global_permutation
|
||||
|
||||
print("Glob match", global_matching)
|
||||
overall = []
|
||||
for i,j in enumerate(global_matching):
|
||||
overall.append(np.sum(xs_image[j][None,:,:] * ys_image[i][:,None,:],axis=2))
|
||||
|
||||
overall = np.mean(overall, 0)
|
||||
|
||||
row, col = scipy.optimize.linear_sum_assignment(-overall)
|
||||
|
||||
try:
|
||||
print("Changed frac:", np.mean(global_permutation!=np.argsort(col)))
|
||||
except:
|
||||
pass
|
||||
|
||||
global_permutation = np.argsort(col)
|
||||
|
||||
|
||||
def recover_global_matching_second():
|
||||
"""
|
||||
Match each encoded image with its original encoded image,
|
||||
but better by relying on the global permutation.
|
||||
"""
|
||||
global global_matching_second, global_matching
|
||||
|
||||
ys_fix = []
|
||||
for i in range(ys_image.shape[0]):
|
||||
ys_fix.append(ys_image[i][global_permutation])
|
||||
ys_fix = np.array(ys_fix)
|
||||
|
||||
|
||||
print(xs_image.shape)
|
||||
|
||||
sims = []
|
||||
for i in range(0,len(xs_image),10):
|
||||
tmp = np.mean(xs_image[None,:,:,:] * ys_fix[i:i+10][:,None,:,:],axis=(2,3))
|
||||
sims.extend(tmp)
|
||||
sims = np.array(sims)
|
||||
print(sims.shape)
|
||||
|
||||
|
||||
row, col = scipy.optimize.linear_sum_assignment(-sims)
|
||||
|
||||
print('arg',sims.argmax(1))
|
||||
|
||||
print("Same matching frac", np.mean(col == global_matching) )
|
||||
print(col)
|
||||
global_matching = col
|
||||
|
||||
|
||||
def extract_by_training(resume):
|
||||
"""
|
||||
Final recovery process by extracting the neural network
|
||||
"""
|
||||
global inverse
|
||||
|
||||
device = torch.device('cuda:1')
|
||||
|
||||
if not resume:
|
||||
inverse = PrivateEncoder(Xray(True)).cuda(device)
|
||||
|
||||
# More adam to train.
|
||||
optimizer = torch.optim.Adam(inverse.parameters(), lr=0.0001)
|
||||
|
||||
this_xs = ooriginal[global_matching]
|
||||
this_ys = encoded[:,global_permutation,:]
|
||||
|
||||
for i in range(2000):
|
||||
idx = np.random.random_integers(0, len(this_xs)-1, 32)
|
||||
xbatch = torch.tensor(this_xs[idx]).cuda(device)
|
||||
ybatch = torch.tensor(this_ys[idx]).cuda(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
guess_output = inverse(xbatch)
|
||||
# L1 loss because we don't want to be sensitive to outliers
|
||||
error = torch.mean(torch.abs(guess_output-ybatch))
|
||||
error.backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
print(error)
|
||||
|
||||
|
||||
|
||||
def test_extract():
|
||||
"""
|
||||
Now we can recover the matching much better by computing the estimated
|
||||
encodings for each original image.
|
||||
"""
|
||||
global err, global_matching, guessed_encoded, smatrix
|
||||
|
||||
device = torch.device('cuda:1')
|
||||
|
||||
print(ooriginal.shape, encoded.shape)
|
||||
|
||||
out = []
|
||||
for i in range(0,len(ooriginal),32):
|
||||
print(i)
|
||||
out.extend(inverse(torch.tensor(ooriginal[i:i+32]).cuda(device)).cpu().detach().numpy())
|
||||
|
||||
guessed_encoded = np.array(out)
|
||||
|
||||
|
||||
# Now we have to compare each encoded image with every other original image.
|
||||
# Do this fast with some matrix multiplies.
|
||||
|
||||
out = guessed_encoded.reshape((len(encoded), -1))
|
||||
real = encoded[:,global_permutation,:].reshape((len(encoded), -1))
|
||||
@jax.jit
|
||||
def foo(x, y):
|
||||
return jn.square(x[:,None] - y[None,:]).sum(2)
|
||||
|
||||
smatrix = np.zeros((len(out), len(out)))
|
||||
|
||||
B = 500
|
||||
for i in range(0,len(out),B):
|
||||
print(i)
|
||||
for j in range(0,len(out),B):
|
||||
smatrix[i:i+B, j:j+B] = foo(out[i:i+B], real[j:j+B])
|
||||
|
||||
# And the final time you'l have to look at a min weight matching, I promise.
|
||||
row, col = scipy.optimize.linear_sum_assignment(np.array(smatrix))
|
||||
r = np.array(smatrix)
|
||||
|
||||
print(list(row)[::100])
|
||||
|
||||
print("Differences", np.mean(np.argsort(col) != global_matching))
|
||||
|
||||
global_matching = np.argsort(col)
|
||||
|
||||
|
||||
def perf(steps=[]):
|
||||
if len(steps) == 0:
|
||||
steps.append(time.time())
|
||||
else:
|
||||
print("Last Time Elapsed:", time.time()-steps[-1], ' Total Time Elapsed:', time.time()-steps[0])
|
||||
steps.append(time.time())
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if True:
|
||||
perf()
|
||||
setup('challenge')
|
||||
perf()
|
||||
gen_train_data()
|
||||
perf()
|
||||
get_train_features()
|
||||
perf()
|
||||
train_model()
|
||||
perf()
|
||||
|
||||
if True:
|
||||
load_challenge()
|
||||
perf()
|
||||
recover_local_permutation()
|
||||
perf()
|
||||
recover_better_local_permutation()
|
||||
perf()
|
||||
compute_patch_similarity()
|
||||
perf()
|
||||
recover_global_matching_first()
|
||||
perf()
|
||||
|
||||
for _ in range(3):
|
||||
recover_global_permutation()
|
||||
perf()
|
||||
recover_global_matching_second()
|
||||
perf()
|
||||
|
||||
for i in range(3):
|
||||
recover_global_permutation()
|
||||
perf()
|
||||
extract_by_training(i > 0)
|
||||
perf()
|
||||
test_extract()
|
||||
perf()
|
||||
print(perf())
|
Loading…
Reference in a new issue