diff --git a/research/mi_lira_2021/README.md b/research/mi_lira_2021/README.md index 7cb30e1..72cd48f 100644 --- a/research/mi_lira_2021/README.md +++ b/research/mi_lira_2021/README.md @@ -2,10 +2,9 @@ This directory contains code to reproduce our paper: -**"Membership Inference Attacks From First Principles"** -https://arxiv.org/abs/2112.03570 -by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramer. - +**"Membership Inference Attacks From First Principles"**
+https://arxiv.org/abs/2112.03570
+by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramèr. ### INSTALLING @@ -18,21 +17,20 @@ with JAX + ObJAX so you will need to follow build instructions for that https://github.com/google/objax https://objax.readthedocs.io/en/latest/installation_setup.html - ### RUNNING THE CODE #### 1. Train the models -The first step in our attack is to train shadow models. As a baseline -that should give most of the gains in our attack, you should start by -training 16 shadow models with the command +The first step in our attack is to train shadow models. As a baseline that +should give most of the gains in our attack, you should start by training 16 +shadow models with the command > bash scripts/train_demo.sh -or if you have multiple GPUs on your machine and want to train these models -in parallel, then modify and run +or if you have multiple GPUs on your machine and want to train these models in +parallel, then modify and run -> bash scripts/train_demo_multigpu.sh +> bash scripts/train_demo_multigpu.sh This will train several CIFAR-10 wide ResNet models to ~91% accuracy each, and will output a bunch of files under the directory exp/cifar10 with structure: @@ -63,14 +61,13 @@ exp/cifar10/ --- 0000000100.npy ``` -where this new file has shape (50000, 10) and stores the model's -output features for each example. - +where this new file has shape (50000, 10) and stores the model's output features +for each example. #### 3. Compute membership inference scores -Finally we take the output features and generate our logit-scaled membership inference -scores for each example for each model. +Finally we take the output features and generate our logit-scaled membership +inference scores for each example for each model. > python3 score.py exp/cifar10/ @@ -85,7 +82,6 @@ exp/cifar10/ with shape (50000,) storing just our scores. - ### PLOTTING THE RESULTS Finally we can generate pretty pictures, and run the plotting code @@ -94,7 +90,6 @@ Finally we can generate pretty pictures, and run the plotting code which should give (something like) the following output - ![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve") ``` @@ -111,9 +106,9 @@ Attack Global threshold ``` where the global threshold attack is the baseline, and our online, -online-with-fixed-variance, offline, and offline-with-fixed-variance -attack variants are the four other curves. Note that because we only -train a few models, the fixed variance variants perform best. +online-with-fixed-variance, offline, and offline-with-fixed-variance attack +variants are the four other curves. Note that because we only train a few +models, the fixed variance variants perform best. ### Citation @@ -126,4 +121,4 @@ You can cite this paper with journal={arXiv preprint arXiv:2112.03570}, year={2021} } -``` \ No newline at end of file +``` diff --git a/research/mi_lira_2021/inference.py b/research/mi_lira_2021/inference.py index 11ad696..9d78d0b 100644 --- a/research/mi_lira_2021/inference.py +++ b/research/mi_lira_2021/inference.py @@ -12,32 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -import os -from typing import Callable +# pylint: skip-file +# pyformat: disable + import json - +import os import re -import jax -import jax.numpy as jn + import numpy as np -import tensorflow as tf # For data augmentation. -import tensorflow_datasets as tfds -from absl import app, flags -from tqdm import tqdm, trange -import pickle -from functools import partial - import objax -from objax.jaxboard import SummaryWriter, Summary -from objax.util import EasyDict -from objax.zoo import convnet, wide_resnet +import tensorflow as tf # For data augmentation. +from absl import app +from absl import flags -from dataset import DataSet +from train import MemModule +from train import network -from train import MemModule, network - -from collections import defaultdict FLAGS = flags.FLAGS @@ -56,7 +46,7 @@ def main(argv): lr=.1, batch=0, epochs=0, - weight_decay=0) + weight_decay=0) def cache_load(arch): thing = [] @@ -68,8 +58,8 @@ def main(argv): xs_all = np.load(os.path.join(FLAGS.logdir,"x_train.npy"))[:FLAGS.dataset_size] ys_all = np.load(os.path.join(FLAGS.logdir,"y_train.npy"))[:FLAGS.dataset_size] - - + + def get_loss(model, xbatch, ybatch, shift, reflect=True, stride=1): outs = [] @@ -90,7 +80,7 @@ def main(argv): def features(model, xbatch, ybatch): return get_loss(model, xbatch, ybatch, shift=0, reflect=True, stride=1) - + for path in sorted(os.listdir(os.path.join(FLAGS.logdir))): if re.search(FLAGS.regex, path) is None: print("Skipping from regex") @@ -99,9 +89,9 @@ def main(argv): hparams = json.load(open(os.path.join(FLAGS.logdir, path, "hparams.json"))) arch = hparams['arch'] model = cache_load(arch)() - + logdir = os.path.join(FLAGS.logdir, path) - + checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=10, makedir=True) max_epoch, last_ckpt = checkpoint.restore(model.vars()) if max_epoch == 0: continue @@ -112,12 +102,12 @@ def main(argv): first = FLAGS.from_epoch else: first = max_epoch-1 - + for epoch in range(first,max_epoch+1): if not os.path.exists(os.path.join(FLAGS.logdir, path, "ckpt", "%010d.npz"%epoch)): # no checkpoint saved here continue - + if os.path.exists(os.path.join(FLAGS.logdir, path, "logits", "%010d.npy"%epoch)): print("Skipping already generated file", epoch) continue @@ -127,7 +117,7 @@ def main(argv): except: print("Fail to load", epoch) continue - + stats = [] for i in range(0,len(xs_all),N): @@ -142,9 +132,7 @@ if __name__ == '__main__': flags.DEFINE_string('dataset', 'cifar10', 'Dataset.') flags.DEFINE_string('logdir', 'experiments/', 'Directory where to save checkpoints and tensorboard data.') flags.DEFINE_string('regex', '.*experiment.*', 'keep files when matching') - flags.DEFINE_bool('random_labels', False, 'use random labels.') flags.DEFINE_integer('dataset_size', 50000, 'size of dataset.') flags.DEFINE_integer('from_epoch', None, 'which epoch to load from.') - flags.DEFINE_integer('seed_mod', None, 'keep mod seed.') - flags.DEFINE_integer('modulus', 8, 'modulus.') app.run(main) + diff --git a/research/mi_lira_2021/plot.py b/research/mi_lira_2021/plot.py index 435125c..42b1a54 100644 --- a/research/mi_lira_2021/plot.py +++ b/research/mi_lira_2021/plot.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pylint: skip-file +# pyformat: disable + import os import scipy.stats @@ -113,7 +116,7 @@ def generate_ours_offline(keep, scores, check_keep, check_scores, in_size=100000 dat_out.append(scores[~keep[:, j], j, :]) out_size = min(min(map(len,dat_out)), out_size) - + dat_out = np.array([x[:out_size] for x in dat_out]) mean_out = np.median(dat_out, 1) @@ -160,7 +163,7 @@ def do_plot(fn, keep, scores, ntest, legend='', metric='auc', sweep_fn=sweep, ** fpr, tpr, auc, acc = sweep_fn(np.array(prediction), np.array(answers, dtype=bool)) low = tpr[np.where(fpr<.001)[0][-1]] - + print('Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc,acc, low)) metric_text = '' @@ -206,7 +209,7 @@ def fig_fpr_tpr(): "Global threshold\n", metric='auc' ) - + plt.semilogx() plt.semilogy() plt.xlim(1e-5,1) @@ -220,5 +223,6 @@ def fig_fpr_tpr(): plt.show() -load_data("exp/cifar10/") -fig_fpr_tpr() +if __name__ == '__main__': + load_data("exp/cifar10/") + fig_fpr_tpr() diff --git a/research/mi_lira_2021/train.py b/research/mi_lira_2021/train.py index 19ff0e3..fa658ac 100644 --- a/research/mi_lira_2021/train.py +++ b/research/mi_lira_2021/train.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pylint: skip-file +# pyformat: disable + import functools import os import shutil @@ -24,12 +27,11 @@ import numpy as np import tensorflow as tf # For data augmentation. import tensorflow_datasets as tfds from absl import app, flags -from tqdm import tqdm, trange import objax from objax.jaxboard import SummaryWriter, Summary from objax.util import EasyDict -from objax.zoo import convnet, wide_resnet, dnnet +from objax.zoo import convnet, wide_resnet from dataset import DataSet @@ -202,11 +204,11 @@ def get_data(seed): data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR)) inputs = data['train']['image'] labels = data['train']['label'] - + inputs = (inputs/127.5)-1 np.save(os.path.join(FLAGS.logdir, "x_train.npy"),inputs) np.save(os.path.join(FLAGS.logdir, "y_train.npy"),labels) - + nclass = np.max(labels)+1 np.random.seed(seed) @@ -233,7 +235,7 @@ def get_data(seed): aug = lambda x: augment(x, 0, mirror=False) else: raise - + train = DataSet.from_arrays(xs, ys, augment_fn=aug) test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:]) @@ -252,7 +254,7 @@ def main(argv): import time seed = np.random.randint(0, 1000000000) seed ^= int(time.time()) - + args = EasyDict(arch=FLAGS.arch, lr=FLAGS.lr, batch=FLAGS.batch, @@ -260,7 +262,7 @@ def main(argv): augment=FLAGS.augment, seed=seed) - + if FLAGS.tunename: logdir = '_'.join(sorted('%s=%s' % k for k in args.items())) elif FLAGS.expid is not None: @@ -269,7 +271,7 @@ def main(argv): logdir = "experiment-"+str(seed) logdir = os.path.join(FLAGS.logdir, logdir) - if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%10)): + if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%FLAGS.epochs)): print(f"run {FLAGS.expid} already completed.") return else: @@ -282,7 +284,7 @@ def main(argv): os.makedirs(logdir) train, test, xs, ys, keep, nclass = get_data(seed) - + # Define the network and train_it tm = MemModule(network(FLAGS.arch), nclass=nclass, mnist=FLAGS.dataset == 'mnist', @@ -303,8 +305,8 @@ def main(argv): tm.train(FLAGS.epochs, len(xs), train, test, logdir, save_steps=FLAGS.save_steps, patience=FLAGS.patience) - - + + if __name__ == '__main__': flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.') @@ -327,3 +329,4 @@ if __name__ == '__main__': flags.DEFINE_integer('patience', None, 'Early stopping after this many epochs without progress') flags.DEFINE_bool('tunename', False, 'Use tune name?') app.run(main) + diff --git a/research/mi_poison_2022/README.md b/research/mi_poison_2022/README.md new file mode 100644 index 0000000..a20444d --- /dev/null +++ b/research/mi_poison_2022/README.md @@ -0,0 +1,116 @@ +## Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets + +This directory contains code to reproduce results from the paper: + +**"Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets"**
+https://arxiv.org/abs/2204.00032
+by Florian Tramèr, Reza Shokri, Ayrton San Joaquin, Hoang Le, Matthew Jagielski, Sanghyun Hong and Nicholas Carlini + +### INSTALLING + +The experiments in this directory are built on top of the +[LiRA membership inference attack](../mi_lira_2021). + +After following the [installation instructions](../mi_lira_2021#installing) for +LiRa, make sure the attack code is on your `PYTHONPATH`: + +```bash +export PYTHONPATH="${PYTHONPATH}:../mi_lira_2021" +``` + +### RUNNING THE CODE + +#### 1. Train the models + +The first step in our attack is to train shadow models, with some data points +targeted by a poisoning attack. You can train 16 shadow models with the command + +> bash scripts/train_demo.sh + +or if you have multiple GPUs on your machine and want to train these models in +parallel, then modify and run + +> bash scripts/train_demo_multigpu.sh + +This will train 16 CIFAR-10 wide ResNet models to ~91% accuracy each, with 250 +points targeted for poisoning. For each of these 250 targeted points, the +attacker adds 8 mislabeled poisoned copies of the point into the training set. +The training run will output a bunch of files under the directory exp/cifar10 +with structure: + +``` +exp/cifar10/ +- xtrain.npy +- ytain.npy +- poison_pos.npy +- experiment_N_of_16 +-- hparams.json +-- keep.npy +-- ckpt/ +--- 0000000100.npz +-- tb/ +``` + +The following flags control the poisoning attack: + +- `num_poison_targets (default=250)`. The number of targeted points. +- `poison_reps (default=8)`. The number of replicas per poison. +- `poison_pos_seed (default=0)`. The random seed to use to choose the target + points. + +We recommend that `num_poison_targets * poison_reps < 5000` on CIFAR-10, as +otherwise the poisons introduce too much label noise and the model's accuracy +(and the attack's success rate) will be degraded. + +#### 2. Perform inference and compute scores + +Exactly as for LiRA, we then evaluate the models on the entire CIFAR-10 dataset, +and generate logit-scaled membership inference scores. See +[here](../mi_lira_2021#2-perform-inference) and +[here](../mi_lira_2021#3-compute-membership-inference-scores) for details. + +```bash +python3 -m inference --logdir=exp/cifar10/ +python3 -m score exp/cifar10/ +``` + +### PLOTTING THE RESULTS + +Finally we can generate pretty pictures, and run the plotting code + +```bash +python3 plot_poison.py +``` + +which should give (something like) the following output + +![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve") + +``` +Attack No poison (LiRA) + AUC 0.7025, Accuracy 0.6258, TPR@0.1%FPR of 0.0544 +Attack No poison (Global threshold) + AUC 0.6191, Accuracy 0.6173, TPR@0.1%FPR of 0.0012 +Attack With poison (LiRA) + AUC 0.9943, Accuracy 0.9653, TPR@0.1%FPR of 0.4945 +Attack With poison (Global threshold) + AUC 0.9922, Accuracy 0.9603, TPR@0.1%FPR of 0.3930 +``` + +where the baselines are LiRA and a simple global threshold on the membership +scores, both without poisoning. With poisoning, both LiRA and the global +threshold attack are boosted significantly. Note that because we only train a +few models, we use the fixed variance variant of LiRA. + +### Citation + +You can cite this paper with + +``` +@article{tramer2022truth, + title={Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets}, + author={Tramer, Florian and Shokri, Reza and San Joaquin, Ayrton and Le, Hoang and Jagielski, Matthew and Hong, Sanghyun and Carlini, Nicholas}, + journal={arXiv preprint arXiv:2204.00032}, + year={2022} +} +``` diff --git a/research/mi_poison_2022/fprtpr.png b/research/mi_poison_2022/fprtpr.png new file mode 100644 index 0000000..a870cb9 Binary files /dev/null and b/research/mi_poison_2022/fprtpr.png differ diff --git a/research/mi_poison_2022/logs/.keep b/research/mi_poison_2022/logs/.keep new file mode 100644 index 0000000..bb545f4 --- /dev/null +++ b/research/mi_poison_2022/logs/.keep @@ -0,0 +1,13 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/research/mi_poison_2022/plot_poison.py b/research/mi_poison_2022/plot_poison.py new file mode 100644 index 0000000..03306f5 --- /dev/null +++ b/research/mi_poison_2022/plot_poison.py @@ -0,0 +1,116 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: skip-file +# pyformat: disable + +import os +import numpy as np +import matplotlib.pyplot as plt +import functools + +import matplotlib +matplotlib.rcParams['pdf.fonttype'] = 42 +matplotlib.rcParams['ps.fonttype'] = 42 + +# from mi_lira_2021 +from plot import sweep, load_data, generate_ours, generate_global + + +def do_plot_all(fn, keep, scores, legend='', metric='auc', sweep_fn=sweep, **plot_kwargs): + """ + Generate the ROC curves by using one model as test model and the rest to train, + with a full leave-one-out cross-validation. + """ + + all_predictions = [] + all_answers = [] + for i in range(0, len(keep)): + mask = np.zeros(len(keep), dtype=bool) + mask[i:i+1] = True + prediction, answers = fn(keep[~mask], + scores[~mask], + keep[mask], + scores[mask]) + all_predictions.extend(prediction) + all_answers.extend(answers) + + fpr, tpr, auc, acc = sweep_fn(np.array(all_predictions), + np.array(all_answers, dtype=bool)) + + low = tpr[np.where(fpr < .001)[0][-1]] + print('Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc, acc, low)) + + metric_text = '' + if metric == 'auc': + metric_text = 'auc=%.3f' % auc + elif metric == 'acc': + metric_text = 'acc=%.3f' % acc + + plt.plot(fpr, tpr, label=legend+metric_text, **plot_kwargs) + return acc, auc + + +def fig_fpr_tpr(poison_mask, scores, keep): + + plt.figure(figsize=(4, 3)) + + # evaluate LiRA on the points that were not targeted by poisoning + do_plot_all(functools.partial(generate_ours, fix_variance=True), + keep[:, ~poison_mask], scores[:, ~poison_mask], + "No poison (LiRA)\n", + metric='auc', + ) + + # evaluate the global-threshold attack on the points that were not targeted by poisoning + do_plot_all(generate_global, + keep[:, ~poison_mask], scores[:, ~poison_mask], + "No poison (Global threshold)\n", + metric='auc', ls="--", c=plt.gca().lines[-1].get_color() + ) + + # evaluate LiRA on the points that were targeted by poisoning + do_plot_all(functools.partial(generate_ours, fix_variance=True), + keep[:, poison_mask], scores[:, poison_mask], + "With poison (LiRA)\n", + metric='auc', + ) + + # evaluate the global-threshold attack on the points that were targeted by poisoning + do_plot_all(generate_global, + keep[:, poison_mask], scores[:, poison_mask], + "With poison (Global threshold)\n", + metric='auc', ls="--", c=plt.gca().lines[-1].get_color() + ) + + plt.semilogx() + plt.semilogy() + plt.xlim(1e-3, 1) + plt.ylim(1e-3, 1) + plt.xlabel("False Positive Rate") + plt.ylabel("True Positive Rate") + plt.plot([0, 1], [0, 1], ls='--', color='gray') + plt.subplots_adjust(bottom=.18, left=.18, top=.96, right=.96) + plt.legend(fontsize=8) + plt.savefig("/tmp/fprtpr.png") + plt.show() + + +if __name__ == '__main__': + logdir = "exp/cifar10/" + scores, keep = load_data(logdir) + poison_pos = np.load(os.path.join(logdir, "poison_pos.npy")) + poison_mask = np.zeros(scores.shape[1], dtype=bool) + poison_mask[poison_pos] = True + fig_fpr_tpr(poison_mask, scores, keep) diff --git a/research/mi_poison_2022/scripts/train_demo.sh b/research/mi_poison_2022/scripts/train_demo.sh new file mode 100644 index 0000000..16b1434 --- /dev/null +++ b/research/mi_poison_2022/scripts/train_demo.sh @@ -0,0 +1,30 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0 +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1 +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2 +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3 +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8 +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9 +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10 +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11 +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 diff --git a/research/mi_poison_2022/scripts/train_demo_multigpu.sh b/research/mi_poison_2022/scripts/train_demo_multigpu.sh new file mode 100644 index 0000000..7d8d81f --- /dev/null +++ b/research/mi_poison_2022/scripts/train_demo_multigpu.sh @@ -0,0 +1,32 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0 & +CUDA_VISIBLE_DEVICES='1' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1 & +CUDA_VISIBLE_DEVICES='2' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2 & +CUDA_VISIBLE_DEVICES='3' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3 & +CUDA_VISIBLE_DEVICES='4' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 & +CUDA_VISIBLE_DEVICES='5' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 & +CUDA_VISIBLE_DEVICES='6' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 & +CUDA_VISIBLE_DEVICES='7' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 & +wait; +CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8 & +CUDA_VISIBLE_DEVICES='1' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9 & +CUDA_VISIBLE_DEVICES='2' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10 & +CUDA_VISIBLE_DEVICES='3' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11 & +CUDA_VISIBLE_DEVICES='4' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 & +CUDA_VISIBLE_DEVICES='5' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 & +CUDA_VISIBLE_DEVICES='6' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 & +CUDA_VISIBLE_DEVICES='7' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 & +wait; diff --git a/research/mi_poison_2022/train_poison.py b/research/mi_poison_2022/train_poison.py new file mode 100644 index 0000000..f2afc2c --- /dev/null +++ b/research/mi_poison_2022/train_poison.py @@ -0,0 +1,214 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: skip-file +# pyformat: disable + +import os +import shutil +import json + +import numpy as np +import tensorflow as tf # For data augmentation. +import tensorflow_datasets as tfds +from absl import app, flags + +from objax.util import EasyDict + +# from mi_lira_2021 +from dataset import DataSet +from train import augment, MemModule, network + +FLAGS = flags.FLAGS + + +def get_data(seed): + """ + This is the function to generate subsets of the data for training models. + + First, we get the training dataset either from the numpy cache + or otherwise we load it from tensorflow datasets. + + Then, we compute the subset. This works in one of two ways. + + 1. If we have a seed, then we just randomly choose examples based on + a prng with that seed, keeping FLAGS.pkeep fraction of the data. + + 2. Otherwise, if we have an experiment ID, then we do something fancier. + If we run each experiment independently then even after a lot of trials + there will still probably be some examples that were always included + or always excluded. So instead, with experiment IDs, we guarantee that + after FLAGS.num_experiments are done, each example is seen exactly half + of the time in train, and half of the time not in train. + + Finally, we add some poisons. The same poisoned samples are added for + each randomly generated training set. + We first select FLAGS.num_poison_targets victim points that will be targeted + by the poisoning attack. For each of these victim points, the attacker will + insert FLAGS.poison_reps mislabeled replicas of the point into the training + set. + + For CIFAR-10, we recommend that: + + `FLAGS.num_poison_targets * FLAGS.poison_reps < 5000` + + Otherwise, the poisons might introduce too much label noise and the model's + accuracy (and the attack's success rate) will be degraded. + """ + DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS') + + if os.path.exists(os.path.join(FLAGS.logdir, "x_train.npy")): + inputs = np.load(os.path.join(FLAGS.logdir, "x_train.npy")) + labels = np.load(os.path.join(FLAGS.logdir, "y_train.npy")) + else: + print("First time, creating dataset") + data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR)) + inputs = data['train']['image'] + labels = data['train']['label'] + + inputs = (inputs/127.5)-1 + np.save(os.path.join(FLAGS.logdir, "x_train.npy"), inputs) + np.save(os.path.join(FLAGS.logdir, "y_train.npy"), labels) + + nclass = np.max(labels)+1 + + np.random.seed(seed) + if FLAGS.num_experiments is not None: + np.random.seed(0) + keep = np.random.uniform(0, 1, size=(FLAGS.num_experiments, len(inputs))) + order = keep.argsort(0) + keep = order < int(FLAGS.pkeep * FLAGS.num_experiments) + keep = np.array(keep[FLAGS.expid], dtype=bool) + else: + keep = np.random.uniform(0, 1, size=len(inputs)) <= FLAGS.pkeep + + xs = inputs[keep] + ys = labels[keep] + + if FLAGS.num_poison_targets > 0: + + # select some points as targets + np.random.seed(FLAGS.poison_pos_seed) + poison_pos = np.random.choice(len(inputs), size=FLAGS.num_poison_targets, replace=False) + + # create mislabeled poisons for the targeted points and replicate each + # poison `FLAGS.poison_reps` times + y_noise = np.mod(labels[poison_pos] + np.random.randint(low=1, high=nclass, size=FLAGS.num_poison_targets), nclass) + ypoison = np.repeat(y_noise, FLAGS.poison_reps) + xpoison = np.repeat(inputs[poison_pos], FLAGS.poison_reps, axis=0) + xs = np.concatenate((xs, xpoison), axis=0) + ys = np.concatenate((ys, ypoison), axis=0) + + if not os.path.exists(os.path.join(FLAGS.logdir, "poison_pos.npy")): + np.save(os.path.join(FLAGS.logdir, "poison_pos.npy"), poison_pos) + + if FLAGS.augment == 'weak': + aug = lambda x: augment(x, 4) + elif FLAGS.augment == 'mirror': + aug = lambda x: augment(x, 0) + elif FLAGS.augment == 'none': + aug = lambda x: augment(x, 0, mirror=False) + else: + raise + + print(xs.shape, ys.shape) + train = DataSet.from_arrays(xs, ys, + augment_fn=aug) + test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:]) + train = train.cache().shuffle(len(xs)).repeat().parse().augment().batch(FLAGS.batch) + train = train.nchw().one_hot(nclass).prefetch(FLAGS.batch) + test = test.cache().parse().batch(FLAGS.batch).nchw().prefetch(FLAGS.batch) + + return train, test, xs, ys, keep, nclass + + +def main(argv): + del argv + tf.config.experimental.set_visible_devices([], "GPU") + + seed = FLAGS.seed + if seed is None: + import time + seed = np.random.randint(0, 1000000000) + seed ^= int(time.time()) + + args = EasyDict(arch=FLAGS.arch, + lr=FLAGS.lr, + batch=FLAGS.batch, + weight_decay=FLAGS.weight_decay, + augment=FLAGS.augment, + seed=seed) + + if FLAGS.expid is not None: + logdir = "experiment-%d_%d" % (FLAGS.expid, FLAGS.num_experiments) + else: + logdir = "experiment-"+str(seed) + logdir = os.path.join(FLAGS.logdir, logdir) + + if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz" % FLAGS.epochs)): + print(f"run {FLAGS.expid} already completed.") + return + else: + if os.path.exists(logdir): + print(f"deleting run {FLAGS.expid} that did not complete.") + shutil.rmtree(logdir) + + print(f"starting run {FLAGS.expid}.") + if not os.path.exists(logdir): + os.makedirs(logdir) + + train, test, xs, ys, keep, nclass = get_data(seed) + + # Define the network and train_it + tm = MemModule(network(FLAGS.arch), nclass=nclass, + mnist=FLAGS.dataset == 'mnist', + epochs=FLAGS.epochs, + expid=FLAGS.expid, + num_experiments=FLAGS.num_experiments, + pkeep=FLAGS.pkeep, + save_steps=FLAGS.save_steps, + **args + ) + + r = {} + r.update(tm.params) + + open(os.path.join(logdir, 'hparams.json'), "w").write(json.dumps(tm.params)) + np.save(os.path.join(logdir,'keep.npy'), keep) + + tm.train(FLAGS.epochs, len(xs), train, test, logdir, + save_steps=FLAGS.save_steps) + + +if __name__ == '__main__': + flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.') + flags.DEFINE_float('lr', 0.1, 'Learning rate.') + flags.DEFINE_string('dataset', 'cifar10', 'Dataset.') + flags.DEFINE_float('weight_decay', 0.0005, 'Weight decay ratio.') + flags.DEFINE_integer('batch', 256, 'Batch size') + flags.DEFINE_integer('epochs', 100, 'Training duration in number of epochs.') + flags.DEFINE_string('logdir', 'experiments', 'Directory where to save checkpoints and tensorboard data.') + flags.DEFINE_integer('seed', None, 'Training seed.') + flags.DEFINE_float('pkeep', .5, 'Probability to keep examples.') + flags.DEFINE_integer('expid', None, 'Experiment ID') + flags.DEFINE_integer('num_experiments', None, 'Number of experiments') + flags.DEFINE_string('augment', 'weak', 'Strong or weak augmentation') + flags.DEFINE_integer('eval_steps', 1, 'how often to get eval accuracy.') + flags.DEFINE_integer('save_steps', 10, 'how often to get save model.') + + flags.DEFINE_integer('num_poison_targets', 250, 'Number of points to target ' + 'with the poisoning attack.') + flags.DEFINE_integer('poison_reps', 8, 'Number of times to repeat each poison.') + flags.DEFINE_integer('poison_pos_seed', 0, '') + app.run(main)