Merge pull request #183 from carlini/better-mi
Add research code to reproduce Membership Inference Attacks From First Principles
This commit is contained in:
commit
3d499e69ba
10 changed files with 1027 additions and 0 deletions
129
research/mi_lira_2021/README.md
Normal file
129
research/mi_lira_2021/README.md
Normal file
|
@ -0,0 +1,129 @@
|
||||||
|
## Membership Inference Attacks From First Principles
|
||||||
|
|
||||||
|
This directory contains code to reproduce our paper:
|
||||||
|
|
||||||
|
**"Membership Inference Attacks From First Principles"**
|
||||||
|
https://arxiv.org/abs/2112.03570
|
||||||
|
by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramer.
|
||||||
|
|
||||||
|
|
||||||
|
### INSTALLING
|
||||||
|
|
||||||
|
You will need to install fairly standard dependencies
|
||||||
|
|
||||||
|
`pip install scipy, sklearn, numpy, matplotlib`
|
||||||
|
|
||||||
|
and also some machine learning framework to train models. We train our models
|
||||||
|
with JAX + ObJAX so you will need to follow build instructions for that
|
||||||
|
https://github.com/google/objax
|
||||||
|
https://objax.readthedocs.io/en/latest/installation_setup.html
|
||||||
|
|
||||||
|
|
||||||
|
### RUNNING THE CODE
|
||||||
|
|
||||||
|
#### 1. Train the models
|
||||||
|
|
||||||
|
The first step in our attack is to train shadow models. As a baseline
|
||||||
|
that should give most of the gains in our attack, you should start by
|
||||||
|
training 16 shadow models with the command
|
||||||
|
|
||||||
|
> bash scripts/train_demo.sh
|
||||||
|
|
||||||
|
or if you have multiple GPUs on your machine and want to train these models
|
||||||
|
in parallel, then modify and run
|
||||||
|
|
||||||
|
> bash scripts/train_demo_multigpu.sh
|
||||||
|
|
||||||
|
This will train several CIFAR-10 wide ResNet models to ~91% accuracy each, and
|
||||||
|
will output a bunch of files under the directory exp/cifar10 with structure:
|
||||||
|
|
||||||
|
```
|
||||||
|
exp/cifar10/
|
||||||
|
- experiment_N_of_16
|
||||||
|
-- hparams.json
|
||||||
|
-- keep.npy
|
||||||
|
-- ckpt/
|
||||||
|
--- 0000000100.npz
|
||||||
|
-- tb/
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Perform inference
|
||||||
|
|
||||||
|
Once the models are trained, now it's necessary to perform inference and save
|
||||||
|
the output features for each training example for each model in the dataset.
|
||||||
|
|
||||||
|
> python3 inference.py --logdir=exp/cifar10/
|
||||||
|
|
||||||
|
This will add to the experiment directory a new set of files
|
||||||
|
|
||||||
|
```
|
||||||
|
exp/cifar10/
|
||||||
|
- experiment_N_of_16
|
||||||
|
-- logits/
|
||||||
|
--- 0000000100.npy
|
||||||
|
```
|
||||||
|
|
||||||
|
where this new file has shape (50000, 10) and stores the model's
|
||||||
|
output features for each example.
|
||||||
|
|
||||||
|
|
||||||
|
#### 3. Compute membership inference scores
|
||||||
|
|
||||||
|
Finally we take the output features and generate our logit-scaled membership inference
|
||||||
|
scores for each example for each model.
|
||||||
|
|
||||||
|
> python3 score.py exp/cifar10/
|
||||||
|
|
||||||
|
And this in turn generates a new directory
|
||||||
|
|
||||||
|
```
|
||||||
|
exp/cifar10/
|
||||||
|
- experiment_N_of_16
|
||||||
|
-- scores/
|
||||||
|
--- 0000000100.npy
|
||||||
|
```
|
||||||
|
|
||||||
|
with shape (50000,) storing just our scores.
|
||||||
|
|
||||||
|
|
||||||
|
### PLOTTING THE RESULTS
|
||||||
|
|
||||||
|
Finally we can generate pretty pictures, and run the plotting code
|
||||||
|
|
||||||
|
> python3 plot.py
|
||||||
|
|
||||||
|
which should give (something like) the following output
|
||||||
|
|
||||||
|
|
||||||
|
![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve")
|
||||||
|
|
||||||
|
```
|
||||||
|
Attack Ours (online)
|
||||||
|
AUC 0.6676, Accuracy 0.6077, TPR@0.1%FPR of 0.0169
|
||||||
|
Attack Ours (online, fixed variance)
|
||||||
|
AUC 0.6856, Accuracy 0.6137, TPR@0.1%FPR of 0.0593
|
||||||
|
Attack Ours (offline)
|
||||||
|
AUC 0.5488, Accuracy 0.5500, TPR@0.1%FPR of 0.0130
|
||||||
|
Attack Ours (offline, fixed variance)
|
||||||
|
AUC 0.5549, Accuracy 0.5537, TPR@0.1%FPR of 0.0299
|
||||||
|
Attack Global threshold
|
||||||
|
AUC 0.5921, Accuracy 0.6044, TPR@0.1%FPR of 0.0009
|
||||||
|
```
|
||||||
|
|
||||||
|
where the global threshold attack is the baseline, and our online,
|
||||||
|
online-with-fixed-variance, offline, and offline-with-fixed-variance
|
||||||
|
attack variants are the four other curves. Note that because we only
|
||||||
|
train a few models, the fixed variance variants perform best.
|
||||||
|
|
||||||
|
### Citation
|
||||||
|
|
||||||
|
You can cite this paper with
|
||||||
|
|
||||||
|
```
|
||||||
|
@article{carlini2021membership,
|
||||||
|
title={Membership Inference Attacks From First Principles},
|
||||||
|
author={Carlini, Nicholas and Chien, Steve and Nasr, Milad and Song, Shuang and Terzis, Andreas and Tramer, Florian},
|
||||||
|
journal={arXiv preprint arXiv:2112.03570},
|
||||||
|
year={2021}
|
||||||
|
}
|
||||||
|
```
|
95
research/mi_lira_2021/dataset.py
Normal file
95
research/mi_lira_2021/dataset.py
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
# Copyright 2020 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Callable, Optional, Tuple, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def record_parse(serialized_example: str, image_shape: Tuple[int, int, int]):
|
||||||
|
features = tf.io.parse_single_example(serialized_example,
|
||||||
|
features={'image': tf.io.FixedLenFeature([], tf.string),
|
||||||
|
'label': tf.io.FixedLenFeature([], tf.int64)})
|
||||||
|
image = tf.image.decode_image(features['image']).set_shape(image_shape)
|
||||||
|
image = tf.cast(image, tf.float32) * (2.0 / 255) - 1.0
|
||||||
|
return dict(image=image, label=features['label'])
|
||||||
|
|
||||||
|
|
||||||
|
class DataSet:
|
||||||
|
"""Wrapper for tf.data.Dataset to permit extensions."""
|
||||||
|
|
||||||
|
def __init__(self, data: tf.data.Dataset,
|
||||||
|
image_shape: Tuple[int, int, int],
|
||||||
|
augment_fn: Optional[Callable] = None,
|
||||||
|
parse_fn: Optional[Callable] = record_parse):
|
||||||
|
self.data = data
|
||||||
|
self.parse_fn = parse_fn
|
||||||
|
self.augment_fn = augment_fn
|
||||||
|
self.image_shape = image_shape
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_arrays(cls, images: np.ndarray, labels: np.ndarray, augment_fn: Optional[Callable] = None):
|
||||||
|
return cls(tf.data.Dataset.from_tensor_slices(dict(image=images, label=labels)), images.shape[1:],
|
||||||
|
augment_fn=augment_fn, parse_fn=None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_files(cls, filenames: List[str],
|
||||||
|
image_shape: Tuple[int, int, int],
|
||||||
|
augment_fn: Optional[Callable],
|
||||||
|
parse_fn: Optional[Callable] = record_parse):
|
||||||
|
filenames_in = filenames
|
||||||
|
filenames = sorted(sum([tf.io.gfile.glob(x) for x in filenames], []))
|
||||||
|
if not filenames:
|
||||||
|
raise ValueError('Empty dataset, files not found:', filenames_in)
|
||||||
|
return cls(tf.data.TFRecordDataset(filenames), image_shape, augment_fn=augment_fn, parse_fn=parse_fn)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tfds(cls, dataset: tf.data.Dataset, image_shape: Tuple[int, int, int],
|
||||||
|
augment_fn: Optional[Callable] = None):
|
||||||
|
return cls(dataset.map(lambda x: dict(image=tf.cast(x['image'], tf.float32) / 127.5 - 1, label=x['label'])),
|
||||||
|
image_shape, augment_fn=augment_fn, parse_fn=None)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.data)
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
if item in self.__dict__:
|
||||||
|
return self.__dict__[item]
|
||||||
|
|
||||||
|
def call_and_update(*args, **kwargs):
|
||||||
|
v = getattr(self.__dict__['data'], item)(*args, **kwargs)
|
||||||
|
if isinstance(v, tf.data.Dataset):
|
||||||
|
return self.__class__(v, self.image_shape, augment_fn=self.augment_fn, parse_fn=self.parse_fn)
|
||||||
|
return v
|
||||||
|
|
||||||
|
return call_and_update
|
||||||
|
|
||||||
|
def augment(self, para_augment: int = 4):
|
||||||
|
if self.augment_fn:
|
||||||
|
return self.map(self.augment_fn, para_augment)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def nchw(self):
|
||||||
|
return self.map(lambda x: dict(image=tf.transpose(x['image'], [0, 3, 1, 2]), label=x['label']))
|
||||||
|
|
||||||
|
def one_hot(self, nclass: int):
|
||||||
|
return self.map(lambda x: dict(image=x['image'], label=tf.one_hot(x['label'], nclass)))
|
||||||
|
|
||||||
|
def parse(self, para_parse: int = 2):
|
||||||
|
if not self.parse_fn:
|
||||||
|
return self
|
||||||
|
if self.image_shape:
|
||||||
|
return self.map(lambda x: self.parse_fn(x, self.image_shape), para_parse)
|
||||||
|
return self.map(self.parse_fn, para_parse)
|
BIN
research/mi_lira_2021/fprtpr.png
Normal file
BIN
research/mi_lira_2021/fprtpr.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 37 KiB |
150
research/mi_lira_2021/inference.py
Normal file
150
research/mi_lira_2021/inference.py
Normal file
|
@ -0,0 +1,150 @@
|
||||||
|
# Copyright 2021 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import os
|
||||||
|
from typing import Callable
|
||||||
|
import json
|
||||||
|
|
||||||
|
import re
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jn
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf # For data augmentation.
|
||||||
|
import tensorflow_datasets as tfds
|
||||||
|
from absl import app, flags
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
import pickle
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import objax
|
||||||
|
from objax.jaxboard import SummaryWriter, Summary
|
||||||
|
from objax.util import EasyDict
|
||||||
|
from objax.zoo import convnet, wide_resnet
|
||||||
|
|
||||||
|
from dataset import DataSet
|
||||||
|
|
||||||
|
from train import MemModule, network
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
"""
|
||||||
|
Perform inference of the saved model in order to generate the
|
||||||
|
output logits, using a particular set of augmentations.
|
||||||
|
"""
|
||||||
|
del argv
|
||||||
|
tf.config.experimental.set_visible_devices([], "GPU")
|
||||||
|
|
||||||
|
def load(arch):
|
||||||
|
return MemModule(network(arch), nclass=100 if FLAGS.dataset == 'cifar100' else 10,
|
||||||
|
mnist=FLAGS.dataset == 'mnist',
|
||||||
|
arch=arch,
|
||||||
|
lr=.1,
|
||||||
|
batch=0,
|
||||||
|
epochs=0,
|
||||||
|
weight_decay=0)
|
||||||
|
|
||||||
|
def cache_load(arch):
|
||||||
|
thing = []
|
||||||
|
def fn():
|
||||||
|
if len(thing) == 0:
|
||||||
|
thing.append(load(arch))
|
||||||
|
return thing[0]
|
||||||
|
return fn
|
||||||
|
|
||||||
|
xs_all = np.load(os.path.join(FLAGS.logdir,"x_train.npy"))[:FLAGS.dataset_size]
|
||||||
|
ys_all = np.load(os.path.join(FLAGS.logdir,"y_train.npy"))[:FLAGS.dataset_size]
|
||||||
|
|
||||||
|
|
||||||
|
def get_loss(model, xbatch, ybatch, shift, reflect=True, stride=1):
|
||||||
|
|
||||||
|
outs = []
|
||||||
|
for aug in [xbatch, xbatch[:,:,::-1,:]][:reflect+1]:
|
||||||
|
aug_pad = tf.pad(aug, [[0] * 2, [shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT').numpy()
|
||||||
|
for dx in range(0, 2*shift+1, stride):
|
||||||
|
for dy in range(0, 2*shift+1, stride):
|
||||||
|
this_x = aug_pad[:, dx:dx+32, dy:dy+32, :].transpose((0,3,1,2))
|
||||||
|
|
||||||
|
logits = model.model(this_x, training=True)
|
||||||
|
outs.append(logits)
|
||||||
|
|
||||||
|
print(np.array(outs).shape)
|
||||||
|
return np.array(outs).transpose((1, 0, 2))
|
||||||
|
|
||||||
|
N = 5000
|
||||||
|
|
||||||
|
def features(model, xbatch, ybatch):
|
||||||
|
return get_loss(model, xbatch, ybatch,
|
||||||
|
shift=0, reflect=True, stride=1)
|
||||||
|
|
||||||
|
for path in sorted(os.listdir(os.path.join(FLAGS.logdir))):
|
||||||
|
if re.search(FLAGS.regex, path) is None:
|
||||||
|
print("Skipping from regex")
|
||||||
|
continue
|
||||||
|
|
||||||
|
hparams = json.load(open(os.path.join(FLAGS.logdir, path, "hparams.json")))
|
||||||
|
arch = hparams['arch']
|
||||||
|
model = cache_load(arch)()
|
||||||
|
|
||||||
|
logdir = os.path.join(FLAGS.logdir, path)
|
||||||
|
|
||||||
|
checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=10, makedir=True)
|
||||||
|
max_epoch, last_ckpt = checkpoint.restore(model.vars())
|
||||||
|
if max_epoch == 0: continue
|
||||||
|
|
||||||
|
if not os.path.exists(os.path.join(FLAGS.logdir, path, "logits")):
|
||||||
|
os.mkdir(os.path.join(FLAGS.logdir, path, "logits"))
|
||||||
|
if FLAGS.from_epoch is not None:
|
||||||
|
first = FLAGS.from_epoch
|
||||||
|
else:
|
||||||
|
first = max_epoch-1
|
||||||
|
|
||||||
|
for epoch in range(first,max_epoch+1):
|
||||||
|
if not os.path.exists(os.path.join(FLAGS.logdir, path, "ckpt", "%010d.npz"%epoch)):
|
||||||
|
# no checkpoint saved here
|
||||||
|
continue
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(FLAGS.logdir, path, "logits", "%010d.npy"%epoch)):
|
||||||
|
print("Skipping already generated file", epoch)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_epoch, last_ckpt = checkpoint.restore(model.vars(), epoch)
|
||||||
|
except:
|
||||||
|
print("Fail to load", epoch)
|
||||||
|
continue
|
||||||
|
|
||||||
|
stats = []
|
||||||
|
|
||||||
|
for i in range(0,len(xs_all),N):
|
||||||
|
stats.extend(features(model, xs_all[i:i+N],
|
||||||
|
ys_all[i:i+N]))
|
||||||
|
# This will be shape N, augs, nclass
|
||||||
|
|
||||||
|
np.save(os.path.join(FLAGS.logdir, path, "logits", "%010d"%epoch),
|
||||||
|
np.array(stats)[:,None,:,:])
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
|
||||||
|
flags.DEFINE_string('logdir', 'experiments/', 'Directory where to save checkpoints and tensorboard data.')
|
||||||
|
flags.DEFINE_string('regex', '.*experiment.*', 'keep files when matching')
|
||||||
|
flags.DEFINE_bool('random_labels', False, 'use random labels.')
|
||||||
|
flags.DEFINE_integer('dataset_size', 50000, 'size of dataset.')
|
||||||
|
flags.DEFINE_integer('from_epoch', None, 'which epoch to load from.')
|
||||||
|
flags.DEFINE_integer('seed_mod', None, 'keep mod seed.')
|
||||||
|
flags.DEFINE_integer('modulus', 8, 'modulus.')
|
||||||
|
app.run(main)
|
0
research/mi_lira_2021/logs/.keep
Normal file
0
research/mi_lira_2021/logs/.keep
Normal file
224
research/mi_lira_2021/plot.py
Normal file
224
research/mi_lira_2021/plot.py
Normal file
|
@ -0,0 +1,224 @@
|
||||||
|
# Copyright 2021 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import scipy.stats
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from sklearn.metrics import auc, roc_curve
|
||||||
|
import functools
|
||||||
|
|
||||||
|
# Look at me being proactive!
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.rcParams['pdf.fonttype'] = 42
|
||||||
|
matplotlib.rcParams['ps.fonttype'] = 42
|
||||||
|
|
||||||
|
|
||||||
|
def sweep(score, x):
|
||||||
|
"""
|
||||||
|
Compute a ROC curve and then return the FPR, TPR, AUC, and ACC.
|
||||||
|
"""
|
||||||
|
fpr, tpr, _ = roc_curve(x, -score)
|
||||||
|
acc = np.max(1-(fpr+(1-tpr))/2)
|
||||||
|
return fpr, tpr, auc(fpr, tpr), acc
|
||||||
|
|
||||||
|
def load_data(p):
|
||||||
|
"""
|
||||||
|
Load our saved scores and then put them into a big matrix.
|
||||||
|
"""
|
||||||
|
global scores, keep
|
||||||
|
scores = []
|
||||||
|
keep = []
|
||||||
|
|
||||||
|
for root,ds,_ in os.walk(p):
|
||||||
|
for f in ds:
|
||||||
|
if not f.startswith("experiment"): continue
|
||||||
|
if not os.path.exists(os.path.join(root,f,"scores")): continue
|
||||||
|
last_epoch = sorted(os.listdir(os.path.join(root,f,"scores")))
|
||||||
|
if len(last_epoch) == 0: continue
|
||||||
|
scores.append(np.load(os.path.join(root,f,"scores",last_epoch[-1])))
|
||||||
|
keep.append(np.load(os.path.join(root,f,"keep.npy")))
|
||||||
|
|
||||||
|
scores = np.array(scores)
|
||||||
|
keep = np.array(keep)[:,:scores.shape[1]]
|
||||||
|
|
||||||
|
return scores, keep
|
||||||
|
|
||||||
|
def generate_ours(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000,
|
||||||
|
fix_variance=False):
|
||||||
|
"""
|
||||||
|
Fit a two predictive models using keep and scores in order to predict
|
||||||
|
if the examples in check_scores were training data or not, using the
|
||||||
|
ground truth answer from check_keep.
|
||||||
|
"""
|
||||||
|
dat_in = []
|
||||||
|
dat_out = []
|
||||||
|
|
||||||
|
for j in range(scores.shape[1]):
|
||||||
|
dat_in.append(scores[keep[:,j],j,:])
|
||||||
|
dat_out.append(scores[~keep[:,j],j,:])
|
||||||
|
|
||||||
|
in_size = min(min(map(len,dat_in)), in_size)
|
||||||
|
out_size = min(min(map(len,dat_out)), out_size)
|
||||||
|
|
||||||
|
dat_in = np.array([x[:in_size] for x in dat_in])
|
||||||
|
dat_out = np.array([x[:out_size] for x in dat_out])
|
||||||
|
|
||||||
|
mean_in = np.median(dat_in, 1)
|
||||||
|
mean_out = np.median(dat_out, 1)
|
||||||
|
|
||||||
|
if fix_variance:
|
||||||
|
std_in = np.std(dat_in)
|
||||||
|
std_out = np.std(dat_in)
|
||||||
|
else:
|
||||||
|
std_in = np.std(dat_in, 1)
|
||||||
|
std_out = np.std(dat_out, 1)
|
||||||
|
|
||||||
|
prediction = []
|
||||||
|
answers = []
|
||||||
|
for ans, sc in zip(check_keep, check_scores):
|
||||||
|
pr_in = -scipy.stats.norm.logpdf(sc, mean_in, std_in+1e-30)
|
||||||
|
pr_out = -scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30)
|
||||||
|
score = pr_in-pr_out
|
||||||
|
|
||||||
|
prediction.extend(score.mean(1))
|
||||||
|
answers.extend(ans)
|
||||||
|
|
||||||
|
return prediction, answers
|
||||||
|
|
||||||
|
def generate_ours_offline(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000,
|
||||||
|
fix_variance=False):
|
||||||
|
"""
|
||||||
|
Fit a single predictive model using keep and scores in order to predict
|
||||||
|
if the examples in check_scores were training data or not, using the
|
||||||
|
ground truth answer from check_keep.
|
||||||
|
"""
|
||||||
|
dat_in = []
|
||||||
|
dat_out = []
|
||||||
|
|
||||||
|
for j in range(scores.shape[1]):
|
||||||
|
dat_in.append(scores[keep[:, j], j, :])
|
||||||
|
dat_out.append(scores[~keep[:, j], j, :])
|
||||||
|
|
||||||
|
out_size = min(min(map(len,dat_out)), out_size)
|
||||||
|
|
||||||
|
dat_out = np.array([x[:out_size] for x in dat_out])
|
||||||
|
|
||||||
|
mean_out = np.median(dat_out, 1)
|
||||||
|
|
||||||
|
if fix_variance:
|
||||||
|
std_out = np.std(dat_out)
|
||||||
|
else:
|
||||||
|
std_out = np.std(dat_out, 1)
|
||||||
|
|
||||||
|
prediction = []
|
||||||
|
answers = []
|
||||||
|
for ans, sc in zip(check_keep, check_scores):
|
||||||
|
score = scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30)
|
||||||
|
|
||||||
|
prediction.extend(score.mean(1))
|
||||||
|
answers.extend(ans)
|
||||||
|
return prediction, answers
|
||||||
|
|
||||||
|
|
||||||
|
def generate_global(keep, scores, check_keep, check_scores):
|
||||||
|
"""
|
||||||
|
Use a simple global threshold sweep to predict if the examples in
|
||||||
|
check_scores were training data or not, using the ground truth answer from
|
||||||
|
check_keep.
|
||||||
|
"""
|
||||||
|
prediction = []
|
||||||
|
answers = []
|
||||||
|
for ans, sc in zip(check_keep, check_scores):
|
||||||
|
prediction.extend(-sc.mean(1))
|
||||||
|
answers.extend(ans)
|
||||||
|
|
||||||
|
return prediction, answers
|
||||||
|
|
||||||
|
def do_plot(fn, keep, scores, ntest, legend='', metric='auc', sweep_fn=sweep, **plot_kwargs):
|
||||||
|
"""
|
||||||
|
Generate the ROC curves by using ntest models as test models and the rest to train.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prediction, answers = fn(keep[:-ntest],
|
||||||
|
scores[:-ntest],
|
||||||
|
keep[-ntest:],
|
||||||
|
scores[-ntest:])
|
||||||
|
|
||||||
|
fpr, tpr, auc, acc = sweep_fn(np.array(prediction), np.array(answers, dtype=bool))
|
||||||
|
|
||||||
|
low = tpr[np.where(fpr<.001)[0][-1]]
|
||||||
|
|
||||||
|
print('Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc,acc, low))
|
||||||
|
|
||||||
|
metric_text = ''
|
||||||
|
if metric == 'auc':
|
||||||
|
metric_text = 'auc=%.3f'%auc
|
||||||
|
elif metric == 'acc':
|
||||||
|
metric_text = 'acc=%.3f'%acc
|
||||||
|
|
||||||
|
plt.plot(fpr, tpr, label=legend+metric_text, **plot_kwargs)
|
||||||
|
return (acc,auc)
|
||||||
|
|
||||||
|
|
||||||
|
def fig_fpr_tpr():
|
||||||
|
|
||||||
|
plt.figure(figsize=(4,3))
|
||||||
|
|
||||||
|
do_plot(generate_ours,
|
||||||
|
keep, scores, 1,
|
||||||
|
"Ours (online)\n",
|
||||||
|
metric='auc'
|
||||||
|
)
|
||||||
|
|
||||||
|
do_plot(functools.partial(generate_ours, fix_variance=True),
|
||||||
|
keep, scores, 1,
|
||||||
|
"Ours (online, fixed variance)\n",
|
||||||
|
metric='auc'
|
||||||
|
)
|
||||||
|
|
||||||
|
do_plot(functools.partial(generate_ours_offline),
|
||||||
|
keep, scores, 1,
|
||||||
|
"Ours (offline)\n",
|
||||||
|
metric='auc'
|
||||||
|
)
|
||||||
|
|
||||||
|
do_plot(functools.partial(generate_ours_offline, fix_variance=True),
|
||||||
|
keep, scores, 1,
|
||||||
|
"Ours (offline, fixed variance)\n",
|
||||||
|
metric='auc'
|
||||||
|
)
|
||||||
|
|
||||||
|
do_plot(generate_global,
|
||||||
|
keep, scores, 1,
|
||||||
|
"Global threshold\n",
|
||||||
|
metric='auc'
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.semilogx()
|
||||||
|
plt.semilogy()
|
||||||
|
plt.xlim(1e-5,1)
|
||||||
|
plt.ylim(1e-5,1)
|
||||||
|
plt.xlabel("False Positive Rate")
|
||||||
|
plt.ylabel("True Positive Rate")
|
||||||
|
plt.plot([0, 1], [0, 1], ls='--', color='gray')
|
||||||
|
plt.subplots_adjust(bottom=.18, left=.18, top=.96, right=.96)
|
||||||
|
plt.legend(fontsize=8)
|
||||||
|
plt.savefig("/tmp/fprtpr.png")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
load_data("exp/cifar10/")
|
||||||
|
fig_fpr_tpr()
|
66
research/mi_lira_2021/score.py
Normal file
66
research/mi_lira_2021/score.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
# Copyright 2021 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
|
|
||||||
|
def load_one(base):
|
||||||
|
"""
|
||||||
|
This loads a logits and converts it to a scored prediction.
|
||||||
|
"""
|
||||||
|
root = os.path.join(logdir,base,'logits')
|
||||||
|
if not os.path.exists(root): return None
|
||||||
|
|
||||||
|
if not os.path.exists(os.path.join(logdir,base,'scores')):
|
||||||
|
os.mkdir(os.path.join(logdir,base,'scores'))
|
||||||
|
|
||||||
|
for f in os.listdir(root):
|
||||||
|
try:
|
||||||
|
opredictions = np.load(os.path.join(root,f))
|
||||||
|
except:
|
||||||
|
print("Fail")
|
||||||
|
continue
|
||||||
|
|
||||||
|
## Be exceptionally careful.
|
||||||
|
## Numerically stable everything, as described in the paper.
|
||||||
|
predictions = opredictions - np.max(opredictions, axis=3, keepdims=True)
|
||||||
|
predictions = np.array(np.exp(predictions), dtype=np.float64)
|
||||||
|
predictions = predictions/np.sum(predictions,axis=3,keepdims=True)
|
||||||
|
|
||||||
|
COUNT = predictions.shape[0]
|
||||||
|
# x num_examples x num_augmentations x logits
|
||||||
|
y_true = predictions[np.arange(COUNT),:,:,labels[:COUNT]]
|
||||||
|
print(y_true.shape)
|
||||||
|
|
||||||
|
print('mean acc',np.mean(predictions[:,0,0,:].argmax(1)==labels[:COUNT]))
|
||||||
|
|
||||||
|
predictions[np.arange(COUNT),:,:,labels[:COUNT]] = 0
|
||||||
|
y_wrong = np.sum(predictions, axis=3)
|
||||||
|
|
||||||
|
logit = (np.log(y_true.mean((1))+1e-45) - np.log(y_wrong.mean((1))+1e-45))
|
||||||
|
|
||||||
|
np.save(os.path.join(logdir, base, 'scores', f), logit)
|
||||||
|
|
||||||
|
|
||||||
|
def load_stats():
|
||||||
|
with mp.Pool(8) as p:
|
||||||
|
p.map(load_one, [x for x in os.listdir(logdir) if 'exp' in x])
|
||||||
|
|
||||||
|
|
||||||
|
logdir = sys.argv[1]
|
||||||
|
labels = np.load(os.path.join(logdir,"y_train.npy"))
|
||||||
|
load_stats()
|
16
research/mi_lira_2021/scripts/train_demo.sh
Normal file
16
research/mi_lira_2021/scripts/train_demo.sh
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15
|
18
research/mi_lira_2021/scripts/train_demo_multigpu.sh
Normal file
18
research/mi_lira_2021/scripts/train_demo_multigpu.sh
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0 &
|
||||||
|
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1 &
|
||||||
|
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2 &
|
||||||
|
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3 &
|
||||||
|
CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 &
|
||||||
|
CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 &
|
||||||
|
CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 &
|
||||||
|
CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 &
|
||||||
|
wait;
|
||||||
|
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8 &
|
||||||
|
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9 &
|
||||||
|
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10 &
|
||||||
|
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11 &
|
||||||
|
CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 &
|
||||||
|
CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 &
|
||||||
|
CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 &
|
||||||
|
CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 &
|
||||||
|
wait;
|
329
research/mi_lira_2021/train.py
Normal file
329
research/mi_lira_2021/train.py
Normal file
|
@ -0,0 +1,329 @@
|
||||||
|
# Copyright 2021 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from typing import Callable
|
||||||
|
import json
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jn
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf # For data augmentation.
|
||||||
|
import tensorflow_datasets as tfds
|
||||||
|
from absl import app, flags
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
|
import objax
|
||||||
|
from objax.jaxboard import SummaryWriter, Summary
|
||||||
|
from objax.util import EasyDict
|
||||||
|
from objax.zoo import convnet, wide_resnet, dnnet
|
||||||
|
|
||||||
|
from dataset import DataSet
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
def augment(x, shift: int, mirror=True):
|
||||||
|
"""
|
||||||
|
Augmentation function used in training the model.
|
||||||
|
"""
|
||||||
|
y = x['image']
|
||||||
|
if mirror:
|
||||||
|
y = tf.image.random_flip_left_right(y)
|
||||||
|
y = tf.pad(y, [[shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT')
|
||||||
|
y = tf.image.random_crop(y, tf.shape(x['image']))
|
||||||
|
return dict(image=y, label=x['label'])
|
||||||
|
|
||||||
|
|
||||||
|
class TrainLoop(objax.Module):
|
||||||
|
"""
|
||||||
|
Training loop for general machine learning models.
|
||||||
|
Based on the training loop from the objax CIFAR10 example code.
|
||||||
|
"""
|
||||||
|
predict: Callable
|
||||||
|
train_op: Callable
|
||||||
|
|
||||||
|
def __init__(self, nclass: int, **kwargs):
|
||||||
|
self.nclass = nclass
|
||||||
|
self.params = EasyDict(kwargs)
|
||||||
|
|
||||||
|
def train_step(self, summary: Summary, data: dict, progress: np.ndarray):
|
||||||
|
kv = self.train_op(progress, data['image'].numpy(), data['label'].numpy())
|
||||||
|
for k, v in kv.items():
|
||||||
|
if jn.isnan(v):
|
||||||
|
raise ValueError('NaN, try reducing learning rate', k)
|
||||||
|
if summary is not None:
|
||||||
|
summary.scalar(k, float(v))
|
||||||
|
|
||||||
|
def train(self, num_train_epochs: int, train_size: int, train: DataSet, test: DataSet, logdir: str, save_steps=100, patience=None):
|
||||||
|
"""
|
||||||
|
Completely standard training. Nothing interesting to see here.
|
||||||
|
"""
|
||||||
|
checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=20, makedir=True)
|
||||||
|
start_epoch, last_ckpt = checkpoint.restore(self.vars())
|
||||||
|
train_iter = iter(train)
|
||||||
|
progress = np.zeros(jax.local_device_count(), 'f') # for multi-GPU
|
||||||
|
|
||||||
|
best_acc = 0
|
||||||
|
best_acc_epoch = -1
|
||||||
|
|
||||||
|
with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard:
|
||||||
|
for epoch in range(start_epoch, num_train_epochs):
|
||||||
|
# Train
|
||||||
|
summary = Summary()
|
||||||
|
loop = range(0, train_size, self.params.batch)
|
||||||
|
for step in loop:
|
||||||
|
progress[:] = (step + (epoch * train_size)) / (num_train_epochs * train_size)
|
||||||
|
self.train_step(summary, next(train_iter), progress)
|
||||||
|
|
||||||
|
# Eval
|
||||||
|
accuracy, total = 0, 0
|
||||||
|
if epoch%FLAGS.eval_steps == 0 and test is not None:
|
||||||
|
for data in test:
|
||||||
|
total += data['image'].shape[0]
|
||||||
|
preds = np.argmax(self.predict(data['image'].numpy()), axis=1)
|
||||||
|
accuracy += (preds == data['label'].numpy()).sum()
|
||||||
|
accuracy /= total
|
||||||
|
summary.scalar('eval/accuracy', 100 * accuracy)
|
||||||
|
tensorboard.write(summary, step=(epoch + 1) * train_size)
|
||||||
|
print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, summary['losses/xe'](),
|
||||||
|
summary['eval/accuracy']()))
|
||||||
|
|
||||||
|
if summary['eval/accuracy']() > best_acc:
|
||||||
|
best_acc = summary['eval/accuracy']()
|
||||||
|
best_acc_epoch = epoch
|
||||||
|
elif patience is not None and epoch > best_acc_epoch + patience:
|
||||||
|
print("early stopping!")
|
||||||
|
checkpoint.save(self.vars(), epoch + 1)
|
||||||
|
return
|
||||||
|
|
||||||
|
else:
|
||||||
|
print('Epoch %04d Loss %.2f Accuracy --' % (epoch + 1, summary['losses/xe']()))
|
||||||
|
|
||||||
|
if epoch%save_steps == save_steps-1:
|
||||||
|
checkpoint.save(self.vars(), epoch + 1)
|
||||||
|
|
||||||
|
|
||||||
|
# We inherit from the training loop and define predict and train_op.
|
||||||
|
class MemModule(TrainLoop):
|
||||||
|
def __init__(self, model: Callable, nclass: int, mnist=False, **kwargs):
|
||||||
|
"""
|
||||||
|
Completely standard training. Nothing interesting to see here.
|
||||||
|
"""
|
||||||
|
super().__init__(nclass, **kwargs)
|
||||||
|
self.model = model(1 if mnist else 3, nclass)
|
||||||
|
self.opt = objax.optimizer.Momentum(self.model.vars())
|
||||||
|
self.model_ema = objax.optimizer.ExponentialMovingAverageModule(self.model, momentum=0.999, debias=True)
|
||||||
|
|
||||||
|
@objax.Function.with_vars(self.model.vars())
|
||||||
|
def loss(x, label):
|
||||||
|
logit = self.model(x, training=True)
|
||||||
|
loss_wd = 0.5 * sum((v.value ** 2).sum() for k, v in self.model.vars().items() if k.endswith('.w'))
|
||||||
|
loss_xe = objax.functional.loss.cross_entropy_logits(logit, label).mean()
|
||||||
|
return loss_xe + loss_wd * self.params.weight_decay, {'losses/xe': loss_xe, 'losses/wd': loss_wd}
|
||||||
|
|
||||||
|
gv = objax.GradValues(loss, self.model.vars())
|
||||||
|
self.gv = gv
|
||||||
|
|
||||||
|
@objax.Function.with_vars(self.vars())
|
||||||
|
def train_op(progress, x, y):
|
||||||
|
g, v = gv(x, y)
|
||||||
|
lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8))
|
||||||
|
lr = lr * jn.clip(progress*100,0,1)
|
||||||
|
self.opt(lr, g)
|
||||||
|
self.model_ema.update_ema()
|
||||||
|
return {'monitors/lr': lr, **v[1]}
|
||||||
|
|
||||||
|
self.predict = objax.Jit(objax.nn.Sequential([objax.ForceArgs(self.model_ema, training=False)]))
|
||||||
|
|
||||||
|
self.train_op = objax.Jit(train_op)
|
||||||
|
|
||||||
|
|
||||||
|
def network(arch: str):
|
||||||
|
if arch == 'cnn32-3-max':
|
||||||
|
return functools.partial(convnet.ConvNet, scales=3, filters=32, filters_max=1024,
|
||||||
|
pooling=objax.functional.max_pool_2d)
|
||||||
|
elif arch == 'cnn32-3-mean':
|
||||||
|
return functools.partial(convnet.ConvNet, scales=3, filters=32, filters_max=1024,
|
||||||
|
pooling=objax.functional.average_pool_2d)
|
||||||
|
elif arch == 'cnn64-3-max':
|
||||||
|
return functools.partial(convnet.ConvNet, scales=3, filters=64, filters_max=1024,
|
||||||
|
pooling=objax.functional.max_pool_2d)
|
||||||
|
elif arch == 'cnn64-3-mean':
|
||||||
|
return functools.partial(convnet.ConvNet, scales=3, filters=64, filters_max=1024,
|
||||||
|
pooling=objax.functional.average_pool_2d)
|
||||||
|
elif arch == 'wrn28-1':
|
||||||
|
return functools.partial(wide_resnet.WideResNet, depth=28, width=1)
|
||||||
|
elif arch == 'wrn28-2':
|
||||||
|
return functools.partial(wide_resnet.WideResNet, depth=28, width=2)
|
||||||
|
elif arch == 'wrn28-10':
|
||||||
|
return functools.partial(wide_resnet.WideResNet, depth=28, width=10)
|
||||||
|
raise ValueError('Architecture not recognized', arch)
|
||||||
|
|
||||||
|
def get_data(seed):
|
||||||
|
"""
|
||||||
|
This is the function to generate subsets of the data for training models.
|
||||||
|
|
||||||
|
First, we get the training dataset either from the numpy cache
|
||||||
|
or otherwise we load it from tensorflow datasets.
|
||||||
|
|
||||||
|
Then, we compute the subset. This works in one of two ways.
|
||||||
|
|
||||||
|
1. If we have a seed, then we just randomly choose examples based on
|
||||||
|
a prng with that seed, keeping FLAGS.pkeep fraction of the data.
|
||||||
|
|
||||||
|
2. Otherwise, if we have an experiment ID, then we do something fancier.
|
||||||
|
If we run each experiment independently then even after a lot of trials
|
||||||
|
there will still probably be some examples that were always included
|
||||||
|
or always excluded. So instead, with experiment IDs, we guarantee that
|
||||||
|
after FLAGS.num_experiments are done, each example is seen exactly half
|
||||||
|
of the time in train, and half of the time not in train.
|
||||||
|
|
||||||
|
"""
|
||||||
|
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(FLAGS.logdir, "x_train.npy")):
|
||||||
|
inputs = np.load(os.path.join(FLAGS.logdir, "x_train.npy"))
|
||||||
|
labels = np.load(os.path.join(FLAGS.logdir, "y_train.npy"))
|
||||||
|
else:
|
||||||
|
print("First time, creating dataset")
|
||||||
|
data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR))
|
||||||
|
inputs = data['train']['image']
|
||||||
|
labels = data['train']['label']
|
||||||
|
|
||||||
|
inputs = (inputs/127.5)-1
|
||||||
|
np.save(os.path.join(FLAGS.logdir, "x_train.npy"),inputs)
|
||||||
|
np.save(os.path.join(FLAGS.logdir, "y_train.npy"),labels)
|
||||||
|
|
||||||
|
nclass = np.max(labels)+1
|
||||||
|
|
||||||
|
np.random.seed(seed)
|
||||||
|
if FLAGS.num_experiments is not None:
|
||||||
|
np.random.seed(0)
|
||||||
|
keep = np.random.uniform(0,1,size=(FLAGS.num_experiments, FLAGS.dataset_size))
|
||||||
|
order = keep.argsort(0)
|
||||||
|
keep = order < int(FLAGS.pkeep * FLAGS.num_experiments)
|
||||||
|
keep = np.array(keep[FLAGS.expid], dtype=bool)
|
||||||
|
else:
|
||||||
|
keep = np.random.uniform(0, 1, size=FLAGS.dataset_size) <= FLAGS.pkeep
|
||||||
|
|
||||||
|
if FLAGS.only_subset is not None:
|
||||||
|
keep[FLAGS.only_subset:] = 0
|
||||||
|
|
||||||
|
xs = inputs[keep]
|
||||||
|
ys = labels[keep]
|
||||||
|
|
||||||
|
if FLAGS.augment == 'weak':
|
||||||
|
aug = lambda x: augment(x, 4)
|
||||||
|
elif FLAGS.augment == 'mirror':
|
||||||
|
aug = lambda x: augment(x, 0)
|
||||||
|
elif FLAGS.augment == 'none':
|
||||||
|
aug = lambda x: augment(x, 0, mirror=False)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
train = DataSet.from_arrays(xs, ys,
|
||||||
|
augment_fn=aug)
|
||||||
|
test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:])
|
||||||
|
train = train.cache().shuffle(8192).repeat().parse().augment().batch(FLAGS.batch)
|
||||||
|
train = train.nchw().one_hot(nclass).prefetch(16)
|
||||||
|
test = test.cache().parse().batch(FLAGS.batch).nchw().prefetch(16)
|
||||||
|
|
||||||
|
return train, test, xs, ys, keep, nclass
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
del argv
|
||||||
|
tf.config.experimental.set_visible_devices([], "GPU")
|
||||||
|
|
||||||
|
seed = FLAGS.seed
|
||||||
|
if seed is None:
|
||||||
|
import time
|
||||||
|
seed = np.random.randint(0, 1000000000)
|
||||||
|
seed ^= int(time.time())
|
||||||
|
|
||||||
|
args = EasyDict(arch=FLAGS.arch,
|
||||||
|
lr=FLAGS.lr,
|
||||||
|
batch=FLAGS.batch,
|
||||||
|
weight_decay=FLAGS.weight_decay,
|
||||||
|
augment=FLAGS.augment,
|
||||||
|
seed=seed)
|
||||||
|
|
||||||
|
|
||||||
|
if FLAGS.tunename:
|
||||||
|
logdir = '_'.join(sorted('%s=%s' % k for k in args.items()))
|
||||||
|
elif FLAGS.expid is not None:
|
||||||
|
logdir = "experiment-%d_%d"%(FLAGS.expid,FLAGS.num_experiments)
|
||||||
|
else:
|
||||||
|
logdir = "experiment-"+str(seed)
|
||||||
|
logdir = os.path.join(FLAGS.logdir, logdir)
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%10)):
|
||||||
|
print(f"run {FLAGS.expid} already completed.")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
if os.path.exists(logdir):
|
||||||
|
print(f"deleting run {FLAGS.expid} that did not complete.")
|
||||||
|
shutil.rmtree(logdir)
|
||||||
|
|
||||||
|
print(f"starting run {FLAGS.expid}.")
|
||||||
|
if not os.path.exists(logdir):
|
||||||
|
os.makedirs(logdir)
|
||||||
|
|
||||||
|
train, test, xs, ys, keep, nclass = get_data(seed)
|
||||||
|
|
||||||
|
# Define the network and train_it
|
||||||
|
tm = MemModule(network(FLAGS.arch), nclass=nclass,
|
||||||
|
mnist=FLAGS.dataset == 'mnist',
|
||||||
|
epochs=FLAGS.epochs,
|
||||||
|
expid=FLAGS.expid,
|
||||||
|
num_experiments=FLAGS.num_experiments,
|
||||||
|
pkeep=FLAGS.pkeep,
|
||||||
|
save_steps=FLAGS.save_steps,
|
||||||
|
only_subset=FLAGS.only_subset,
|
||||||
|
**args
|
||||||
|
)
|
||||||
|
|
||||||
|
r = {}
|
||||||
|
r.update(tm.params)
|
||||||
|
|
||||||
|
open(os.path.join(logdir,'hparams.json'),"w").write(json.dumps(tm.params))
|
||||||
|
np.save(os.path.join(logdir,'keep.npy'), keep)
|
||||||
|
|
||||||
|
tm.train(FLAGS.epochs, len(xs), train, test, logdir,
|
||||||
|
save_steps=FLAGS.save_steps, patience=FLAGS.patience)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.')
|
||||||
|
flags.DEFINE_float('lr', 0.1, 'Learning rate.')
|
||||||
|
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
|
||||||
|
flags.DEFINE_float('weight_decay', 0.0005, 'Weight decay ratio.')
|
||||||
|
flags.DEFINE_integer('batch', 256, 'Batch size')
|
||||||
|
flags.DEFINE_integer('epochs', 501, 'Training duration in number of epochs.')
|
||||||
|
flags.DEFINE_string('logdir', 'experiments', 'Directory where to save checkpoints and tensorboard data.')
|
||||||
|
flags.DEFINE_integer('seed', None, 'Training seed.')
|
||||||
|
flags.DEFINE_float('pkeep', .5, 'Probability to keep examples.')
|
||||||
|
flags.DEFINE_integer('expid', None, 'Experiment ID')
|
||||||
|
flags.DEFINE_integer('num_experiments', None, 'Number of experiments')
|
||||||
|
flags.DEFINE_string('augment', 'weak', 'Strong or weak augmentation')
|
||||||
|
flags.DEFINE_integer('only_subset', None, 'Only train on a subset of images.')
|
||||||
|
flags.DEFINE_integer('dataset_size', 50000, 'number of examples to keep.')
|
||||||
|
flags.DEFINE_integer('eval_steps', 1, 'how often to get eval accuracy.')
|
||||||
|
flags.DEFINE_integer('abort_after_epoch', None, 'stop trainin early at an epoch')
|
||||||
|
flags.DEFINE_integer('save_steps', 10, 'how often to get save model.')
|
||||||
|
flags.DEFINE_integer('patience', None, 'Early stopping after this many epochs without progress')
|
||||||
|
flags.DEFINE_bool('tunename', False, 'Use tune name?')
|
||||||
|
app.run(main)
|
Loading…
Reference in a new issue