COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/privacy/pull/234 from ftramer:truth_serum fe44a0713952ef1615abf032947082eb5c082836
PiperOrigin-RevId: 447573314
This commit is contained in:
parent
137f795352
commit
97eec1a8e3
11 changed files with 581 additions and 70 deletions
|
@ -2,10 +2,9 @@
|
||||||
|
|
||||||
This directory contains code to reproduce our paper:
|
This directory contains code to reproduce our paper:
|
||||||
|
|
||||||
**"Membership Inference Attacks From First Principles"**
|
**"Membership Inference Attacks From First Principles"** <br>
|
||||||
https://arxiv.org/abs/2112.03570
|
https://arxiv.org/abs/2112.03570 <br>
|
||||||
by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramer.
|
by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramèr.
|
||||||
|
|
||||||
|
|
||||||
### INSTALLING
|
### INSTALLING
|
||||||
|
|
||||||
|
@ -18,19 +17,18 @@ with JAX + ObJAX so you will need to follow build instructions for that
|
||||||
https://github.com/google/objax
|
https://github.com/google/objax
|
||||||
https://objax.readthedocs.io/en/latest/installation_setup.html
|
https://objax.readthedocs.io/en/latest/installation_setup.html
|
||||||
|
|
||||||
|
|
||||||
### RUNNING THE CODE
|
### RUNNING THE CODE
|
||||||
|
|
||||||
#### 1. Train the models
|
#### 1. Train the models
|
||||||
|
|
||||||
The first step in our attack is to train shadow models. As a baseline
|
The first step in our attack is to train shadow models. As a baseline that
|
||||||
that should give most of the gains in our attack, you should start by
|
should give most of the gains in our attack, you should start by training 16
|
||||||
training 16 shadow models with the command
|
shadow models with the command
|
||||||
|
|
||||||
> bash scripts/train_demo.sh
|
> bash scripts/train_demo.sh
|
||||||
|
|
||||||
or if you have multiple GPUs on your machine and want to train these models
|
or if you have multiple GPUs on your machine and want to train these models in
|
||||||
in parallel, then modify and run
|
parallel, then modify and run
|
||||||
|
|
||||||
> bash scripts/train_demo_multigpu.sh
|
> bash scripts/train_demo_multigpu.sh
|
||||||
|
|
||||||
|
@ -63,14 +61,13 @@ exp/cifar10/
|
||||||
--- 0000000100.npy
|
--- 0000000100.npy
|
||||||
```
|
```
|
||||||
|
|
||||||
where this new file has shape (50000, 10) and stores the model's
|
where this new file has shape (50000, 10) and stores the model's output features
|
||||||
output features for each example.
|
for each example.
|
||||||
|
|
||||||
|
|
||||||
#### 3. Compute membership inference scores
|
#### 3. Compute membership inference scores
|
||||||
|
|
||||||
Finally we take the output features and generate our logit-scaled membership inference
|
Finally we take the output features and generate our logit-scaled membership
|
||||||
scores for each example for each model.
|
inference scores for each example for each model.
|
||||||
|
|
||||||
> python3 score.py exp/cifar10/
|
> python3 score.py exp/cifar10/
|
||||||
|
|
||||||
|
@ -85,7 +82,6 @@ exp/cifar10/
|
||||||
|
|
||||||
with shape (50000,) storing just our scores.
|
with shape (50000,) storing just our scores.
|
||||||
|
|
||||||
|
|
||||||
### PLOTTING THE RESULTS
|
### PLOTTING THE RESULTS
|
||||||
|
|
||||||
Finally we can generate pretty pictures, and run the plotting code
|
Finally we can generate pretty pictures, and run the plotting code
|
||||||
|
@ -94,7 +90,6 @@ Finally we can generate pretty pictures, and run the plotting code
|
||||||
|
|
||||||
which should give (something like) the following output
|
which should give (something like) the following output
|
||||||
|
|
||||||
|
|
||||||
![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve")
|
![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve")
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -111,9 +106,9 @@ Attack Global threshold
|
||||||
```
|
```
|
||||||
|
|
||||||
where the global threshold attack is the baseline, and our online,
|
where the global threshold attack is the baseline, and our online,
|
||||||
online-with-fixed-variance, offline, and offline-with-fixed-variance
|
online-with-fixed-variance, offline, and offline-with-fixed-variance attack
|
||||||
attack variants are the four other curves. Note that because we only
|
variants are the four other curves. Note that because we only train a few
|
||||||
train a few models, the fixed variance variants perform best.
|
models, the fixed variance variants perform best.
|
||||||
|
|
||||||
### Citation
|
### Citation
|
||||||
|
|
||||||
|
|
|
@ -12,32 +12,22 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import functools
|
# pylint: skip-file
|
||||||
import os
|
# pyformat: disable
|
||||||
from typing import Callable
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import jax
|
|
||||||
import jax.numpy as jn
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf # For data augmentation.
|
|
||||||
import tensorflow_datasets as tfds
|
|
||||||
from absl import app, flags
|
|
||||||
from tqdm import tqdm, trange
|
|
||||||
import pickle
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import objax
|
import objax
|
||||||
from objax.jaxboard import SummaryWriter, Summary
|
import tensorflow as tf # For data augmentation.
|
||||||
from objax.util import EasyDict
|
from absl import app
|
||||||
from objax.zoo import convnet, wide_resnet
|
from absl import flags
|
||||||
|
|
||||||
from dataset import DataSet
|
from train import MemModule
|
||||||
|
from train import network
|
||||||
|
|
||||||
from train import MemModule, network
|
|
||||||
|
|
||||||
from collections import defaultdict
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
|
||||||
|
@ -142,9 +132,7 @@ if __name__ == '__main__':
|
||||||
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
|
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
|
||||||
flags.DEFINE_string('logdir', 'experiments/', 'Directory where to save checkpoints and tensorboard data.')
|
flags.DEFINE_string('logdir', 'experiments/', 'Directory where to save checkpoints and tensorboard data.')
|
||||||
flags.DEFINE_string('regex', '.*experiment.*', 'keep files when matching')
|
flags.DEFINE_string('regex', '.*experiment.*', 'keep files when matching')
|
||||||
flags.DEFINE_bool('random_labels', False, 'use random labels.')
|
|
||||||
flags.DEFINE_integer('dataset_size', 50000, 'size of dataset.')
|
flags.DEFINE_integer('dataset_size', 50000, 'size of dataset.')
|
||||||
flags.DEFINE_integer('from_epoch', None, 'which epoch to load from.')
|
flags.DEFINE_integer('from_epoch', None, 'which epoch to load from.')
|
||||||
flags.DEFINE_integer('seed_mod', None, 'keep mod seed.')
|
|
||||||
flags.DEFINE_integer('modulus', 8, 'modulus.')
|
|
||||||
app.run(main)
|
app.run(main)
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
# pylint: skip-file
|
||||||
|
# pyformat: disable
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
|
|
||||||
|
@ -220,5 +223,6 @@ def fig_fpr_tpr():
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
load_data("exp/cifar10/")
|
if __name__ == '__main__':
|
||||||
fig_fpr_tpr()
|
load_data("exp/cifar10/")
|
||||||
|
fig_fpr_tpr()
|
||||||
|
|
|
@ -12,6 +12,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
# pylint: skip-file
|
||||||
|
# pyformat: disable
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
@ -24,12 +27,11 @@ import numpy as np
|
||||||
import tensorflow as tf # For data augmentation.
|
import tensorflow as tf # For data augmentation.
|
||||||
import tensorflow_datasets as tfds
|
import tensorflow_datasets as tfds
|
||||||
from absl import app, flags
|
from absl import app, flags
|
||||||
from tqdm import tqdm, trange
|
|
||||||
|
|
||||||
import objax
|
import objax
|
||||||
from objax.jaxboard import SummaryWriter, Summary
|
from objax.jaxboard import SummaryWriter, Summary
|
||||||
from objax.util import EasyDict
|
from objax.util import EasyDict
|
||||||
from objax.zoo import convnet, wide_resnet, dnnet
|
from objax.zoo import convnet, wide_resnet
|
||||||
|
|
||||||
from dataset import DataSet
|
from dataset import DataSet
|
||||||
|
|
||||||
|
@ -269,7 +271,7 @@ def main(argv):
|
||||||
logdir = "experiment-"+str(seed)
|
logdir = "experiment-"+str(seed)
|
||||||
logdir = os.path.join(FLAGS.logdir, logdir)
|
logdir = os.path.join(FLAGS.logdir, logdir)
|
||||||
|
|
||||||
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%10)):
|
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%FLAGS.epochs)):
|
||||||
print(f"run {FLAGS.expid} already completed.")
|
print(f"run {FLAGS.expid} already completed.")
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
|
@ -327,3 +329,4 @@ if __name__ == '__main__':
|
||||||
flags.DEFINE_integer('patience', None, 'Early stopping after this many epochs without progress')
|
flags.DEFINE_integer('patience', None, 'Early stopping after this many epochs without progress')
|
||||||
flags.DEFINE_bool('tunename', False, 'Use tune name?')
|
flags.DEFINE_bool('tunename', False, 'Use tune name?')
|
||||||
app.run(main)
|
app.run(main)
|
||||||
|
|
||||||
|
|
116
research/mi_poison_2022/README.md
Normal file
116
research/mi_poison_2022/README.md
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
## Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets
|
||||||
|
|
||||||
|
This directory contains code to reproduce results from the paper:
|
||||||
|
|
||||||
|
**"Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets"**<br>
|
||||||
|
https://arxiv.org/abs/2204.00032 <br>
|
||||||
|
by Florian Tramèr, Reza Shokri, Ayrton San Joaquin, Hoang Le, Matthew Jagielski, Sanghyun Hong and Nicholas Carlini
|
||||||
|
|
||||||
|
### INSTALLING
|
||||||
|
|
||||||
|
The experiments in this directory are built on top of the
|
||||||
|
[LiRA membership inference attack](../mi_lira_2021).
|
||||||
|
|
||||||
|
After following the [installation instructions](../mi_lira_2021#installing) for
|
||||||
|
LiRa, make sure the attack code is on your `PYTHONPATH`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH="${PYTHONPATH}:../mi_lira_2021"
|
||||||
|
```
|
||||||
|
|
||||||
|
### RUNNING THE CODE
|
||||||
|
|
||||||
|
#### 1. Train the models
|
||||||
|
|
||||||
|
The first step in our attack is to train shadow models, with some data points
|
||||||
|
targeted by a poisoning attack. You can train 16 shadow models with the command
|
||||||
|
|
||||||
|
> bash scripts/train_demo.sh
|
||||||
|
|
||||||
|
or if you have multiple GPUs on your machine and want to train these models in
|
||||||
|
parallel, then modify and run
|
||||||
|
|
||||||
|
> bash scripts/train_demo_multigpu.sh
|
||||||
|
|
||||||
|
This will train 16 CIFAR-10 wide ResNet models to ~91% accuracy each, with 250
|
||||||
|
points targeted for poisoning. For each of these 250 targeted points, the
|
||||||
|
attacker adds 8 mislabeled poisoned copies of the point into the training set.
|
||||||
|
The training run will output a bunch of files under the directory exp/cifar10
|
||||||
|
with structure:
|
||||||
|
|
||||||
|
```
|
||||||
|
exp/cifar10/
|
||||||
|
- xtrain.npy
|
||||||
|
- ytain.npy
|
||||||
|
- poison_pos.npy
|
||||||
|
- experiment_N_of_16
|
||||||
|
-- hparams.json
|
||||||
|
-- keep.npy
|
||||||
|
-- ckpt/
|
||||||
|
--- 0000000100.npz
|
||||||
|
-- tb/
|
||||||
|
```
|
||||||
|
|
||||||
|
The following flags control the poisoning attack:
|
||||||
|
|
||||||
|
- `num_poison_targets (default=250)`. The number of targeted points.
|
||||||
|
- `poison_reps (default=8)`. The number of replicas per poison.
|
||||||
|
- `poison_pos_seed (default=0)`. The random seed to use to choose the target
|
||||||
|
points.
|
||||||
|
|
||||||
|
We recommend that `num_poison_targets * poison_reps < 5000` on CIFAR-10, as
|
||||||
|
otherwise the poisons introduce too much label noise and the model's accuracy
|
||||||
|
(and the attack's success rate) will be degraded.
|
||||||
|
|
||||||
|
#### 2. Perform inference and compute scores
|
||||||
|
|
||||||
|
Exactly as for LiRA, we then evaluate the models on the entire CIFAR-10 dataset,
|
||||||
|
and generate logit-scaled membership inference scores. See
|
||||||
|
[here](../mi_lira_2021#2-perform-inference) and
|
||||||
|
[here](../mi_lira_2021#3-compute-membership-inference-scores) for details.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m inference --logdir=exp/cifar10/
|
||||||
|
python3 -m score exp/cifar10/
|
||||||
|
```
|
||||||
|
|
||||||
|
### PLOTTING THE RESULTS
|
||||||
|
|
||||||
|
Finally we can generate pretty pictures, and run the plotting code
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 plot_poison.py
|
||||||
|
```
|
||||||
|
|
||||||
|
which should give (something like) the following output
|
||||||
|
|
||||||
|
![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve")
|
||||||
|
|
||||||
|
```
|
||||||
|
Attack No poison (LiRA)
|
||||||
|
AUC 0.7025, Accuracy 0.6258, TPR@0.1%FPR of 0.0544
|
||||||
|
Attack No poison (Global threshold)
|
||||||
|
AUC 0.6191, Accuracy 0.6173, TPR@0.1%FPR of 0.0012
|
||||||
|
Attack With poison (LiRA)
|
||||||
|
AUC 0.9943, Accuracy 0.9653, TPR@0.1%FPR of 0.4945
|
||||||
|
Attack With poison (Global threshold)
|
||||||
|
AUC 0.9922, Accuracy 0.9603, TPR@0.1%FPR of 0.3930
|
||||||
|
```
|
||||||
|
|
||||||
|
where the baselines are LiRA and a simple global threshold on the membership
|
||||||
|
scores, both without poisoning. With poisoning, both LiRA and the global
|
||||||
|
threshold attack are boosted significantly. Note that because we only train a
|
||||||
|
few models, we use the fixed variance variant of LiRA.
|
||||||
|
|
||||||
|
### Citation
|
||||||
|
|
||||||
|
You can cite this paper with
|
||||||
|
|
||||||
|
```
|
||||||
|
@article{tramer2022truth,
|
||||||
|
title={Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets},
|
||||||
|
author={Tramer, Florian and Shokri, Reza and San Joaquin, Ayrton and Le, Hoang and Jagielski, Matthew and Hong, Sanghyun and Carlini, Nicholas},
|
||||||
|
journal={arXiv preprint arXiv:2204.00032},
|
||||||
|
year={2022}
|
||||||
|
}
|
||||||
|
```
|
BIN
research/mi_poison_2022/fprtpr.png
Normal file
BIN
research/mi_poison_2022/fprtpr.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 32 KiB |
13
research/mi_poison_2022/logs/.keep
Normal file
13
research/mi_poison_2022/logs/.keep
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
116
research/mi_poison_2022/plot_poison.py
Normal file
116
research/mi_poison_2022/plot_poison.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
# Copyright 2021 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# pylint: skip-file
|
||||||
|
# pyformat: disable
|
||||||
|
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.rcParams['pdf.fonttype'] = 42
|
||||||
|
matplotlib.rcParams['ps.fonttype'] = 42
|
||||||
|
|
||||||
|
# from mi_lira_2021
|
||||||
|
from plot import sweep, load_data, generate_ours, generate_global
|
||||||
|
|
||||||
|
|
||||||
|
def do_plot_all(fn, keep, scores, legend='', metric='auc', sweep_fn=sweep, **plot_kwargs):
|
||||||
|
"""
|
||||||
|
Generate the ROC curves by using one model as test model and the rest to train,
|
||||||
|
with a full leave-one-out cross-validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
all_predictions = []
|
||||||
|
all_answers = []
|
||||||
|
for i in range(0, len(keep)):
|
||||||
|
mask = np.zeros(len(keep), dtype=bool)
|
||||||
|
mask[i:i+1] = True
|
||||||
|
prediction, answers = fn(keep[~mask],
|
||||||
|
scores[~mask],
|
||||||
|
keep[mask],
|
||||||
|
scores[mask])
|
||||||
|
all_predictions.extend(prediction)
|
||||||
|
all_answers.extend(answers)
|
||||||
|
|
||||||
|
fpr, tpr, auc, acc = sweep_fn(np.array(all_predictions),
|
||||||
|
np.array(all_answers, dtype=bool))
|
||||||
|
|
||||||
|
low = tpr[np.where(fpr < .001)[0][-1]]
|
||||||
|
print('Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc, acc, low))
|
||||||
|
|
||||||
|
metric_text = ''
|
||||||
|
if metric == 'auc':
|
||||||
|
metric_text = 'auc=%.3f' % auc
|
||||||
|
elif metric == 'acc':
|
||||||
|
metric_text = 'acc=%.3f' % acc
|
||||||
|
|
||||||
|
plt.plot(fpr, tpr, label=legend+metric_text, **plot_kwargs)
|
||||||
|
return acc, auc
|
||||||
|
|
||||||
|
|
||||||
|
def fig_fpr_tpr(poison_mask, scores, keep):
|
||||||
|
|
||||||
|
plt.figure(figsize=(4, 3))
|
||||||
|
|
||||||
|
# evaluate LiRA on the points that were not targeted by poisoning
|
||||||
|
do_plot_all(functools.partial(generate_ours, fix_variance=True),
|
||||||
|
keep[:, ~poison_mask], scores[:, ~poison_mask],
|
||||||
|
"No poison (LiRA)\n",
|
||||||
|
metric='auc',
|
||||||
|
)
|
||||||
|
|
||||||
|
# evaluate the global-threshold attack on the points that were not targeted by poisoning
|
||||||
|
do_plot_all(generate_global,
|
||||||
|
keep[:, ~poison_mask], scores[:, ~poison_mask],
|
||||||
|
"No poison (Global threshold)\n",
|
||||||
|
metric='auc', ls="--", c=plt.gca().lines[-1].get_color()
|
||||||
|
)
|
||||||
|
|
||||||
|
# evaluate LiRA on the points that were targeted by poisoning
|
||||||
|
do_plot_all(functools.partial(generate_ours, fix_variance=True),
|
||||||
|
keep[:, poison_mask], scores[:, poison_mask],
|
||||||
|
"With poison (LiRA)\n",
|
||||||
|
metric='auc',
|
||||||
|
)
|
||||||
|
|
||||||
|
# evaluate the global-threshold attack on the points that were targeted by poisoning
|
||||||
|
do_plot_all(generate_global,
|
||||||
|
keep[:, poison_mask], scores[:, poison_mask],
|
||||||
|
"With poison (Global threshold)\n",
|
||||||
|
metric='auc', ls="--", c=plt.gca().lines[-1].get_color()
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.semilogx()
|
||||||
|
plt.semilogy()
|
||||||
|
plt.xlim(1e-3, 1)
|
||||||
|
plt.ylim(1e-3, 1)
|
||||||
|
plt.xlabel("False Positive Rate")
|
||||||
|
plt.ylabel("True Positive Rate")
|
||||||
|
plt.plot([0, 1], [0, 1], ls='--', color='gray')
|
||||||
|
plt.subplots_adjust(bottom=.18, left=.18, top=.96, right=.96)
|
||||||
|
plt.legend(fontsize=8)
|
||||||
|
plt.savefig("/tmp/fprtpr.png")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
logdir = "exp/cifar10/"
|
||||||
|
scores, keep = load_data(logdir)
|
||||||
|
poison_pos = np.load(os.path.join(logdir, "poison_pos.npy"))
|
||||||
|
poison_mask = np.zeros(scores.shape[1], dtype=bool)
|
||||||
|
poison_mask[poison_pos] = True
|
||||||
|
fig_fpr_tpr(poison_mask, scores, keep)
|
30
research/mi_poison_2022/scripts/train_demo.sh
Normal file
30
research/mi_poison_2022/scripts/train_demo.sh
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2021 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15
|
32
research/mi_poison_2022/scripts/train_demo_multigpu.sh
Normal file
32
research/mi_poison_2022/scripts/train_demo_multigpu.sh
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2021 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0 &
|
||||||
|
CUDA_VISIBLE_DEVICES='1' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1 &
|
||||||
|
CUDA_VISIBLE_DEVICES='2' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2 &
|
||||||
|
CUDA_VISIBLE_DEVICES='3' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3 &
|
||||||
|
CUDA_VISIBLE_DEVICES='4' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 &
|
||||||
|
CUDA_VISIBLE_DEVICES='5' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 &
|
||||||
|
CUDA_VISIBLE_DEVICES='6' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 &
|
||||||
|
CUDA_VISIBLE_DEVICES='7' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 &
|
||||||
|
wait;
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8 &
|
||||||
|
CUDA_VISIBLE_DEVICES='1' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9 &
|
||||||
|
CUDA_VISIBLE_DEVICES='2' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10 &
|
||||||
|
CUDA_VISIBLE_DEVICES='3' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11 &
|
||||||
|
CUDA_VISIBLE_DEVICES='4' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 &
|
||||||
|
CUDA_VISIBLE_DEVICES='5' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 &
|
||||||
|
CUDA_VISIBLE_DEVICES='6' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 &
|
||||||
|
CUDA_VISIBLE_DEVICES='7' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 &
|
||||||
|
wait;
|
214
research/mi_poison_2022/train_poison.py
Normal file
214
research/mi_poison_2022/train_poison.py
Normal file
|
@ -0,0 +1,214 @@
|
||||||
|
# Copyright 2021 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# pylint: skip-file
|
||||||
|
# pyformat: disable
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import json
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf # For data augmentation.
|
||||||
|
import tensorflow_datasets as tfds
|
||||||
|
from absl import app, flags
|
||||||
|
|
||||||
|
from objax.util import EasyDict
|
||||||
|
|
||||||
|
# from mi_lira_2021
|
||||||
|
from dataset import DataSet
|
||||||
|
from train import augment, MemModule, network
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
|
||||||
|
def get_data(seed):
|
||||||
|
"""
|
||||||
|
This is the function to generate subsets of the data for training models.
|
||||||
|
|
||||||
|
First, we get the training dataset either from the numpy cache
|
||||||
|
or otherwise we load it from tensorflow datasets.
|
||||||
|
|
||||||
|
Then, we compute the subset. This works in one of two ways.
|
||||||
|
|
||||||
|
1. If we have a seed, then we just randomly choose examples based on
|
||||||
|
a prng with that seed, keeping FLAGS.pkeep fraction of the data.
|
||||||
|
|
||||||
|
2. Otherwise, if we have an experiment ID, then we do something fancier.
|
||||||
|
If we run each experiment independently then even after a lot of trials
|
||||||
|
there will still probably be some examples that were always included
|
||||||
|
or always excluded. So instead, with experiment IDs, we guarantee that
|
||||||
|
after FLAGS.num_experiments are done, each example is seen exactly half
|
||||||
|
of the time in train, and half of the time not in train.
|
||||||
|
|
||||||
|
Finally, we add some poisons. The same poisoned samples are added for
|
||||||
|
each randomly generated training set.
|
||||||
|
We first select FLAGS.num_poison_targets victim points that will be targeted
|
||||||
|
by the poisoning attack. For each of these victim points, the attacker will
|
||||||
|
insert FLAGS.poison_reps mislabeled replicas of the point into the training
|
||||||
|
set.
|
||||||
|
|
||||||
|
For CIFAR-10, we recommend that:
|
||||||
|
|
||||||
|
`FLAGS.num_poison_targets * FLAGS.poison_reps < 5000`
|
||||||
|
|
||||||
|
Otherwise, the poisons might introduce too much label noise and the model's
|
||||||
|
accuracy (and the attack's success rate) will be degraded.
|
||||||
|
"""
|
||||||
|
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(FLAGS.logdir, "x_train.npy")):
|
||||||
|
inputs = np.load(os.path.join(FLAGS.logdir, "x_train.npy"))
|
||||||
|
labels = np.load(os.path.join(FLAGS.logdir, "y_train.npy"))
|
||||||
|
else:
|
||||||
|
print("First time, creating dataset")
|
||||||
|
data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR))
|
||||||
|
inputs = data['train']['image']
|
||||||
|
labels = data['train']['label']
|
||||||
|
|
||||||
|
inputs = (inputs/127.5)-1
|
||||||
|
np.save(os.path.join(FLAGS.logdir, "x_train.npy"), inputs)
|
||||||
|
np.save(os.path.join(FLAGS.logdir, "y_train.npy"), labels)
|
||||||
|
|
||||||
|
nclass = np.max(labels)+1
|
||||||
|
|
||||||
|
np.random.seed(seed)
|
||||||
|
if FLAGS.num_experiments is not None:
|
||||||
|
np.random.seed(0)
|
||||||
|
keep = np.random.uniform(0, 1, size=(FLAGS.num_experiments, len(inputs)))
|
||||||
|
order = keep.argsort(0)
|
||||||
|
keep = order < int(FLAGS.pkeep * FLAGS.num_experiments)
|
||||||
|
keep = np.array(keep[FLAGS.expid], dtype=bool)
|
||||||
|
else:
|
||||||
|
keep = np.random.uniform(0, 1, size=len(inputs)) <= FLAGS.pkeep
|
||||||
|
|
||||||
|
xs = inputs[keep]
|
||||||
|
ys = labels[keep]
|
||||||
|
|
||||||
|
if FLAGS.num_poison_targets > 0:
|
||||||
|
|
||||||
|
# select some points as targets
|
||||||
|
np.random.seed(FLAGS.poison_pos_seed)
|
||||||
|
poison_pos = np.random.choice(len(inputs), size=FLAGS.num_poison_targets, replace=False)
|
||||||
|
|
||||||
|
# create mislabeled poisons for the targeted points and replicate each
|
||||||
|
# poison `FLAGS.poison_reps` times
|
||||||
|
y_noise = np.mod(labels[poison_pos] + np.random.randint(low=1, high=nclass, size=FLAGS.num_poison_targets), nclass)
|
||||||
|
ypoison = np.repeat(y_noise, FLAGS.poison_reps)
|
||||||
|
xpoison = np.repeat(inputs[poison_pos], FLAGS.poison_reps, axis=0)
|
||||||
|
xs = np.concatenate((xs, xpoison), axis=0)
|
||||||
|
ys = np.concatenate((ys, ypoison), axis=0)
|
||||||
|
|
||||||
|
if not os.path.exists(os.path.join(FLAGS.logdir, "poison_pos.npy")):
|
||||||
|
np.save(os.path.join(FLAGS.logdir, "poison_pos.npy"), poison_pos)
|
||||||
|
|
||||||
|
if FLAGS.augment == 'weak':
|
||||||
|
aug = lambda x: augment(x, 4)
|
||||||
|
elif FLAGS.augment == 'mirror':
|
||||||
|
aug = lambda x: augment(x, 0)
|
||||||
|
elif FLAGS.augment == 'none':
|
||||||
|
aug = lambda x: augment(x, 0, mirror=False)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
print(xs.shape, ys.shape)
|
||||||
|
train = DataSet.from_arrays(xs, ys,
|
||||||
|
augment_fn=aug)
|
||||||
|
test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:])
|
||||||
|
train = train.cache().shuffle(len(xs)).repeat().parse().augment().batch(FLAGS.batch)
|
||||||
|
train = train.nchw().one_hot(nclass).prefetch(FLAGS.batch)
|
||||||
|
test = test.cache().parse().batch(FLAGS.batch).nchw().prefetch(FLAGS.batch)
|
||||||
|
|
||||||
|
return train, test, xs, ys, keep, nclass
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
del argv
|
||||||
|
tf.config.experimental.set_visible_devices([], "GPU")
|
||||||
|
|
||||||
|
seed = FLAGS.seed
|
||||||
|
if seed is None:
|
||||||
|
import time
|
||||||
|
seed = np.random.randint(0, 1000000000)
|
||||||
|
seed ^= int(time.time())
|
||||||
|
|
||||||
|
args = EasyDict(arch=FLAGS.arch,
|
||||||
|
lr=FLAGS.lr,
|
||||||
|
batch=FLAGS.batch,
|
||||||
|
weight_decay=FLAGS.weight_decay,
|
||||||
|
augment=FLAGS.augment,
|
||||||
|
seed=seed)
|
||||||
|
|
||||||
|
if FLAGS.expid is not None:
|
||||||
|
logdir = "experiment-%d_%d" % (FLAGS.expid, FLAGS.num_experiments)
|
||||||
|
else:
|
||||||
|
logdir = "experiment-"+str(seed)
|
||||||
|
logdir = os.path.join(FLAGS.logdir, logdir)
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz" % FLAGS.epochs)):
|
||||||
|
print(f"run {FLAGS.expid} already completed.")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
if os.path.exists(logdir):
|
||||||
|
print(f"deleting run {FLAGS.expid} that did not complete.")
|
||||||
|
shutil.rmtree(logdir)
|
||||||
|
|
||||||
|
print(f"starting run {FLAGS.expid}.")
|
||||||
|
if not os.path.exists(logdir):
|
||||||
|
os.makedirs(logdir)
|
||||||
|
|
||||||
|
train, test, xs, ys, keep, nclass = get_data(seed)
|
||||||
|
|
||||||
|
# Define the network and train_it
|
||||||
|
tm = MemModule(network(FLAGS.arch), nclass=nclass,
|
||||||
|
mnist=FLAGS.dataset == 'mnist',
|
||||||
|
epochs=FLAGS.epochs,
|
||||||
|
expid=FLAGS.expid,
|
||||||
|
num_experiments=FLAGS.num_experiments,
|
||||||
|
pkeep=FLAGS.pkeep,
|
||||||
|
save_steps=FLAGS.save_steps,
|
||||||
|
**args
|
||||||
|
)
|
||||||
|
|
||||||
|
r = {}
|
||||||
|
r.update(tm.params)
|
||||||
|
|
||||||
|
open(os.path.join(logdir, 'hparams.json'), "w").write(json.dumps(tm.params))
|
||||||
|
np.save(os.path.join(logdir,'keep.npy'), keep)
|
||||||
|
|
||||||
|
tm.train(FLAGS.epochs, len(xs), train, test, logdir,
|
||||||
|
save_steps=FLAGS.save_steps)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.')
|
||||||
|
flags.DEFINE_float('lr', 0.1, 'Learning rate.')
|
||||||
|
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
|
||||||
|
flags.DEFINE_float('weight_decay', 0.0005, 'Weight decay ratio.')
|
||||||
|
flags.DEFINE_integer('batch', 256, 'Batch size')
|
||||||
|
flags.DEFINE_integer('epochs', 100, 'Training duration in number of epochs.')
|
||||||
|
flags.DEFINE_string('logdir', 'experiments', 'Directory where to save checkpoints and tensorboard data.')
|
||||||
|
flags.DEFINE_integer('seed', None, 'Training seed.')
|
||||||
|
flags.DEFINE_float('pkeep', .5, 'Probability to keep examples.')
|
||||||
|
flags.DEFINE_integer('expid', None, 'Experiment ID')
|
||||||
|
flags.DEFINE_integer('num_experiments', None, 'Number of experiments')
|
||||||
|
flags.DEFINE_string('augment', 'weak', 'Strong or weak augmentation')
|
||||||
|
flags.DEFINE_integer('eval_steps', 1, 'how often to get eval accuracy.')
|
||||||
|
flags.DEFINE_integer('save_steps', 10, 'how often to get save model.')
|
||||||
|
|
||||||
|
flags.DEFINE_integer('num_poison_targets', 250, 'Number of points to target '
|
||||||
|
'with the poisoning attack.')
|
||||||
|
flags.DEFINE_integer('poison_reps', 8, 'Number of times to repeat each poison.')
|
||||||
|
flags.DEFINE_integer('poison_pos_seed', 0, '')
|
||||||
|
app.run(main)
|
Loading…
Reference in a new issue