forked from 626_privacy/tensorflow_privacy
COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/privacy/pull/234 from ftramer:truth_serum fe44a0713952ef1615abf032947082eb5c082836
PiperOrigin-RevId: 447573314
This commit is contained in:
parent
137f795352
commit
97eec1a8e3
11 changed files with 581 additions and 70 deletions
|
@ -2,10 +2,9 @@
|
|||
|
||||
This directory contains code to reproduce our paper:
|
||||
|
||||
**"Membership Inference Attacks From First Principles"**
|
||||
https://arxiv.org/abs/2112.03570
|
||||
by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramer.
|
||||
|
||||
**"Membership Inference Attacks From First Principles"** <br>
|
||||
https://arxiv.org/abs/2112.03570 <br>
|
||||
by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramèr.
|
||||
|
||||
### INSTALLING
|
||||
|
||||
|
@ -18,19 +17,18 @@ with JAX + ObJAX so you will need to follow build instructions for that
|
|||
https://github.com/google/objax
|
||||
https://objax.readthedocs.io/en/latest/installation_setup.html
|
||||
|
||||
|
||||
### RUNNING THE CODE
|
||||
|
||||
#### 1. Train the models
|
||||
|
||||
The first step in our attack is to train shadow models. As a baseline
|
||||
that should give most of the gains in our attack, you should start by
|
||||
training 16 shadow models with the command
|
||||
The first step in our attack is to train shadow models. As a baseline that
|
||||
should give most of the gains in our attack, you should start by training 16
|
||||
shadow models with the command
|
||||
|
||||
> bash scripts/train_demo.sh
|
||||
|
||||
or if you have multiple GPUs on your machine and want to train these models
|
||||
in parallel, then modify and run
|
||||
or if you have multiple GPUs on your machine and want to train these models in
|
||||
parallel, then modify and run
|
||||
|
||||
> bash scripts/train_demo_multigpu.sh
|
||||
|
||||
|
@ -63,14 +61,13 @@ exp/cifar10/
|
|||
--- 0000000100.npy
|
||||
```
|
||||
|
||||
where this new file has shape (50000, 10) and stores the model's
|
||||
output features for each example.
|
||||
|
||||
where this new file has shape (50000, 10) and stores the model's output features
|
||||
for each example.
|
||||
|
||||
#### 3. Compute membership inference scores
|
||||
|
||||
Finally we take the output features and generate our logit-scaled membership inference
|
||||
scores for each example for each model.
|
||||
Finally we take the output features and generate our logit-scaled membership
|
||||
inference scores for each example for each model.
|
||||
|
||||
> python3 score.py exp/cifar10/
|
||||
|
||||
|
@ -85,7 +82,6 @@ exp/cifar10/
|
|||
|
||||
with shape (50000,) storing just our scores.
|
||||
|
||||
|
||||
### PLOTTING THE RESULTS
|
||||
|
||||
Finally we can generate pretty pictures, and run the plotting code
|
||||
|
@ -94,7 +90,6 @@ Finally we can generate pretty pictures, and run the plotting code
|
|||
|
||||
which should give (something like) the following output
|
||||
|
||||
|
||||
![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve")
|
||||
|
||||
```
|
||||
|
@ -111,9 +106,9 @@ Attack Global threshold
|
|||
```
|
||||
|
||||
where the global threshold attack is the baseline, and our online,
|
||||
online-with-fixed-variance, offline, and offline-with-fixed-variance
|
||||
attack variants are the four other curves. Note that because we only
|
||||
train a few models, the fixed variance variants perform best.
|
||||
online-with-fixed-variance, offline, and offline-with-fixed-variance attack
|
||||
variants are the four other curves. Note that because we only train a few
|
||||
models, the fixed variance variants perform best.
|
||||
|
||||
### Citation
|
||||
|
||||
|
|
|
@ -12,32 +12,22 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Callable
|
||||
# pylint: skip-file
|
||||
# pyformat: disable
|
||||
|
||||
import json
|
||||
|
||||
import os
|
||||
import re
|
||||
import jax
|
||||
import jax.numpy as jn
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # For data augmentation.
|
||||
import tensorflow_datasets as tfds
|
||||
from absl import app, flags
|
||||
from tqdm import tqdm, trange
|
||||
import pickle
|
||||
from functools import partial
|
||||
|
||||
import objax
|
||||
from objax.jaxboard import SummaryWriter, Summary
|
||||
from objax.util import EasyDict
|
||||
from objax.zoo import convnet, wide_resnet
|
||||
import tensorflow as tf # For data augmentation.
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from dataset import DataSet
|
||||
from train import MemModule
|
||||
from train import network
|
||||
|
||||
from train import MemModule, network
|
||||
|
||||
from collections import defaultdict
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
|
@ -142,9 +132,7 @@ if __name__ == '__main__':
|
|||
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
|
||||
flags.DEFINE_string('logdir', 'experiments/', 'Directory where to save checkpoints and tensorboard data.')
|
||||
flags.DEFINE_string('regex', '.*experiment.*', 'keep files when matching')
|
||||
flags.DEFINE_bool('random_labels', False, 'use random labels.')
|
||||
flags.DEFINE_integer('dataset_size', 50000, 'size of dataset.')
|
||||
flags.DEFINE_integer('from_epoch', None, 'which epoch to load from.')
|
||||
flags.DEFINE_integer('seed_mod', None, 'keep mod seed.')
|
||||
flags.DEFINE_integer('modulus', 8, 'modulus.')
|
||||
app.run(main)
|
||||
|
||||
|
|
|
@ -12,6 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# pylint: skip-file
|
||||
# pyformat: disable
|
||||
|
||||
import os
|
||||
import scipy.stats
|
||||
|
||||
|
@ -220,5 +223,6 @@ def fig_fpr_tpr():
|
|||
plt.show()
|
||||
|
||||
|
||||
load_data("exp/cifar10/")
|
||||
fig_fpr_tpr()
|
||||
if __name__ == '__main__':
|
||||
load_data("exp/cifar10/")
|
||||
fig_fpr_tpr()
|
||||
|
|
|
@ -12,6 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# pylint: skip-file
|
||||
# pyformat: disable
|
||||
|
||||
import functools
|
||||
import os
|
||||
import shutil
|
||||
|
@ -24,12 +27,11 @@ import numpy as np
|
|||
import tensorflow as tf # For data augmentation.
|
||||
import tensorflow_datasets as tfds
|
||||
from absl import app, flags
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import objax
|
||||
from objax.jaxboard import SummaryWriter, Summary
|
||||
from objax.util import EasyDict
|
||||
from objax.zoo import convnet, wide_resnet, dnnet
|
||||
from objax.zoo import convnet, wide_resnet
|
||||
|
||||
from dataset import DataSet
|
||||
|
||||
|
@ -269,7 +271,7 @@ def main(argv):
|
|||
logdir = "experiment-"+str(seed)
|
||||
logdir = os.path.join(FLAGS.logdir, logdir)
|
||||
|
||||
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%10)):
|
||||
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%FLAGS.epochs)):
|
||||
print(f"run {FLAGS.expid} already completed.")
|
||||
return
|
||||
else:
|
||||
|
@ -327,3 +329,4 @@ if __name__ == '__main__':
|
|||
flags.DEFINE_integer('patience', None, 'Early stopping after this many epochs without progress')
|
||||
flags.DEFINE_bool('tunename', False, 'Use tune name?')
|
||||
app.run(main)
|
||||
|
||||
|
|
116
research/mi_poison_2022/README.md
Normal file
116
research/mi_poison_2022/README.md
Normal file
|
@ -0,0 +1,116 @@
|
|||
## Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets
|
||||
|
||||
This directory contains code to reproduce results from the paper:
|
||||
|
||||
**"Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets"**<br>
|
||||
https://arxiv.org/abs/2204.00032 <br>
|
||||
by Florian Tramèr, Reza Shokri, Ayrton San Joaquin, Hoang Le, Matthew Jagielski, Sanghyun Hong and Nicholas Carlini
|
||||
|
||||
### INSTALLING
|
||||
|
||||
The experiments in this directory are built on top of the
|
||||
[LiRA membership inference attack](../mi_lira_2021).
|
||||
|
||||
After following the [installation instructions](../mi_lira_2021#installing) for
|
||||
LiRa, make sure the attack code is on your `PYTHONPATH`:
|
||||
|
||||
```bash
|
||||
export PYTHONPATH="${PYTHONPATH}:../mi_lira_2021"
|
||||
```
|
||||
|
||||
### RUNNING THE CODE
|
||||
|
||||
#### 1. Train the models
|
||||
|
||||
The first step in our attack is to train shadow models, with some data points
|
||||
targeted by a poisoning attack. You can train 16 shadow models with the command
|
||||
|
||||
> bash scripts/train_demo.sh
|
||||
|
||||
or if you have multiple GPUs on your machine and want to train these models in
|
||||
parallel, then modify and run
|
||||
|
||||
> bash scripts/train_demo_multigpu.sh
|
||||
|
||||
This will train 16 CIFAR-10 wide ResNet models to ~91% accuracy each, with 250
|
||||
points targeted for poisoning. For each of these 250 targeted points, the
|
||||
attacker adds 8 mislabeled poisoned copies of the point into the training set.
|
||||
The training run will output a bunch of files under the directory exp/cifar10
|
||||
with structure:
|
||||
|
||||
```
|
||||
exp/cifar10/
|
||||
- xtrain.npy
|
||||
- ytain.npy
|
||||
- poison_pos.npy
|
||||
- experiment_N_of_16
|
||||
-- hparams.json
|
||||
-- keep.npy
|
||||
-- ckpt/
|
||||
--- 0000000100.npz
|
||||
-- tb/
|
||||
```
|
||||
|
||||
The following flags control the poisoning attack:
|
||||
|
||||
- `num_poison_targets (default=250)`. The number of targeted points.
|
||||
- `poison_reps (default=8)`. The number of replicas per poison.
|
||||
- `poison_pos_seed (default=0)`. The random seed to use to choose the target
|
||||
points.
|
||||
|
||||
We recommend that `num_poison_targets * poison_reps < 5000` on CIFAR-10, as
|
||||
otherwise the poisons introduce too much label noise and the model's accuracy
|
||||
(and the attack's success rate) will be degraded.
|
||||
|
||||
#### 2. Perform inference and compute scores
|
||||
|
||||
Exactly as for LiRA, we then evaluate the models on the entire CIFAR-10 dataset,
|
||||
and generate logit-scaled membership inference scores. See
|
||||
[here](../mi_lira_2021#2-perform-inference) and
|
||||
[here](../mi_lira_2021#3-compute-membership-inference-scores) for details.
|
||||
|
||||
```bash
|
||||
python3 -m inference --logdir=exp/cifar10/
|
||||
python3 -m score exp/cifar10/
|
||||
```
|
||||
|
||||
### PLOTTING THE RESULTS
|
||||
|
||||
Finally we can generate pretty pictures, and run the plotting code
|
||||
|
||||
```bash
|
||||
python3 plot_poison.py
|
||||
```
|
||||
|
||||
which should give (something like) the following output
|
||||
|
||||
![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve")
|
||||
|
||||
```
|
||||
Attack No poison (LiRA)
|
||||
AUC 0.7025, Accuracy 0.6258, TPR@0.1%FPR of 0.0544
|
||||
Attack No poison (Global threshold)
|
||||
AUC 0.6191, Accuracy 0.6173, TPR@0.1%FPR of 0.0012
|
||||
Attack With poison (LiRA)
|
||||
AUC 0.9943, Accuracy 0.9653, TPR@0.1%FPR of 0.4945
|
||||
Attack With poison (Global threshold)
|
||||
AUC 0.9922, Accuracy 0.9603, TPR@0.1%FPR of 0.3930
|
||||
```
|
||||
|
||||
where the baselines are LiRA and a simple global threshold on the membership
|
||||
scores, both without poisoning. With poisoning, both LiRA and the global
|
||||
threshold attack are boosted significantly. Note that because we only train a
|
||||
few models, we use the fixed variance variant of LiRA.
|
||||
|
||||
### Citation
|
||||
|
||||
You can cite this paper with
|
||||
|
||||
```
|
||||
@article{tramer2022truth,
|
||||
title={Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets},
|
||||
author={Tramer, Florian and Shokri, Reza and San Joaquin, Ayrton and Le, Hoang and Jagielski, Matthew and Hong, Sanghyun and Carlini, Nicholas},
|
||||
journal={arXiv preprint arXiv:2204.00032},
|
||||
year={2022}
|
||||
}
|
||||
```
|
BIN
research/mi_poison_2022/fprtpr.png
Normal file
BIN
research/mi_poison_2022/fprtpr.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 32 KiB |
13
research/mi_poison_2022/logs/.keep
Normal file
13
research/mi_poison_2022/logs/.keep
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
116
research/mi_poison_2022/plot_poison.py
Normal file
116
research/mi_poison_2022/plot_poison.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# pylint: skip-file
|
||||
# pyformat: disable
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import functools
|
||||
|
||||
import matplotlib
|
||||
matplotlib.rcParams['pdf.fonttype'] = 42
|
||||
matplotlib.rcParams['ps.fonttype'] = 42
|
||||
|
||||
# from mi_lira_2021
|
||||
from plot import sweep, load_data, generate_ours, generate_global
|
||||
|
||||
|
||||
def do_plot_all(fn, keep, scores, legend='', metric='auc', sweep_fn=sweep, **plot_kwargs):
|
||||
"""
|
||||
Generate the ROC curves by using one model as test model and the rest to train,
|
||||
with a full leave-one-out cross-validation.
|
||||
"""
|
||||
|
||||
all_predictions = []
|
||||
all_answers = []
|
||||
for i in range(0, len(keep)):
|
||||
mask = np.zeros(len(keep), dtype=bool)
|
||||
mask[i:i+1] = True
|
||||
prediction, answers = fn(keep[~mask],
|
||||
scores[~mask],
|
||||
keep[mask],
|
||||
scores[mask])
|
||||
all_predictions.extend(prediction)
|
||||
all_answers.extend(answers)
|
||||
|
||||
fpr, tpr, auc, acc = sweep_fn(np.array(all_predictions),
|
||||
np.array(all_answers, dtype=bool))
|
||||
|
||||
low = tpr[np.where(fpr < .001)[0][-1]]
|
||||
print('Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc, acc, low))
|
||||
|
||||
metric_text = ''
|
||||
if metric == 'auc':
|
||||
metric_text = 'auc=%.3f' % auc
|
||||
elif metric == 'acc':
|
||||
metric_text = 'acc=%.3f' % acc
|
||||
|
||||
plt.plot(fpr, tpr, label=legend+metric_text, **plot_kwargs)
|
||||
return acc, auc
|
||||
|
||||
|
||||
def fig_fpr_tpr(poison_mask, scores, keep):
|
||||
|
||||
plt.figure(figsize=(4, 3))
|
||||
|
||||
# evaluate LiRA on the points that were not targeted by poisoning
|
||||
do_plot_all(functools.partial(generate_ours, fix_variance=True),
|
||||
keep[:, ~poison_mask], scores[:, ~poison_mask],
|
||||
"No poison (LiRA)\n",
|
||||
metric='auc',
|
||||
)
|
||||
|
||||
# evaluate the global-threshold attack on the points that were not targeted by poisoning
|
||||
do_plot_all(generate_global,
|
||||
keep[:, ~poison_mask], scores[:, ~poison_mask],
|
||||
"No poison (Global threshold)\n",
|
||||
metric='auc', ls="--", c=plt.gca().lines[-1].get_color()
|
||||
)
|
||||
|
||||
# evaluate LiRA on the points that were targeted by poisoning
|
||||
do_plot_all(functools.partial(generate_ours, fix_variance=True),
|
||||
keep[:, poison_mask], scores[:, poison_mask],
|
||||
"With poison (LiRA)\n",
|
||||
metric='auc',
|
||||
)
|
||||
|
||||
# evaluate the global-threshold attack on the points that were targeted by poisoning
|
||||
do_plot_all(generate_global,
|
||||
keep[:, poison_mask], scores[:, poison_mask],
|
||||
"With poison (Global threshold)\n",
|
||||
metric='auc', ls="--", c=plt.gca().lines[-1].get_color()
|
||||
)
|
||||
|
||||
plt.semilogx()
|
||||
plt.semilogy()
|
||||
plt.xlim(1e-3, 1)
|
||||
plt.ylim(1e-3, 1)
|
||||
plt.xlabel("False Positive Rate")
|
||||
plt.ylabel("True Positive Rate")
|
||||
plt.plot([0, 1], [0, 1], ls='--', color='gray')
|
||||
plt.subplots_adjust(bottom=.18, left=.18, top=.96, right=.96)
|
||||
plt.legend(fontsize=8)
|
||||
plt.savefig("/tmp/fprtpr.png")
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logdir = "exp/cifar10/"
|
||||
scores, keep = load_data(logdir)
|
||||
poison_pos = np.load(os.path.join(logdir, "poison_pos.npy"))
|
||||
poison_mask = np.zeros(scores.shape[1], dtype=bool)
|
||||
poison_mask[poison_pos] = True
|
||||
fig_fpr_tpr(poison_mask, scores, keep)
|
30
research/mi_poison_2022/scripts/train_demo.sh
Normal file
30
research/mi_poison_2022/scripts/train_demo.sh
Normal file
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15
|
32
research/mi_poison_2022/scripts/train_demo_multigpu.sh
Normal file
32
research/mi_poison_2022/scripts/train_demo_multigpu.sh
Normal file
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0 &
|
||||
CUDA_VISIBLE_DEVICES='1' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1 &
|
||||
CUDA_VISIBLE_DEVICES='2' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2 &
|
||||
CUDA_VISIBLE_DEVICES='3' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3 &
|
||||
CUDA_VISIBLE_DEVICES='4' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 &
|
||||
CUDA_VISIBLE_DEVICES='5' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 &
|
||||
CUDA_VISIBLE_DEVICES='6' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 &
|
||||
CUDA_VISIBLE_DEVICES='7' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 &
|
||||
wait;
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8 &
|
||||
CUDA_VISIBLE_DEVICES='1' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9 &
|
||||
CUDA_VISIBLE_DEVICES='2' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10 &
|
||||
CUDA_VISIBLE_DEVICES='3' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11 &
|
||||
CUDA_VISIBLE_DEVICES='4' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 &
|
||||
CUDA_VISIBLE_DEVICES='5' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 &
|
||||
CUDA_VISIBLE_DEVICES='6' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 &
|
||||
CUDA_VISIBLE_DEVICES='7' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 &
|
||||
wait;
|
214
research/mi_poison_2022/train_poison.py
Normal file
214
research/mi_poison_2022/train_poison.py
Normal file
|
@ -0,0 +1,214 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# pylint: skip-file
|
||||
# pyformat: disable
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # For data augmentation.
|
||||
import tensorflow_datasets as tfds
|
||||
from absl import app, flags
|
||||
|
||||
from objax.util import EasyDict
|
||||
|
||||
# from mi_lira_2021
|
||||
from dataset import DataSet
|
||||
from train import augment, MemModule, network
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def get_data(seed):
|
||||
"""
|
||||
This is the function to generate subsets of the data for training models.
|
||||
|
||||
First, we get the training dataset either from the numpy cache
|
||||
or otherwise we load it from tensorflow datasets.
|
||||
|
||||
Then, we compute the subset. This works in one of two ways.
|
||||
|
||||
1. If we have a seed, then we just randomly choose examples based on
|
||||
a prng with that seed, keeping FLAGS.pkeep fraction of the data.
|
||||
|
||||
2. Otherwise, if we have an experiment ID, then we do something fancier.
|
||||
If we run each experiment independently then even after a lot of trials
|
||||
there will still probably be some examples that were always included
|
||||
or always excluded. So instead, with experiment IDs, we guarantee that
|
||||
after FLAGS.num_experiments are done, each example is seen exactly half
|
||||
of the time in train, and half of the time not in train.
|
||||
|
||||
Finally, we add some poisons. The same poisoned samples are added for
|
||||
each randomly generated training set.
|
||||
We first select FLAGS.num_poison_targets victim points that will be targeted
|
||||
by the poisoning attack. For each of these victim points, the attacker will
|
||||
insert FLAGS.poison_reps mislabeled replicas of the point into the training
|
||||
set.
|
||||
|
||||
For CIFAR-10, we recommend that:
|
||||
|
||||
`FLAGS.num_poison_targets * FLAGS.poison_reps < 5000`
|
||||
|
||||
Otherwise, the poisons might introduce too much label noise and the model's
|
||||
accuracy (and the attack's success rate) will be degraded.
|
||||
"""
|
||||
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
|
||||
|
||||
if os.path.exists(os.path.join(FLAGS.logdir, "x_train.npy")):
|
||||
inputs = np.load(os.path.join(FLAGS.logdir, "x_train.npy"))
|
||||
labels = np.load(os.path.join(FLAGS.logdir, "y_train.npy"))
|
||||
else:
|
||||
print("First time, creating dataset")
|
||||
data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR))
|
||||
inputs = data['train']['image']
|
||||
labels = data['train']['label']
|
||||
|
||||
inputs = (inputs/127.5)-1
|
||||
np.save(os.path.join(FLAGS.logdir, "x_train.npy"), inputs)
|
||||
np.save(os.path.join(FLAGS.logdir, "y_train.npy"), labels)
|
||||
|
||||
nclass = np.max(labels)+1
|
||||
|
||||
np.random.seed(seed)
|
||||
if FLAGS.num_experiments is not None:
|
||||
np.random.seed(0)
|
||||
keep = np.random.uniform(0, 1, size=(FLAGS.num_experiments, len(inputs)))
|
||||
order = keep.argsort(0)
|
||||
keep = order < int(FLAGS.pkeep * FLAGS.num_experiments)
|
||||
keep = np.array(keep[FLAGS.expid], dtype=bool)
|
||||
else:
|
||||
keep = np.random.uniform(0, 1, size=len(inputs)) <= FLAGS.pkeep
|
||||
|
||||
xs = inputs[keep]
|
||||
ys = labels[keep]
|
||||
|
||||
if FLAGS.num_poison_targets > 0:
|
||||
|
||||
# select some points as targets
|
||||
np.random.seed(FLAGS.poison_pos_seed)
|
||||
poison_pos = np.random.choice(len(inputs), size=FLAGS.num_poison_targets, replace=False)
|
||||
|
||||
# create mislabeled poisons for the targeted points and replicate each
|
||||
# poison `FLAGS.poison_reps` times
|
||||
y_noise = np.mod(labels[poison_pos] + np.random.randint(low=1, high=nclass, size=FLAGS.num_poison_targets), nclass)
|
||||
ypoison = np.repeat(y_noise, FLAGS.poison_reps)
|
||||
xpoison = np.repeat(inputs[poison_pos], FLAGS.poison_reps, axis=0)
|
||||
xs = np.concatenate((xs, xpoison), axis=0)
|
||||
ys = np.concatenate((ys, ypoison), axis=0)
|
||||
|
||||
if not os.path.exists(os.path.join(FLAGS.logdir, "poison_pos.npy")):
|
||||
np.save(os.path.join(FLAGS.logdir, "poison_pos.npy"), poison_pos)
|
||||
|
||||
if FLAGS.augment == 'weak':
|
||||
aug = lambda x: augment(x, 4)
|
||||
elif FLAGS.augment == 'mirror':
|
||||
aug = lambda x: augment(x, 0)
|
||||
elif FLAGS.augment == 'none':
|
||||
aug = lambda x: augment(x, 0, mirror=False)
|
||||
else:
|
||||
raise
|
||||
|
||||
print(xs.shape, ys.shape)
|
||||
train = DataSet.from_arrays(xs, ys,
|
||||
augment_fn=aug)
|
||||
test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:])
|
||||
train = train.cache().shuffle(len(xs)).repeat().parse().augment().batch(FLAGS.batch)
|
||||
train = train.nchw().one_hot(nclass).prefetch(FLAGS.batch)
|
||||
test = test.cache().parse().batch(FLAGS.batch).nchw().prefetch(FLAGS.batch)
|
||||
|
||||
return train, test, xs, ys, keep, nclass
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
tf.config.experimental.set_visible_devices([], "GPU")
|
||||
|
||||
seed = FLAGS.seed
|
||||
if seed is None:
|
||||
import time
|
||||
seed = np.random.randint(0, 1000000000)
|
||||
seed ^= int(time.time())
|
||||
|
||||
args = EasyDict(arch=FLAGS.arch,
|
||||
lr=FLAGS.lr,
|
||||
batch=FLAGS.batch,
|
||||
weight_decay=FLAGS.weight_decay,
|
||||
augment=FLAGS.augment,
|
||||
seed=seed)
|
||||
|
||||
if FLAGS.expid is not None:
|
||||
logdir = "experiment-%d_%d" % (FLAGS.expid, FLAGS.num_experiments)
|
||||
else:
|
||||
logdir = "experiment-"+str(seed)
|
||||
logdir = os.path.join(FLAGS.logdir, logdir)
|
||||
|
||||
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz" % FLAGS.epochs)):
|
||||
print(f"run {FLAGS.expid} already completed.")
|
||||
return
|
||||
else:
|
||||
if os.path.exists(logdir):
|
||||
print(f"deleting run {FLAGS.expid} that did not complete.")
|
||||
shutil.rmtree(logdir)
|
||||
|
||||
print(f"starting run {FLAGS.expid}.")
|
||||
if not os.path.exists(logdir):
|
||||
os.makedirs(logdir)
|
||||
|
||||
train, test, xs, ys, keep, nclass = get_data(seed)
|
||||
|
||||
# Define the network and train_it
|
||||
tm = MemModule(network(FLAGS.arch), nclass=nclass,
|
||||
mnist=FLAGS.dataset == 'mnist',
|
||||
epochs=FLAGS.epochs,
|
||||
expid=FLAGS.expid,
|
||||
num_experiments=FLAGS.num_experiments,
|
||||
pkeep=FLAGS.pkeep,
|
||||
save_steps=FLAGS.save_steps,
|
||||
**args
|
||||
)
|
||||
|
||||
r = {}
|
||||
r.update(tm.params)
|
||||
|
||||
open(os.path.join(logdir, 'hparams.json'), "w").write(json.dumps(tm.params))
|
||||
np.save(os.path.join(logdir,'keep.npy'), keep)
|
||||
|
||||
tm.train(FLAGS.epochs, len(xs), train, test, logdir,
|
||||
save_steps=FLAGS.save_steps)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.')
|
||||
flags.DEFINE_float('lr', 0.1, 'Learning rate.')
|
||||
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
|
||||
flags.DEFINE_float('weight_decay', 0.0005, 'Weight decay ratio.')
|
||||
flags.DEFINE_integer('batch', 256, 'Batch size')
|
||||
flags.DEFINE_integer('epochs', 100, 'Training duration in number of epochs.')
|
||||
flags.DEFINE_string('logdir', 'experiments', 'Directory where to save checkpoints and tensorboard data.')
|
||||
flags.DEFINE_integer('seed', None, 'Training seed.')
|
||||
flags.DEFINE_float('pkeep', .5, 'Probability to keep examples.')
|
||||
flags.DEFINE_integer('expid', None, 'Experiment ID')
|
||||
flags.DEFINE_integer('num_experiments', None, 'Number of experiments')
|
||||
flags.DEFINE_string('augment', 'weak', 'Strong or weak augmentation')
|
||||
flags.DEFINE_integer('eval_steps', 1, 'how often to get eval accuracy.')
|
||||
flags.DEFINE_integer('save_steps', 10, 'how often to get save model.')
|
||||
|
||||
flags.DEFINE_integer('num_poison_targets', 250, 'Number of points to target '
|
||||
'with the poisoning attack.')
|
||||
flags.DEFINE_integer('poison_reps', 8, 'Number of times to repeat each poison.')
|
||||
flags.DEFINE_integer('poison_pos_seed', 0, '')
|
||||
app.run(main)
|
Loading…
Reference in a new issue