diff --git a/research/mi_lira_2021/README.md b/research/mi_lira_2021/README.md
index 7cb30e1..72cd48f 100644
--- a/research/mi_lira_2021/README.md
+++ b/research/mi_lira_2021/README.md
@@ -2,10 +2,9 @@
This directory contains code to reproduce our paper:
-**"Membership Inference Attacks From First Principles"**
-https://arxiv.org/abs/2112.03570
-by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramer.
-
+**"Membership Inference Attacks From First Principles"**
+https://arxiv.org/abs/2112.03570
+by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramèr.
### INSTALLING
@@ -18,21 +17,20 @@ with JAX + ObJAX so you will need to follow build instructions for that
https://github.com/google/objax
https://objax.readthedocs.io/en/latest/installation_setup.html
-
### RUNNING THE CODE
#### 1. Train the models
-The first step in our attack is to train shadow models. As a baseline
-that should give most of the gains in our attack, you should start by
-training 16 shadow models with the command
+The first step in our attack is to train shadow models. As a baseline that
+should give most of the gains in our attack, you should start by training 16
+shadow models with the command
> bash scripts/train_demo.sh
-or if you have multiple GPUs on your machine and want to train these models
-in parallel, then modify and run
+or if you have multiple GPUs on your machine and want to train these models in
+parallel, then modify and run
-> bash scripts/train_demo_multigpu.sh
+> bash scripts/train_demo_multigpu.sh
This will train several CIFAR-10 wide ResNet models to ~91% accuracy each, and
will output a bunch of files under the directory exp/cifar10 with structure:
@@ -63,14 +61,13 @@ exp/cifar10/
--- 0000000100.npy
```
-where this new file has shape (50000, 10) and stores the model's
-output features for each example.
-
+where this new file has shape (50000, 10) and stores the model's output features
+for each example.
#### 3. Compute membership inference scores
-Finally we take the output features and generate our logit-scaled membership inference
-scores for each example for each model.
+Finally we take the output features and generate our logit-scaled membership
+inference scores for each example for each model.
> python3 score.py exp/cifar10/
@@ -85,7 +82,6 @@ exp/cifar10/
with shape (50000,) storing just our scores.
-
### PLOTTING THE RESULTS
Finally we can generate pretty pictures, and run the plotting code
@@ -94,7 +90,6 @@ Finally we can generate pretty pictures, and run the plotting code
which should give (something like) the following output
-
![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve")
```
@@ -111,9 +106,9 @@ Attack Global threshold
```
where the global threshold attack is the baseline, and our online,
-online-with-fixed-variance, offline, and offline-with-fixed-variance
-attack variants are the four other curves. Note that because we only
-train a few models, the fixed variance variants perform best.
+online-with-fixed-variance, offline, and offline-with-fixed-variance attack
+variants are the four other curves. Note that because we only train a few
+models, the fixed variance variants perform best.
### Citation
@@ -126,4 +121,4 @@ You can cite this paper with
journal={arXiv preprint arXiv:2112.03570},
year={2021}
}
-```
\ No newline at end of file
+```
diff --git a/research/mi_lira_2021/inference.py b/research/mi_lira_2021/inference.py
index 11ad696..9d78d0b 100644
--- a/research/mi_lira_2021/inference.py
+++ b/research/mi_lira_2021/inference.py
@@ -12,32 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import functools
-import os
-from typing import Callable
+# pylint: skip-file
+# pyformat: disable
+
import json
-
+import os
import re
-import jax
-import jax.numpy as jn
+
import numpy as np
-import tensorflow as tf # For data augmentation.
-import tensorflow_datasets as tfds
-from absl import app, flags
-from tqdm import tqdm, trange
-import pickle
-from functools import partial
-
import objax
-from objax.jaxboard import SummaryWriter, Summary
-from objax.util import EasyDict
-from objax.zoo import convnet, wide_resnet
+import tensorflow as tf # For data augmentation.
+from absl import app
+from absl import flags
-from dataset import DataSet
+from train import MemModule
+from train import network
-from train import MemModule, network
-
-from collections import defaultdict
FLAGS = flags.FLAGS
@@ -56,7 +46,7 @@ def main(argv):
lr=.1,
batch=0,
epochs=0,
- weight_decay=0)
+ weight_decay=0)
def cache_load(arch):
thing = []
@@ -68,8 +58,8 @@ def main(argv):
xs_all = np.load(os.path.join(FLAGS.logdir,"x_train.npy"))[:FLAGS.dataset_size]
ys_all = np.load(os.path.join(FLAGS.logdir,"y_train.npy"))[:FLAGS.dataset_size]
-
-
+
+
def get_loss(model, xbatch, ybatch, shift, reflect=True, stride=1):
outs = []
@@ -90,7 +80,7 @@ def main(argv):
def features(model, xbatch, ybatch):
return get_loss(model, xbatch, ybatch,
shift=0, reflect=True, stride=1)
-
+
for path in sorted(os.listdir(os.path.join(FLAGS.logdir))):
if re.search(FLAGS.regex, path) is None:
print("Skipping from regex")
@@ -99,9 +89,9 @@ def main(argv):
hparams = json.load(open(os.path.join(FLAGS.logdir, path, "hparams.json")))
arch = hparams['arch']
model = cache_load(arch)()
-
+
logdir = os.path.join(FLAGS.logdir, path)
-
+
checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=10, makedir=True)
max_epoch, last_ckpt = checkpoint.restore(model.vars())
if max_epoch == 0: continue
@@ -112,12 +102,12 @@ def main(argv):
first = FLAGS.from_epoch
else:
first = max_epoch-1
-
+
for epoch in range(first,max_epoch+1):
if not os.path.exists(os.path.join(FLAGS.logdir, path, "ckpt", "%010d.npz"%epoch)):
# no checkpoint saved here
continue
-
+
if os.path.exists(os.path.join(FLAGS.logdir, path, "logits", "%010d.npy"%epoch)):
print("Skipping already generated file", epoch)
continue
@@ -127,7 +117,7 @@ def main(argv):
except:
print("Fail to load", epoch)
continue
-
+
stats = []
for i in range(0,len(xs_all),N):
@@ -142,9 +132,7 @@ if __name__ == '__main__':
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
flags.DEFINE_string('logdir', 'experiments/', 'Directory where to save checkpoints and tensorboard data.')
flags.DEFINE_string('regex', '.*experiment.*', 'keep files when matching')
- flags.DEFINE_bool('random_labels', False, 'use random labels.')
flags.DEFINE_integer('dataset_size', 50000, 'size of dataset.')
flags.DEFINE_integer('from_epoch', None, 'which epoch to load from.')
- flags.DEFINE_integer('seed_mod', None, 'keep mod seed.')
- flags.DEFINE_integer('modulus', 8, 'modulus.')
app.run(main)
+
diff --git a/research/mi_lira_2021/plot.py b/research/mi_lira_2021/plot.py
index 435125c..42b1a54 100644
--- a/research/mi_lira_2021/plot.py
+++ b/research/mi_lira_2021/plot.py
@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# pylint: skip-file
+# pyformat: disable
+
import os
import scipy.stats
@@ -113,7 +116,7 @@ def generate_ours_offline(keep, scores, check_keep, check_scores, in_size=100000
dat_out.append(scores[~keep[:, j], j, :])
out_size = min(min(map(len,dat_out)), out_size)
-
+
dat_out = np.array([x[:out_size] for x in dat_out])
mean_out = np.median(dat_out, 1)
@@ -160,7 +163,7 @@ def do_plot(fn, keep, scores, ntest, legend='', metric='auc', sweep_fn=sweep, **
fpr, tpr, auc, acc = sweep_fn(np.array(prediction), np.array(answers, dtype=bool))
low = tpr[np.where(fpr<.001)[0][-1]]
-
+
print('Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc,acc, low))
metric_text = ''
@@ -206,7 +209,7 @@ def fig_fpr_tpr():
"Global threshold\n",
metric='auc'
)
-
+
plt.semilogx()
plt.semilogy()
plt.xlim(1e-5,1)
@@ -220,5 +223,6 @@ def fig_fpr_tpr():
plt.show()
-load_data("exp/cifar10/")
-fig_fpr_tpr()
+if __name__ == '__main__':
+ load_data("exp/cifar10/")
+ fig_fpr_tpr()
diff --git a/research/mi_lira_2021/train.py b/research/mi_lira_2021/train.py
index 19ff0e3..fa658ac 100644
--- a/research/mi_lira_2021/train.py
+++ b/research/mi_lira_2021/train.py
@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# pylint: skip-file
+# pyformat: disable
+
import functools
import os
import shutil
@@ -24,12 +27,11 @@ import numpy as np
import tensorflow as tf # For data augmentation.
import tensorflow_datasets as tfds
from absl import app, flags
-from tqdm import tqdm, trange
import objax
from objax.jaxboard import SummaryWriter, Summary
from objax.util import EasyDict
-from objax.zoo import convnet, wide_resnet, dnnet
+from objax.zoo import convnet, wide_resnet
from dataset import DataSet
@@ -202,11 +204,11 @@ def get_data(seed):
data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR))
inputs = data['train']['image']
labels = data['train']['label']
-
+
inputs = (inputs/127.5)-1
np.save(os.path.join(FLAGS.logdir, "x_train.npy"),inputs)
np.save(os.path.join(FLAGS.logdir, "y_train.npy"),labels)
-
+
nclass = np.max(labels)+1
np.random.seed(seed)
@@ -233,7 +235,7 @@ def get_data(seed):
aug = lambda x: augment(x, 0, mirror=False)
else:
raise
-
+
train = DataSet.from_arrays(xs, ys,
augment_fn=aug)
test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:])
@@ -252,7 +254,7 @@ def main(argv):
import time
seed = np.random.randint(0, 1000000000)
seed ^= int(time.time())
-
+
args = EasyDict(arch=FLAGS.arch,
lr=FLAGS.lr,
batch=FLAGS.batch,
@@ -260,7 +262,7 @@ def main(argv):
augment=FLAGS.augment,
seed=seed)
-
+
if FLAGS.tunename:
logdir = '_'.join(sorted('%s=%s' % k for k in args.items()))
elif FLAGS.expid is not None:
@@ -269,7 +271,7 @@ def main(argv):
logdir = "experiment-"+str(seed)
logdir = os.path.join(FLAGS.logdir, logdir)
- if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%10)):
+ if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%FLAGS.epochs)):
print(f"run {FLAGS.expid} already completed.")
return
else:
@@ -282,7 +284,7 @@ def main(argv):
os.makedirs(logdir)
train, test, xs, ys, keep, nclass = get_data(seed)
-
+
# Define the network and train_it
tm = MemModule(network(FLAGS.arch), nclass=nclass,
mnist=FLAGS.dataset == 'mnist',
@@ -303,8 +305,8 @@ def main(argv):
tm.train(FLAGS.epochs, len(xs), train, test, logdir,
save_steps=FLAGS.save_steps, patience=FLAGS.patience)
-
-
+
+
if __name__ == '__main__':
flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.')
@@ -327,3 +329,4 @@ if __name__ == '__main__':
flags.DEFINE_integer('patience', None, 'Early stopping after this many epochs without progress')
flags.DEFINE_bool('tunename', False, 'Use tune name?')
app.run(main)
+
diff --git a/research/mi_poison_2022/README.md b/research/mi_poison_2022/README.md
new file mode 100644
index 0000000..a20444d
--- /dev/null
+++ b/research/mi_poison_2022/README.md
@@ -0,0 +1,116 @@
+## Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets
+
+This directory contains code to reproduce results from the paper:
+
+**"Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets"**
+https://arxiv.org/abs/2204.00032
+by Florian Tramèr, Reza Shokri, Ayrton San Joaquin, Hoang Le, Matthew Jagielski, Sanghyun Hong and Nicholas Carlini
+
+### INSTALLING
+
+The experiments in this directory are built on top of the
+[LiRA membership inference attack](../mi_lira_2021).
+
+After following the [installation instructions](../mi_lira_2021#installing) for
+LiRa, make sure the attack code is on your `PYTHONPATH`:
+
+```bash
+export PYTHONPATH="${PYTHONPATH}:../mi_lira_2021"
+```
+
+### RUNNING THE CODE
+
+#### 1. Train the models
+
+The first step in our attack is to train shadow models, with some data points
+targeted by a poisoning attack. You can train 16 shadow models with the command
+
+> bash scripts/train_demo.sh
+
+or if you have multiple GPUs on your machine and want to train these models in
+parallel, then modify and run
+
+> bash scripts/train_demo_multigpu.sh
+
+This will train 16 CIFAR-10 wide ResNet models to ~91% accuracy each, with 250
+points targeted for poisoning. For each of these 250 targeted points, the
+attacker adds 8 mislabeled poisoned copies of the point into the training set.
+The training run will output a bunch of files under the directory exp/cifar10
+with structure:
+
+```
+exp/cifar10/
+- xtrain.npy
+- ytain.npy
+- poison_pos.npy
+- experiment_N_of_16
+-- hparams.json
+-- keep.npy
+-- ckpt/
+--- 0000000100.npz
+-- tb/
+```
+
+The following flags control the poisoning attack:
+
+- `num_poison_targets (default=250)`. The number of targeted points.
+- `poison_reps (default=8)`. The number of replicas per poison.
+- `poison_pos_seed (default=0)`. The random seed to use to choose the target
+ points.
+
+We recommend that `num_poison_targets * poison_reps < 5000` on CIFAR-10, as
+otherwise the poisons introduce too much label noise and the model's accuracy
+(and the attack's success rate) will be degraded.
+
+#### 2. Perform inference and compute scores
+
+Exactly as for LiRA, we then evaluate the models on the entire CIFAR-10 dataset,
+and generate logit-scaled membership inference scores. See
+[here](../mi_lira_2021#2-perform-inference) and
+[here](../mi_lira_2021#3-compute-membership-inference-scores) for details.
+
+```bash
+python3 -m inference --logdir=exp/cifar10/
+python3 -m score exp/cifar10/
+```
+
+### PLOTTING THE RESULTS
+
+Finally we can generate pretty pictures, and run the plotting code
+
+```bash
+python3 plot_poison.py
+```
+
+which should give (something like) the following output
+
+![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve")
+
+```
+Attack No poison (LiRA)
+ AUC 0.7025, Accuracy 0.6258, TPR@0.1%FPR of 0.0544
+Attack No poison (Global threshold)
+ AUC 0.6191, Accuracy 0.6173, TPR@0.1%FPR of 0.0012
+Attack With poison (LiRA)
+ AUC 0.9943, Accuracy 0.9653, TPR@0.1%FPR of 0.4945
+Attack With poison (Global threshold)
+ AUC 0.9922, Accuracy 0.9603, TPR@0.1%FPR of 0.3930
+```
+
+where the baselines are LiRA and a simple global threshold on the membership
+scores, both without poisoning. With poisoning, both LiRA and the global
+threshold attack are boosted significantly. Note that because we only train a
+few models, we use the fixed variance variant of LiRA.
+
+### Citation
+
+You can cite this paper with
+
+```
+@article{tramer2022truth,
+ title={Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets},
+ author={Tramer, Florian and Shokri, Reza and San Joaquin, Ayrton and Le, Hoang and Jagielski, Matthew and Hong, Sanghyun and Carlini, Nicholas},
+ journal={arXiv preprint arXiv:2204.00032},
+ year={2022}
+}
+```
diff --git a/research/mi_poison_2022/fprtpr.png b/research/mi_poison_2022/fprtpr.png
new file mode 100644
index 0000000..a870cb9
Binary files /dev/null and b/research/mi_poison_2022/fprtpr.png differ
diff --git a/research/mi_poison_2022/logs/.keep b/research/mi_poison_2022/logs/.keep
new file mode 100644
index 0000000..bb545f4
--- /dev/null
+++ b/research/mi_poison_2022/logs/.keep
@@ -0,0 +1,13 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/research/mi_poison_2022/plot_poison.py b/research/mi_poison_2022/plot_poison.py
new file mode 100644
index 0000000..03306f5
--- /dev/null
+++ b/research/mi_poison_2022/plot_poison.py
@@ -0,0 +1,116 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: skip-file
+# pyformat: disable
+
+import os
+import numpy as np
+import matplotlib.pyplot as plt
+import functools
+
+import matplotlib
+matplotlib.rcParams['pdf.fonttype'] = 42
+matplotlib.rcParams['ps.fonttype'] = 42
+
+# from mi_lira_2021
+from plot import sweep, load_data, generate_ours, generate_global
+
+
+def do_plot_all(fn, keep, scores, legend='', metric='auc', sweep_fn=sweep, **plot_kwargs):
+ """
+ Generate the ROC curves by using one model as test model and the rest to train,
+ with a full leave-one-out cross-validation.
+ """
+
+ all_predictions = []
+ all_answers = []
+ for i in range(0, len(keep)):
+ mask = np.zeros(len(keep), dtype=bool)
+ mask[i:i+1] = True
+ prediction, answers = fn(keep[~mask],
+ scores[~mask],
+ keep[mask],
+ scores[mask])
+ all_predictions.extend(prediction)
+ all_answers.extend(answers)
+
+ fpr, tpr, auc, acc = sweep_fn(np.array(all_predictions),
+ np.array(all_answers, dtype=bool))
+
+ low = tpr[np.where(fpr < .001)[0][-1]]
+ print('Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc, acc, low))
+
+ metric_text = ''
+ if metric == 'auc':
+ metric_text = 'auc=%.3f' % auc
+ elif metric == 'acc':
+ metric_text = 'acc=%.3f' % acc
+
+ plt.plot(fpr, tpr, label=legend+metric_text, **plot_kwargs)
+ return acc, auc
+
+
+def fig_fpr_tpr(poison_mask, scores, keep):
+
+ plt.figure(figsize=(4, 3))
+
+ # evaluate LiRA on the points that were not targeted by poisoning
+ do_plot_all(functools.partial(generate_ours, fix_variance=True),
+ keep[:, ~poison_mask], scores[:, ~poison_mask],
+ "No poison (LiRA)\n",
+ metric='auc',
+ )
+
+ # evaluate the global-threshold attack on the points that were not targeted by poisoning
+ do_plot_all(generate_global,
+ keep[:, ~poison_mask], scores[:, ~poison_mask],
+ "No poison (Global threshold)\n",
+ metric='auc', ls="--", c=plt.gca().lines[-1].get_color()
+ )
+
+ # evaluate LiRA on the points that were targeted by poisoning
+ do_plot_all(functools.partial(generate_ours, fix_variance=True),
+ keep[:, poison_mask], scores[:, poison_mask],
+ "With poison (LiRA)\n",
+ metric='auc',
+ )
+
+ # evaluate the global-threshold attack on the points that were targeted by poisoning
+ do_plot_all(generate_global,
+ keep[:, poison_mask], scores[:, poison_mask],
+ "With poison (Global threshold)\n",
+ metric='auc', ls="--", c=plt.gca().lines[-1].get_color()
+ )
+
+ plt.semilogx()
+ plt.semilogy()
+ plt.xlim(1e-3, 1)
+ plt.ylim(1e-3, 1)
+ plt.xlabel("False Positive Rate")
+ plt.ylabel("True Positive Rate")
+ plt.plot([0, 1], [0, 1], ls='--', color='gray')
+ plt.subplots_adjust(bottom=.18, left=.18, top=.96, right=.96)
+ plt.legend(fontsize=8)
+ plt.savefig("/tmp/fprtpr.png")
+ plt.show()
+
+
+if __name__ == '__main__':
+ logdir = "exp/cifar10/"
+ scores, keep = load_data(logdir)
+ poison_pos = np.load(os.path.join(logdir, "poison_pos.npy"))
+ poison_mask = np.zeros(scores.shape[1], dtype=bool)
+ poison_mask[poison_pos] = True
+ fig_fpr_tpr(poison_mask, scores, keep)
diff --git a/research/mi_poison_2022/scripts/train_demo.sh b/research/mi_poison_2022/scripts/train_demo.sh
new file mode 100644
index 0000000..16b1434
--- /dev/null
+++ b/research/mi_poison_2022/scripts/train_demo.sh
@@ -0,0 +1,30 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15
diff --git a/research/mi_poison_2022/scripts/train_demo_multigpu.sh b/research/mi_poison_2022/scripts/train_demo_multigpu.sh
new file mode 100644
index 0000000..7d8d81f
--- /dev/null
+++ b/research/mi_poison_2022/scripts/train_demo_multigpu.sh
@@ -0,0 +1,32 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0 &
+CUDA_VISIBLE_DEVICES='1' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1 &
+CUDA_VISIBLE_DEVICES='2' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2 &
+CUDA_VISIBLE_DEVICES='3' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3 &
+CUDA_VISIBLE_DEVICES='4' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 &
+CUDA_VISIBLE_DEVICES='5' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 &
+CUDA_VISIBLE_DEVICES='6' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 &
+CUDA_VISIBLE_DEVICES='7' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 &
+wait;
+CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8 &
+CUDA_VISIBLE_DEVICES='1' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9 &
+CUDA_VISIBLE_DEVICES='2' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10 &
+CUDA_VISIBLE_DEVICES='3' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11 &
+CUDA_VISIBLE_DEVICES='4' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 &
+CUDA_VISIBLE_DEVICES='5' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 &
+CUDA_VISIBLE_DEVICES='6' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 &
+CUDA_VISIBLE_DEVICES='7' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 &
+wait;
diff --git a/research/mi_poison_2022/train_poison.py b/research/mi_poison_2022/train_poison.py
new file mode 100644
index 0000000..f2afc2c
--- /dev/null
+++ b/research/mi_poison_2022/train_poison.py
@@ -0,0 +1,214 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: skip-file
+# pyformat: disable
+
+import os
+import shutil
+import json
+
+import numpy as np
+import tensorflow as tf # For data augmentation.
+import tensorflow_datasets as tfds
+from absl import app, flags
+
+from objax.util import EasyDict
+
+# from mi_lira_2021
+from dataset import DataSet
+from train import augment, MemModule, network
+
+FLAGS = flags.FLAGS
+
+
+def get_data(seed):
+ """
+ This is the function to generate subsets of the data for training models.
+
+ First, we get the training dataset either from the numpy cache
+ or otherwise we load it from tensorflow datasets.
+
+ Then, we compute the subset. This works in one of two ways.
+
+ 1. If we have a seed, then we just randomly choose examples based on
+ a prng with that seed, keeping FLAGS.pkeep fraction of the data.
+
+ 2. Otherwise, if we have an experiment ID, then we do something fancier.
+ If we run each experiment independently then even after a lot of trials
+ there will still probably be some examples that were always included
+ or always excluded. So instead, with experiment IDs, we guarantee that
+ after FLAGS.num_experiments are done, each example is seen exactly half
+ of the time in train, and half of the time not in train.
+
+ Finally, we add some poisons. The same poisoned samples are added for
+ each randomly generated training set.
+ We first select FLAGS.num_poison_targets victim points that will be targeted
+ by the poisoning attack. For each of these victim points, the attacker will
+ insert FLAGS.poison_reps mislabeled replicas of the point into the training
+ set.
+
+ For CIFAR-10, we recommend that:
+
+ `FLAGS.num_poison_targets * FLAGS.poison_reps < 5000`
+
+ Otherwise, the poisons might introduce too much label noise and the model's
+ accuracy (and the attack's success rate) will be degraded.
+ """
+ DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
+
+ if os.path.exists(os.path.join(FLAGS.logdir, "x_train.npy")):
+ inputs = np.load(os.path.join(FLAGS.logdir, "x_train.npy"))
+ labels = np.load(os.path.join(FLAGS.logdir, "y_train.npy"))
+ else:
+ print("First time, creating dataset")
+ data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR))
+ inputs = data['train']['image']
+ labels = data['train']['label']
+
+ inputs = (inputs/127.5)-1
+ np.save(os.path.join(FLAGS.logdir, "x_train.npy"), inputs)
+ np.save(os.path.join(FLAGS.logdir, "y_train.npy"), labels)
+
+ nclass = np.max(labels)+1
+
+ np.random.seed(seed)
+ if FLAGS.num_experiments is not None:
+ np.random.seed(0)
+ keep = np.random.uniform(0, 1, size=(FLAGS.num_experiments, len(inputs)))
+ order = keep.argsort(0)
+ keep = order < int(FLAGS.pkeep * FLAGS.num_experiments)
+ keep = np.array(keep[FLAGS.expid], dtype=bool)
+ else:
+ keep = np.random.uniform(0, 1, size=len(inputs)) <= FLAGS.pkeep
+
+ xs = inputs[keep]
+ ys = labels[keep]
+
+ if FLAGS.num_poison_targets > 0:
+
+ # select some points as targets
+ np.random.seed(FLAGS.poison_pos_seed)
+ poison_pos = np.random.choice(len(inputs), size=FLAGS.num_poison_targets, replace=False)
+
+ # create mislabeled poisons for the targeted points and replicate each
+ # poison `FLAGS.poison_reps` times
+ y_noise = np.mod(labels[poison_pos] + np.random.randint(low=1, high=nclass, size=FLAGS.num_poison_targets), nclass)
+ ypoison = np.repeat(y_noise, FLAGS.poison_reps)
+ xpoison = np.repeat(inputs[poison_pos], FLAGS.poison_reps, axis=0)
+ xs = np.concatenate((xs, xpoison), axis=0)
+ ys = np.concatenate((ys, ypoison), axis=0)
+
+ if not os.path.exists(os.path.join(FLAGS.logdir, "poison_pos.npy")):
+ np.save(os.path.join(FLAGS.logdir, "poison_pos.npy"), poison_pos)
+
+ if FLAGS.augment == 'weak':
+ aug = lambda x: augment(x, 4)
+ elif FLAGS.augment == 'mirror':
+ aug = lambda x: augment(x, 0)
+ elif FLAGS.augment == 'none':
+ aug = lambda x: augment(x, 0, mirror=False)
+ else:
+ raise
+
+ print(xs.shape, ys.shape)
+ train = DataSet.from_arrays(xs, ys,
+ augment_fn=aug)
+ test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:])
+ train = train.cache().shuffle(len(xs)).repeat().parse().augment().batch(FLAGS.batch)
+ train = train.nchw().one_hot(nclass).prefetch(FLAGS.batch)
+ test = test.cache().parse().batch(FLAGS.batch).nchw().prefetch(FLAGS.batch)
+
+ return train, test, xs, ys, keep, nclass
+
+
+def main(argv):
+ del argv
+ tf.config.experimental.set_visible_devices([], "GPU")
+
+ seed = FLAGS.seed
+ if seed is None:
+ import time
+ seed = np.random.randint(0, 1000000000)
+ seed ^= int(time.time())
+
+ args = EasyDict(arch=FLAGS.arch,
+ lr=FLAGS.lr,
+ batch=FLAGS.batch,
+ weight_decay=FLAGS.weight_decay,
+ augment=FLAGS.augment,
+ seed=seed)
+
+ if FLAGS.expid is not None:
+ logdir = "experiment-%d_%d" % (FLAGS.expid, FLAGS.num_experiments)
+ else:
+ logdir = "experiment-"+str(seed)
+ logdir = os.path.join(FLAGS.logdir, logdir)
+
+ if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz" % FLAGS.epochs)):
+ print(f"run {FLAGS.expid} already completed.")
+ return
+ else:
+ if os.path.exists(logdir):
+ print(f"deleting run {FLAGS.expid} that did not complete.")
+ shutil.rmtree(logdir)
+
+ print(f"starting run {FLAGS.expid}.")
+ if not os.path.exists(logdir):
+ os.makedirs(logdir)
+
+ train, test, xs, ys, keep, nclass = get_data(seed)
+
+ # Define the network and train_it
+ tm = MemModule(network(FLAGS.arch), nclass=nclass,
+ mnist=FLAGS.dataset == 'mnist',
+ epochs=FLAGS.epochs,
+ expid=FLAGS.expid,
+ num_experiments=FLAGS.num_experiments,
+ pkeep=FLAGS.pkeep,
+ save_steps=FLAGS.save_steps,
+ **args
+ )
+
+ r = {}
+ r.update(tm.params)
+
+ open(os.path.join(logdir, 'hparams.json'), "w").write(json.dumps(tm.params))
+ np.save(os.path.join(logdir,'keep.npy'), keep)
+
+ tm.train(FLAGS.epochs, len(xs), train, test, logdir,
+ save_steps=FLAGS.save_steps)
+
+
+if __name__ == '__main__':
+ flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.')
+ flags.DEFINE_float('lr', 0.1, 'Learning rate.')
+ flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
+ flags.DEFINE_float('weight_decay', 0.0005, 'Weight decay ratio.')
+ flags.DEFINE_integer('batch', 256, 'Batch size')
+ flags.DEFINE_integer('epochs', 100, 'Training duration in number of epochs.')
+ flags.DEFINE_string('logdir', 'experiments', 'Directory where to save checkpoints and tensorboard data.')
+ flags.DEFINE_integer('seed', None, 'Training seed.')
+ flags.DEFINE_float('pkeep', .5, 'Probability to keep examples.')
+ flags.DEFINE_integer('expid', None, 'Experiment ID')
+ flags.DEFINE_integer('num_experiments', None, 'Number of experiments')
+ flags.DEFINE_string('augment', 'weak', 'Strong or weak augmentation')
+ flags.DEFINE_integer('eval_steps', 1, 'how often to get eval accuracy.')
+ flags.DEFINE_integer('save_steps', 10, 'how often to get save model.')
+
+ flags.DEFINE_integer('num_poison_targets', 250, 'Number of points to target '
+ 'with the poisoning attack.')
+ flags.DEFINE_integer('poison_reps', 8, 'Number of times to repeat each poison.')
+ flags.DEFINE_integer('poison_pos_seed', 0, '')
+ app.run(main)