COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/privacy/pull/234 from ftramer:truth_serum fe44a0713952ef1615abf032947082eb5c082836

PiperOrigin-RevId: 447573314
This commit is contained in:
A. Unique TensorFlower 2022-05-09 15:04:33 -07:00 committed by Steve Chien
parent 137f795352
commit 97eec1a8e3
11 changed files with 581 additions and 70 deletions

View file

@ -2,10 +2,9 @@
This directory contains code to reproduce our paper:
**"Membership Inference Attacks From First Principles"**
https://arxiv.org/abs/2112.03570
by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramer.
**"Membership Inference Attacks From First Principles"** <br>
https://arxiv.org/abs/2112.03570 <br>
by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramèr.
### INSTALLING
@ -18,21 +17,20 @@ with JAX + ObJAX so you will need to follow build instructions for that
https://github.com/google/objax
https://objax.readthedocs.io/en/latest/installation_setup.html
### RUNNING THE CODE
#### 1. Train the models
The first step in our attack is to train shadow models. As a baseline
that should give most of the gains in our attack, you should start by
training 16 shadow models with the command
The first step in our attack is to train shadow models. As a baseline that
should give most of the gains in our attack, you should start by training 16
shadow models with the command
> bash scripts/train_demo.sh
or if you have multiple GPUs on your machine and want to train these models
in parallel, then modify and run
or if you have multiple GPUs on your machine and want to train these models in
parallel, then modify and run
> bash scripts/train_demo_multigpu.sh
> bash scripts/train_demo_multigpu.sh
This will train several CIFAR-10 wide ResNet models to ~91% accuracy each, and
will output a bunch of files under the directory exp/cifar10 with structure:
@ -63,14 +61,13 @@ exp/cifar10/
--- 0000000100.npy
```
where this new file has shape (50000, 10) and stores the model's
output features for each example.
where this new file has shape (50000, 10) and stores the model's output features
for each example.
#### 3. Compute membership inference scores
Finally we take the output features and generate our logit-scaled membership inference
scores for each example for each model.
Finally we take the output features and generate our logit-scaled membership
inference scores for each example for each model.
> python3 score.py exp/cifar10/
@ -85,7 +82,6 @@ exp/cifar10/
with shape (50000,) storing just our scores.
### PLOTTING THE RESULTS
Finally we can generate pretty pictures, and run the plotting code
@ -94,7 +90,6 @@ Finally we can generate pretty pictures, and run the plotting code
which should give (something like) the following output
![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve")
```
@ -111,9 +106,9 @@ Attack Global threshold
```
where the global threshold attack is the baseline, and our online,
online-with-fixed-variance, offline, and offline-with-fixed-variance
attack variants are the four other curves. Note that because we only
train a few models, the fixed variance variants perform best.
online-with-fixed-variance, offline, and offline-with-fixed-variance attack
variants are the four other curves. Note that because we only train a few
models, the fixed variance variants perform best.
### Citation
@ -126,4 +121,4 @@ You can cite this paper with
journal={arXiv preprint arXiv:2112.03570},
year={2021}
}
```
```

View file

@ -12,32 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
from typing import Callable
# pylint: skip-file
# pyformat: disable
import json
import os
import re
import jax
import jax.numpy as jn
import numpy as np
import tensorflow as tf # For data augmentation.
import tensorflow_datasets as tfds
from absl import app, flags
from tqdm import tqdm, trange
import pickle
from functools import partial
import objax
from objax.jaxboard import SummaryWriter, Summary
from objax.util import EasyDict
from objax.zoo import convnet, wide_resnet
import tensorflow as tf # For data augmentation.
from absl import app
from absl import flags
from dataset import DataSet
from train import MemModule
from train import network
from train import MemModule, network
from collections import defaultdict
FLAGS = flags.FLAGS
@ -56,7 +46,7 @@ def main(argv):
lr=.1,
batch=0,
epochs=0,
weight_decay=0)
weight_decay=0)
def cache_load(arch):
thing = []
@ -68,8 +58,8 @@ def main(argv):
xs_all = np.load(os.path.join(FLAGS.logdir,"x_train.npy"))[:FLAGS.dataset_size]
ys_all = np.load(os.path.join(FLAGS.logdir,"y_train.npy"))[:FLAGS.dataset_size]
def get_loss(model, xbatch, ybatch, shift, reflect=True, stride=1):
outs = []
@ -90,7 +80,7 @@ def main(argv):
def features(model, xbatch, ybatch):
return get_loss(model, xbatch, ybatch,
shift=0, reflect=True, stride=1)
for path in sorted(os.listdir(os.path.join(FLAGS.logdir))):
if re.search(FLAGS.regex, path) is None:
print("Skipping from regex")
@ -99,9 +89,9 @@ def main(argv):
hparams = json.load(open(os.path.join(FLAGS.logdir, path, "hparams.json")))
arch = hparams['arch']
model = cache_load(arch)()
logdir = os.path.join(FLAGS.logdir, path)
checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=10, makedir=True)
max_epoch, last_ckpt = checkpoint.restore(model.vars())
if max_epoch == 0: continue
@ -112,12 +102,12 @@ def main(argv):
first = FLAGS.from_epoch
else:
first = max_epoch-1
for epoch in range(first,max_epoch+1):
if not os.path.exists(os.path.join(FLAGS.logdir, path, "ckpt", "%010d.npz"%epoch)):
# no checkpoint saved here
continue
if os.path.exists(os.path.join(FLAGS.logdir, path, "logits", "%010d.npy"%epoch)):
print("Skipping already generated file", epoch)
continue
@ -127,7 +117,7 @@ def main(argv):
except:
print("Fail to load", epoch)
continue
stats = []
for i in range(0,len(xs_all),N):
@ -142,9 +132,7 @@ if __name__ == '__main__':
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
flags.DEFINE_string('logdir', 'experiments/', 'Directory where to save checkpoints and tensorboard data.')
flags.DEFINE_string('regex', '.*experiment.*', 'keep files when matching')
flags.DEFINE_bool('random_labels', False, 'use random labels.')
flags.DEFINE_integer('dataset_size', 50000, 'size of dataset.')
flags.DEFINE_integer('from_epoch', None, 'which epoch to load from.')
flags.DEFINE_integer('seed_mod', None, 'keep mod seed.')
flags.DEFINE_integer('modulus', 8, 'modulus.')
app.run(main)

View file

@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
# pyformat: disable
import os
import scipy.stats
@ -113,7 +116,7 @@ def generate_ours_offline(keep, scores, check_keep, check_scores, in_size=100000
dat_out.append(scores[~keep[:, j], j, :])
out_size = min(min(map(len,dat_out)), out_size)
dat_out = np.array([x[:out_size] for x in dat_out])
mean_out = np.median(dat_out, 1)
@ -160,7 +163,7 @@ def do_plot(fn, keep, scores, ntest, legend='', metric='auc', sweep_fn=sweep, **
fpr, tpr, auc, acc = sweep_fn(np.array(prediction), np.array(answers, dtype=bool))
low = tpr[np.where(fpr<.001)[0][-1]]
print('Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc,acc, low))
metric_text = ''
@ -206,7 +209,7 @@ def fig_fpr_tpr():
"Global threshold\n",
metric='auc'
)
plt.semilogx()
plt.semilogy()
plt.xlim(1e-5,1)
@ -220,5 +223,6 @@ def fig_fpr_tpr():
plt.show()
load_data("exp/cifar10/")
fig_fpr_tpr()
if __name__ == '__main__':
load_data("exp/cifar10/")
fig_fpr_tpr()

View file

@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
# pyformat: disable
import functools
import os
import shutil
@ -24,12 +27,11 @@ import numpy as np
import tensorflow as tf # For data augmentation.
import tensorflow_datasets as tfds
from absl import app, flags
from tqdm import tqdm, trange
import objax
from objax.jaxboard import SummaryWriter, Summary
from objax.util import EasyDict
from objax.zoo import convnet, wide_resnet, dnnet
from objax.zoo import convnet, wide_resnet
from dataset import DataSet
@ -202,11 +204,11 @@ def get_data(seed):
data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR))
inputs = data['train']['image']
labels = data['train']['label']
inputs = (inputs/127.5)-1
np.save(os.path.join(FLAGS.logdir, "x_train.npy"),inputs)
np.save(os.path.join(FLAGS.logdir, "y_train.npy"),labels)
nclass = np.max(labels)+1
np.random.seed(seed)
@ -233,7 +235,7 @@ def get_data(seed):
aug = lambda x: augment(x, 0, mirror=False)
else:
raise
train = DataSet.from_arrays(xs, ys,
augment_fn=aug)
test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:])
@ -252,7 +254,7 @@ def main(argv):
import time
seed = np.random.randint(0, 1000000000)
seed ^= int(time.time())
args = EasyDict(arch=FLAGS.arch,
lr=FLAGS.lr,
batch=FLAGS.batch,
@ -260,7 +262,7 @@ def main(argv):
augment=FLAGS.augment,
seed=seed)
if FLAGS.tunename:
logdir = '_'.join(sorted('%s=%s' % k for k in args.items()))
elif FLAGS.expid is not None:
@ -269,7 +271,7 @@ def main(argv):
logdir = "experiment-"+str(seed)
logdir = os.path.join(FLAGS.logdir, logdir)
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%10)):
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%FLAGS.epochs)):
print(f"run {FLAGS.expid} already completed.")
return
else:
@ -282,7 +284,7 @@ def main(argv):
os.makedirs(logdir)
train, test, xs, ys, keep, nclass = get_data(seed)
# Define the network and train_it
tm = MemModule(network(FLAGS.arch), nclass=nclass,
mnist=FLAGS.dataset == 'mnist',
@ -303,8 +305,8 @@ def main(argv):
tm.train(FLAGS.epochs, len(xs), train, test, logdir,
save_steps=FLAGS.save_steps, patience=FLAGS.patience)
if __name__ == '__main__':
flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.')
@ -327,3 +329,4 @@ if __name__ == '__main__':
flags.DEFINE_integer('patience', None, 'Early stopping after this many epochs without progress')
flags.DEFINE_bool('tunename', False, 'Use tune name?')
app.run(main)

View file

@ -0,0 +1,116 @@
## Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets
This directory contains code to reproduce results from the paper:
**"Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets"**<br>
https://arxiv.org/abs/2204.00032 <br>
by Florian Tramèr, Reza Shokri, Ayrton San Joaquin, Hoang Le, Matthew Jagielski, Sanghyun Hong and Nicholas Carlini
### INSTALLING
The experiments in this directory are built on top of the
[LiRA membership inference attack](../mi_lira_2021).
After following the [installation instructions](../mi_lira_2021#installing) for
LiRa, make sure the attack code is on your `PYTHONPATH`:
```bash
export PYTHONPATH="${PYTHONPATH}:../mi_lira_2021"
```
### RUNNING THE CODE
#### 1. Train the models
The first step in our attack is to train shadow models, with some data points
targeted by a poisoning attack. You can train 16 shadow models with the command
> bash scripts/train_demo.sh
or if you have multiple GPUs on your machine and want to train these models in
parallel, then modify and run
> bash scripts/train_demo_multigpu.sh
This will train 16 CIFAR-10 wide ResNet models to ~91% accuracy each, with 250
points targeted for poisoning. For each of these 250 targeted points, the
attacker adds 8 mislabeled poisoned copies of the point into the training set.
The training run will output a bunch of files under the directory exp/cifar10
with structure:
```
exp/cifar10/
- xtrain.npy
- ytain.npy
- poison_pos.npy
- experiment_N_of_16
-- hparams.json
-- keep.npy
-- ckpt/
--- 0000000100.npz
-- tb/
```
The following flags control the poisoning attack:
- `num_poison_targets (default=250)`. The number of targeted points.
- `poison_reps (default=8)`. The number of replicas per poison.
- `poison_pos_seed (default=0)`. The random seed to use to choose the target
points.
We recommend that `num_poison_targets * poison_reps < 5000` on CIFAR-10, as
otherwise the poisons introduce too much label noise and the model's accuracy
(and the attack's success rate) will be degraded.
#### 2. Perform inference and compute scores
Exactly as for LiRA, we then evaluate the models on the entire CIFAR-10 dataset,
and generate logit-scaled membership inference scores. See
[here](../mi_lira_2021#2-perform-inference) and
[here](../mi_lira_2021#3-compute-membership-inference-scores) for details.
```bash
python3 -m inference --logdir=exp/cifar10/
python3 -m score exp/cifar10/
```
### PLOTTING THE RESULTS
Finally we can generate pretty pictures, and run the plotting code
```bash
python3 plot_poison.py
```
which should give (something like) the following output
![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve")
```
Attack No poison (LiRA)
AUC 0.7025, Accuracy 0.6258, TPR@0.1%FPR of 0.0544
Attack No poison (Global threshold)
AUC 0.6191, Accuracy 0.6173, TPR@0.1%FPR of 0.0012
Attack With poison (LiRA)
AUC 0.9943, Accuracy 0.9653, TPR@0.1%FPR of 0.4945
Attack With poison (Global threshold)
AUC 0.9922, Accuracy 0.9603, TPR@0.1%FPR of 0.3930
```
where the baselines are LiRA and a simple global threshold on the membership
scores, both without poisoning. With poisoning, both LiRA and the global
threshold attack are boosted significantly. Note that because we only train a
few models, we use the fixed variance variant of LiRA.
### Citation
You can cite this paper with
```
@article{tramer2022truth,
title={Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets},
author={Tramer, Florian and Shokri, Reza and San Joaquin, Ayrton and Le, Hoang and Jagielski, Matthew and Hong, Sanghyun and Carlini, Nicholas},
journal={arXiv preprint arXiv:2204.00032},
year={2022}
}
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

View file

@ -0,0 +1,13 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,116 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
# pyformat: disable
import os
import numpy as np
import matplotlib.pyplot as plt
import functools
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
# from mi_lira_2021
from plot import sweep, load_data, generate_ours, generate_global
def do_plot_all(fn, keep, scores, legend='', metric='auc', sweep_fn=sweep, **plot_kwargs):
"""
Generate the ROC curves by using one model as test model and the rest to train,
with a full leave-one-out cross-validation.
"""
all_predictions = []
all_answers = []
for i in range(0, len(keep)):
mask = np.zeros(len(keep), dtype=bool)
mask[i:i+1] = True
prediction, answers = fn(keep[~mask],
scores[~mask],
keep[mask],
scores[mask])
all_predictions.extend(prediction)
all_answers.extend(answers)
fpr, tpr, auc, acc = sweep_fn(np.array(all_predictions),
np.array(all_answers, dtype=bool))
low = tpr[np.where(fpr < .001)[0][-1]]
print('Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc, acc, low))
metric_text = ''
if metric == 'auc':
metric_text = 'auc=%.3f' % auc
elif metric == 'acc':
metric_text = 'acc=%.3f' % acc
plt.plot(fpr, tpr, label=legend+metric_text, **plot_kwargs)
return acc, auc
def fig_fpr_tpr(poison_mask, scores, keep):
plt.figure(figsize=(4, 3))
# evaluate LiRA on the points that were not targeted by poisoning
do_plot_all(functools.partial(generate_ours, fix_variance=True),
keep[:, ~poison_mask], scores[:, ~poison_mask],
"No poison (LiRA)\n",
metric='auc',
)
# evaluate the global-threshold attack on the points that were not targeted by poisoning
do_plot_all(generate_global,
keep[:, ~poison_mask], scores[:, ~poison_mask],
"No poison (Global threshold)\n",
metric='auc', ls="--", c=plt.gca().lines[-1].get_color()
)
# evaluate LiRA on the points that were targeted by poisoning
do_plot_all(functools.partial(generate_ours, fix_variance=True),
keep[:, poison_mask], scores[:, poison_mask],
"With poison (LiRA)\n",
metric='auc',
)
# evaluate the global-threshold attack on the points that were targeted by poisoning
do_plot_all(generate_global,
keep[:, poison_mask], scores[:, poison_mask],
"With poison (Global threshold)\n",
metric='auc', ls="--", c=plt.gca().lines[-1].get_color()
)
plt.semilogx()
plt.semilogy()
plt.xlim(1e-3, 1)
plt.ylim(1e-3, 1)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.plot([0, 1], [0, 1], ls='--', color='gray')
plt.subplots_adjust(bottom=.18, left=.18, top=.96, right=.96)
plt.legend(fontsize=8)
plt.savefig("/tmp/fprtpr.png")
plt.show()
if __name__ == '__main__':
logdir = "exp/cifar10/"
scores, keep = load_data(logdir)
poison_pos = np.load(os.path.join(logdir, "poison_pos.npy"))
poison_mask = np.zeros(scores.shape[1], dtype=bool)
poison_mask[poison_pos] = True
fig_fpr_tpr(poison_mask, scores, keep)

View file

@ -0,0 +1,30 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15

View file

@ -0,0 +1,32 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0 &
CUDA_VISIBLE_DEVICES='1' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1 &
CUDA_VISIBLE_DEVICES='2' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2 &
CUDA_VISIBLE_DEVICES='3' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3 &
CUDA_VISIBLE_DEVICES='4' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 &
CUDA_VISIBLE_DEVICES='5' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 &
CUDA_VISIBLE_DEVICES='6' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 &
CUDA_VISIBLE_DEVICES='7' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 &
wait;
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8 &
CUDA_VISIBLE_DEVICES='1' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9 &
CUDA_VISIBLE_DEVICES='2' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10 &
CUDA_VISIBLE_DEVICES='3' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11 &
CUDA_VISIBLE_DEVICES='4' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 &
CUDA_VISIBLE_DEVICES='5' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 &
CUDA_VISIBLE_DEVICES='6' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 &
CUDA_VISIBLE_DEVICES='7' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 &
wait;

View file

@ -0,0 +1,214 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
# pyformat: disable
import os
import shutil
import json
import numpy as np
import tensorflow as tf # For data augmentation.
import tensorflow_datasets as tfds
from absl import app, flags
from objax.util import EasyDict
# from mi_lira_2021
from dataset import DataSet
from train import augment, MemModule, network
FLAGS = flags.FLAGS
def get_data(seed):
"""
This is the function to generate subsets of the data for training models.
First, we get the training dataset either from the numpy cache
or otherwise we load it from tensorflow datasets.
Then, we compute the subset. This works in one of two ways.
1. If we have a seed, then we just randomly choose examples based on
a prng with that seed, keeping FLAGS.pkeep fraction of the data.
2. Otherwise, if we have an experiment ID, then we do something fancier.
If we run each experiment independently then even after a lot of trials
there will still probably be some examples that were always included
or always excluded. So instead, with experiment IDs, we guarantee that
after FLAGS.num_experiments are done, each example is seen exactly half
of the time in train, and half of the time not in train.
Finally, we add some poisons. The same poisoned samples are added for
each randomly generated training set.
We first select FLAGS.num_poison_targets victim points that will be targeted
by the poisoning attack. For each of these victim points, the attacker will
insert FLAGS.poison_reps mislabeled replicas of the point into the training
set.
For CIFAR-10, we recommend that:
`FLAGS.num_poison_targets * FLAGS.poison_reps < 5000`
Otherwise, the poisons might introduce too much label noise and the model's
accuracy (and the attack's success rate) will be degraded.
"""
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
if os.path.exists(os.path.join(FLAGS.logdir, "x_train.npy")):
inputs = np.load(os.path.join(FLAGS.logdir, "x_train.npy"))
labels = np.load(os.path.join(FLAGS.logdir, "y_train.npy"))
else:
print("First time, creating dataset")
data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR))
inputs = data['train']['image']
labels = data['train']['label']
inputs = (inputs/127.5)-1
np.save(os.path.join(FLAGS.logdir, "x_train.npy"), inputs)
np.save(os.path.join(FLAGS.logdir, "y_train.npy"), labels)
nclass = np.max(labels)+1
np.random.seed(seed)
if FLAGS.num_experiments is not None:
np.random.seed(0)
keep = np.random.uniform(0, 1, size=(FLAGS.num_experiments, len(inputs)))
order = keep.argsort(0)
keep = order < int(FLAGS.pkeep * FLAGS.num_experiments)
keep = np.array(keep[FLAGS.expid], dtype=bool)
else:
keep = np.random.uniform(0, 1, size=len(inputs)) <= FLAGS.pkeep
xs = inputs[keep]
ys = labels[keep]
if FLAGS.num_poison_targets > 0:
# select some points as targets
np.random.seed(FLAGS.poison_pos_seed)
poison_pos = np.random.choice(len(inputs), size=FLAGS.num_poison_targets, replace=False)
# create mislabeled poisons for the targeted points and replicate each
# poison `FLAGS.poison_reps` times
y_noise = np.mod(labels[poison_pos] + np.random.randint(low=1, high=nclass, size=FLAGS.num_poison_targets), nclass)
ypoison = np.repeat(y_noise, FLAGS.poison_reps)
xpoison = np.repeat(inputs[poison_pos], FLAGS.poison_reps, axis=0)
xs = np.concatenate((xs, xpoison), axis=0)
ys = np.concatenate((ys, ypoison), axis=0)
if not os.path.exists(os.path.join(FLAGS.logdir, "poison_pos.npy")):
np.save(os.path.join(FLAGS.logdir, "poison_pos.npy"), poison_pos)
if FLAGS.augment == 'weak':
aug = lambda x: augment(x, 4)
elif FLAGS.augment == 'mirror':
aug = lambda x: augment(x, 0)
elif FLAGS.augment == 'none':
aug = lambda x: augment(x, 0, mirror=False)
else:
raise
print(xs.shape, ys.shape)
train = DataSet.from_arrays(xs, ys,
augment_fn=aug)
test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:])
train = train.cache().shuffle(len(xs)).repeat().parse().augment().batch(FLAGS.batch)
train = train.nchw().one_hot(nclass).prefetch(FLAGS.batch)
test = test.cache().parse().batch(FLAGS.batch).nchw().prefetch(FLAGS.batch)
return train, test, xs, ys, keep, nclass
def main(argv):
del argv
tf.config.experimental.set_visible_devices([], "GPU")
seed = FLAGS.seed
if seed is None:
import time
seed = np.random.randint(0, 1000000000)
seed ^= int(time.time())
args = EasyDict(arch=FLAGS.arch,
lr=FLAGS.lr,
batch=FLAGS.batch,
weight_decay=FLAGS.weight_decay,
augment=FLAGS.augment,
seed=seed)
if FLAGS.expid is not None:
logdir = "experiment-%d_%d" % (FLAGS.expid, FLAGS.num_experiments)
else:
logdir = "experiment-"+str(seed)
logdir = os.path.join(FLAGS.logdir, logdir)
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz" % FLAGS.epochs)):
print(f"run {FLAGS.expid} already completed.")
return
else:
if os.path.exists(logdir):
print(f"deleting run {FLAGS.expid} that did not complete.")
shutil.rmtree(logdir)
print(f"starting run {FLAGS.expid}.")
if not os.path.exists(logdir):
os.makedirs(logdir)
train, test, xs, ys, keep, nclass = get_data(seed)
# Define the network and train_it
tm = MemModule(network(FLAGS.arch), nclass=nclass,
mnist=FLAGS.dataset == 'mnist',
epochs=FLAGS.epochs,
expid=FLAGS.expid,
num_experiments=FLAGS.num_experiments,
pkeep=FLAGS.pkeep,
save_steps=FLAGS.save_steps,
**args
)
r = {}
r.update(tm.params)
open(os.path.join(logdir, 'hparams.json'), "w").write(json.dumps(tm.params))
np.save(os.path.join(logdir,'keep.npy'), keep)
tm.train(FLAGS.epochs, len(xs), train, test, logdir,
save_steps=FLAGS.save_steps)
if __name__ == '__main__':
flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.')
flags.DEFINE_float('lr', 0.1, 'Learning rate.')
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
flags.DEFINE_float('weight_decay', 0.0005, 'Weight decay ratio.')
flags.DEFINE_integer('batch', 256, 'Batch size')
flags.DEFINE_integer('epochs', 100, 'Training duration in number of epochs.')
flags.DEFINE_string('logdir', 'experiments', 'Directory where to save checkpoints and tensorboard data.')
flags.DEFINE_integer('seed', None, 'Training seed.')
flags.DEFINE_float('pkeep', .5, 'Probability to keep examples.')
flags.DEFINE_integer('expid', None, 'Experiment ID')
flags.DEFINE_integer('num_experiments', None, 'Number of experiments')
flags.DEFINE_string('augment', 'weak', 'Strong or weak augmentation')
flags.DEFINE_integer('eval_steps', 1, 'how often to get eval accuracy.')
flags.DEFINE_integer('save_steps', 10, 'how often to get save model.')
flags.DEFINE_integer('num_poison_targets', 250, 'Number of points to target '
'with the poisoning attack.')
flags.DEFINE_integer('poison_reps', 8, 'Number of times to repeat each poison.')
flags.DEFINE_integer('poison_pos_seed', 0, '')
app.run(main)