Add code to reproduce Membership Inference Attacks From First Principles

This commit is contained in:
Nicholas Carlini 2021-12-14 00:50:49 +00:00
parent 8850c23f67
commit 7e40ad9704
10 changed files with 1012 additions and 0 deletions

View file

@ -0,0 +1,114 @@
This directory contains code to reproduce our paper:
**"Membership Inference Attacks From First Principles"**
https://arxiv.org/abs/2112.03570
by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramer.
###INSTALLING
You will need to install fairly standard dependencies
`pip install scipy, sklearn, numpy, matplotlib`
and also some machine learning framework to train models. We train our models
with JAX + ObJAX so you will need to follow build instructions for that
https://github.com/google/objax
https://objax.readthedocs.io/en/latest/installation_setup.html
###RUNNING THE CODE
####1. Train the models
The first step in our attack is to train shadow models. As a baseline
that should give most of the gains in our attack, you should start by
training 16 shadow models with the command
> bash scripts/train_demo.sh
or if you have multiple GPUs on your machine and want to train these models
in parallel, then modify and run
> bash scripts/train_demo_multigpu.sh
This will train several CIFAR-10 wide ResNet models to ~91% accuracy each, and
will output a bunch of files under the directory exp/cifar10 with structure:
```
exp/cifar10/
- experiment_N_of_16
-- hparams.json
-- keep.npy
-- ckpt/
--- 0000000100.npz
-- tb/
```
####2. Perform inference
Once the models are trained, now it's necessary to perform inference and save
the output features for each training example for each model in the dataset.
> python3 inference.py --logdir=exp/cifar10/
This will add to the experiment directory a new set of files
```
exp/cifar10/
- experiment_N_of_16
-- logits/
--- 0000000100.npy
```
where this new file has shape (50000, 10) and stores the model's
output features for each example.
####3. Compute membership inference scores
Finally we take the output features and generate our logit-scaled membership inference
scores for each example for each model.
> python3 score.py exp/cifar10/
And this in turn generates a new directory
```
exp/cifar10/
- experiment_N_of_16
-- scores/
--- 0000000100.npy
```
with shape (50000,) storing just our scores.
###PLOTTING THE RESULTS
Finally we can generate pretty pictures, and run the plotting code
> python3 plot.py
which should give (something like) the following output
![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve")
```
Attack Ours (online)
AUC 0.6676, Accuracy 0.6077, TPR@0.1%FPR of 0.0169
Attack Ours (online, fixed variance)
AUC 0.6856, Accuracy 0.6137, TPR@0.1%FPR of 0.0593
Attack Ours (offline)
AUC 0.5488, Accuracy 0.5500, TPR@0.1%FPR of 0.0130
Attack Ours (offline, fixed variance)
AUC 0.5549, Accuracy 0.5537, TPR@0.1%FPR of 0.0299
Attack Global threshold
AUC 0.5921, Accuracy 0.6044, TPR@0.1%FPR of 0.0009
```
where the global threshold attack is the baseline, and our online,
online-with-fixed-variance, offline, and offline-with-fixed-variance
attack variants are the four other curves. Note that because we only
train a few models, the fixed variance variants perform best.

View file

@ -0,0 +1,95 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Tuple, List
import numpy as np
import tensorflow as tf
def record_parse(serialized_example: str, image_shape: Tuple[int, int, int]):
features = tf.io.parse_single_example(serialized_example,
features={'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)})
image = tf.image.decode_image(features['image']).set_shape(image_shape)
image = tf.cast(image, tf.float32) * (2.0 / 255) - 1.0
return dict(image=image, label=features['label'])
class DataSet:
"""Wrapper for tf.data.Dataset to permit extensions."""
def __init__(self, data: tf.data.Dataset,
image_shape: Tuple[int, int, int],
augment_fn: Optional[Callable] = None,
parse_fn: Optional[Callable] = record_parse):
self.data = data
self.parse_fn = parse_fn
self.augment_fn = augment_fn
self.image_shape = image_shape
@classmethod
def from_arrays(cls, images: np.ndarray, labels: np.ndarray, augment_fn: Optional[Callable] = None):
return cls(tf.data.Dataset.from_tensor_slices(dict(image=images, label=labels)), images.shape[1:],
augment_fn=augment_fn, parse_fn=None)
@classmethod
def from_files(cls, filenames: List[str],
image_shape: Tuple[int, int, int],
augment_fn: Optional[Callable],
parse_fn: Optional[Callable] = record_parse):
filenames_in = filenames
filenames = sorted(sum([tf.io.gfile.glob(x) for x in filenames], []))
if not filenames:
raise ValueError('Empty dataset, files not found:', filenames_in)
return cls(tf.data.TFRecordDataset(filenames), image_shape, augment_fn=augment_fn, parse_fn=parse_fn)
@classmethod
def from_tfds(cls, dataset: tf.data.Dataset, image_shape: Tuple[int, int, int],
augment_fn: Optional[Callable] = None):
return cls(dataset.map(lambda x: dict(image=tf.cast(x['image'], tf.float32) / 127.5 - 1, label=x['label'])),
image_shape, augment_fn=augment_fn, parse_fn=None)
def __iter__(self):
return iter(self.data)
def __getattr__(self, item):
if item in self.__dict__:
return self.__dict__[item]
def call_and_update(*args, **kwargs):
v = getattr(self.__dict__['data'], item)(*args, **kwargs)
if isinstance(v, tf.data.Dataset):
return self.__class__(v, self.image_shape, augment_fn=self.augment_fn, parse_fn=self.parse_fn)
return v
return call_and_update
def augment(self, para_augment: int = 4):
if self.augment_fn:
return self.map(self.augment_fn, para_augment)
return self
def nchw(self):
return self.map(lambda x: dict(image=tf.transpose(x['image'], [0, 3, 1, 2]), label=x['label']))
def one_hot(self, nclass: int):
return self.map(lambda x: dict(image=x['image'], label=tf.one_hot(x['label'], nclass)))
def parse(self, para_parse: int = 2):
if not self.parse_fn:
return self
if self.image_shape:
return self.map(lambda x: self.parse_fn(x, self.image_shape), para_parse)
return self.map(self.parse_fn, para_parse)

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

View file

@ -0,0 +1,150 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
from typing import Callable
import json
import re
import jax
import jax.numpy as jn
import numpy as np
import tensorflow as tf # For data augmentation.
import tensorflow_datasets as tfds
from absl import app, flags
from tqdm import tqdm, trange
import pickle
from functools import partial
import objax
from objax.jaxboard import SummaryWriter, Summary
from objax.util import EasyDict
from objax.zoo import convnet, wide_resnet
from dataset import DataSet
from train import MemModule, network
from collections import defaultdict
FLAGS = flags.FLAGS
def main(argv):
"""
Perform inference of the saved model in order to generate the
output logits, using a particular set of augmentations.
"""
del argv
tf.config.experimental.set_visible_devices([], "GPU")
def load(arch):
return MemModule(network(arch), nclass=100 if FLAGS.dataset == 'cifar100' else 10,
mnist=FLAGS.dataset == 'mnist',
arch=arch,
lr=.1,
batch=0,
epochs=0,
weight_decay=0)
def cache_load(arch):
thing = []
def fn():
if len(thing) == 0:
thing.append(load(arch))
return thing[0]
return fn
xs_all = np.load(os.path.join(FLAGS.logdir,"x_train.npy"))[:FLAGS.dataset_size]
ys_all = np.load(os.path.join(FLAGS.logdir,"y_train.npy"))[:FLAGS.dataset_size]
def get_loss(model, xbatch, ybatch, shift, reflect=True, stride=1):
outs = []
for aug in [xbatch, xbatch[:,:,::-1,:]][:reflect+1]:
aug_pad = tf.pad(aug, [[0] * 2, [shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT').numpy()
for dx in range(0, 2*shift+1, stride):
for dy in range(0, 2*shift+1, stride):
this_x = aug_pad[:, dx:dx+32, dy:dy+32, :].transpose((0,3,1,2))
logits = model.model(this_x, training=True)
outs.append(logits)
print(np.array(outs).shape)
return np.array(outs).transpose((1, 0, 2))
N = 5000
def features(model, xbatch, ybatch):
return get_loss(model, xbatch, ybatch,
shift=0, reflect=True, stride=1)
for path in sorted(os.listdir(os.path.join(FLAGS.logdir))):
if re.search(FLAGS.regex, path) is None:
print("Skipping from regex")
continue
hparams = json.load(open(os.path.join(FLAGS.logdir, path, "hparams.json")))
arch = hparams['arch']
model = cache_load(arch)()
logdir = os.path.join(FLAGS.logdir, path)
checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=10, makedir=True)
max_epoch, last_ckpt = checkpoint.restore(model.vars())
if max_epoch == 0: continue
if not os.path.exists(os.path.join(FLAGS.logdir, path, "logits")):
os.mkdir(os.path.join(FLAGS.logdir, path, "logits"))
if FLAGS.from_epoch is not None:
first = FLAGS.from_epoch
else:
first = max_epoch-1
for epoch in range(first,max_epoch+1):
if not os.path.exists(os.path.join(FLAGS.logdir, path, "ckpt", "%010d.npz"%epoch)):
# no checkpoint saved here
continue
if os.path.exists(os.path.join(FLAGS.logdir, path, "logits", "%010d.npy"%epoch)):
print("Skipping already generated file", epoch)
continue
try:
start_epoch, last_ckpt = checkpoint.restore(model.vars(), epoch)
except:
print("Fail to load", epoch)
continue
stats = []
for i in range(0,len(xs_all),N):
stats.extend(features(model, xs_all[i:i+N],
ys_all[i:i+N]))
# This will be shape N, augs, nclass
np.save(os.path.join(FLAGS.logdir, path, "logits", "%010d"%epoch),
np.array(stats)[:,None,:,:])
if __name__ == '__main__':
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
flags.DEFINE_string('logdir', 'experiments/', 'Directory where to save checkpoints and tensorboard data.')
flags.DEFINE_string('regex', '.*experiment.*', 'keep files when matching')
flags.DEFINE_bool('random_labels', False, 'use random labels.')
flags.DEFINE_integer('dataset_size', 50000, 'size of dataset.')
flags.DEFINE_integer('from_epoch', None, 'which epoch to load from.')
flags.DEFINE_integer('seed_mod', None, 'keep mod seed.')
flags.DEFINE_integer('modulus', 8, 'modulus.')
app.run(main)

View file

View file

@ -0,0 +1,224 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import scipy.stats
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import auc, roc_curve
import functools
# Look at me being proactive!
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
def sweep(score, x):
"""
Compute a ROC curve and then return the FPR, TPR, AUC, and ACC.
"""
fpr, tpr, _ = roc_curve(x, -score)
acc = np.max(1-(fpr+(1-tpr))/2)
return fpr, tpr, auc(fpr, tpr), acc
def load_data(p):
"""
Load our saved scores and then put them into a big matrix.
"""
global scores, keep
scores = []
keep = []
for root,ds,_ in os.walk(p):
for f in ds:
if not f.startswith("experiment"): continue
if not os.path.exists(os.path.join(root,f,"scores")): continue
last_epoch = sorted(os.listdir(os.path.join(root,f,"scores")))
if len(last_epoch) == 0: continue
scores.append(np.load(os.path.join(root,f,"scores",last_epoch[-1])))
keep.append(np.load(os.path.join(root,f,"keep.npy")))
scores = np.array(scores)
keep = np.array(keep)[:,:scores.shape[1]]
return scores, keep
def generate_ours(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000,
fix_variance=False):
"""
Fit a two predictive models using keep and scores in order to predict
if the examples in check_scores were training data or not, using the
ground truth answer from check_keep.
"""
dat_in = []
dat_out = []
for j in range(scores.shape[1]):
dat_in.append(scores[keep[:,j],j,:])
dat_out.append(scores[~keep[:,j],j,:])
in_size = min(min(map(len,dat_in)), in_size)
out_size = min(min(map(len,dat_out)), out_size)
dat_in = np.array([x[:in_size] for x in dat_in])
dat_out = np.array([x[:out_size] for x in dat_out])
mean_in = np.median(dat_in, 1)
mean_out = np.median(dat_out, 1)
if fix_variance:
std_in = np.std(dat_in)
std_out = np.std(dat_in)
else:
std_in = np.std(dat_in, 1)
std_out = np.std(dat_out, 1)
prediction = []
answers = []
for ans, sc in zip(check_keep, check_scores):
pr_in = -scipy.stats.norm.logpdf(sc, mean_in, std_in+1e-30)
pr_out = -scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30)
score = pr_in-pr_out
prediction.extend(score.mean(1))
answers.extend(ans)
return prediction, answers
def generate_ours_offline(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000,
fix_variance=False):
"""
Fit a single predictive model using keep and scores in order to predict
if the examples in check_scores were training data or not, using the
ground truth answer from check_keep.
"""
dat_in = []
dat_out = []
for j in range(scores.shape[1]):
dat_in.append(scores[keep[:, j], j, :])
dat_out.append(scores[~keep[:, j], j, :])
out_size = min(min(map(len,dat_out)), out_size)
dat_out = np.array([x[:out_size] for x in dat_out])
mean_out = np.median(dat_out, 1)
if fix_variance:
std_out = np.std(dat_out)
else:
std_out = np.std(dat_out, 1)
prediction = []
answers = []
for ans, sc in zip(check_keep, check_scores):
score = scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30)
prediction.extend(score.mean(1))
answers.extend(ans)
return prediction, answers
def generate_global(keep, scores, check_keep, check_scores):
"""
Use a simple global threshold sweep to predict if the examples in
check_scores were training data or not, using the ground truth answer from
check_keep.
"""
prediction = []
answers = []
for ans, sc in zip(check_keep, check_scores):
prediction.extend(-sc.mean(1))
answers.extend(ans)
return prediction, answers
def do_plot(fn, keep, scores, ntest, legend='', metric='auc', sweep_fn=sweep, **plot_kwargs):
"""
Generate the ROC curves by using ntest models as test models and the rest to train.
"""
prediction, answers = fn(keep[:-ntest],
scores[:-ntest],
keep[-ntest:],
scores[-ntest:])
fpr, tpr, auc, acc = sweep_fn(np.array(prediction), np.array(answers, dtype=bool))
low = tpr[np.where(fpr<.001)[0][-1]]
print('Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc,acc, low))
metric_text = ''
if metric == 'auc':
metric_text = 'auc=%.3f'%auc
elif metric == 'acc':
metric_text = 'acc=%.3f'%acc
plt.plot(fpr, tpr, label=legend+metric_text, **plot_kwargs)
return (acc,auc)
def fig_fpr_tpr():
plt.figure(figsize=(4,3))
do_plot(generate_ours,
keep, scores, 1,
"Ours (online)\n",
metric='auc'
)
do_plot(functools.partial(generate_ours, fix_variance=True),
keep, scores, 1,
"Ours (online, fixed variance)\n",
metric='auc'
)
do_plot(functools.partial(generate_ours_offline),
keep, scores, 1,
"Ours (offline)\n",
metric='auc'
)
do_plot(functools.partial(generate_ours_offline, fix_variance=True),
keep, scores, 1,
"Ours (offline, fixed variance)\n",
metric='auc'
)
do_plot(generate_global,
keep, scores, 1,
"Global threshold\n",
metric='auc'
)
plt.semilogx()
plt.semilogy()
plt.xlim(1e-5,1)
plt.ylim(1e-5,1)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.plot([0, 1], [0, 1], ls='--', color='gray')
plt.subplots_adjust(bottom=.18, left=.18, top=.96, right=.96)
plt.legend(fontsize=8)
plt.savefig("/tmp/fprtpr.png")
plt.show()
load_data("exp/cifar10/")
fig_fpr_tpr()

View file

@ -0,0 +1,66 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
import os
import multiprocessing as mp
def load_one(base):
"""
This loads a logits and converts it to a scored prediction.
"""
root = os.path.join(logdir,base,'logits')
if not os.path.exists(root): return None
if not os.path.exists(os.path.join(logdir,base,'scores')):
os.mkdir(os.path.join(logdir,base,'scores'))
for f in os.listdir(root):
try:
opredictions = np.load(os.path.join(root,f))
except:
print("Fail")
continue
## Be exceptionally careful.
## Numerically stable everything, as described in the paper.
predictions = opredictions - np.max(opredictions, axis=3, keepdims=True)
predictions = np.array(np.exp(predictions), dtype=np.float64)
predictions = predictions/np.sum(predictions,axis=3,keepdims=True)
COUNT = predictions.shape[0]
# x num_examples x num_augmentations x logits
y_true = predictions[np.arange(COUNT),:,:,labels[:COUNT]]
print(y_true.shape)
print('mean acc',np.mean(predictions[:,0,0,:].argmax(1)==labels[:COUNT]))
predictions[np.arange(COUNT),:,:,labels[:COUNT]] = 0
y_wrong = np.sum(predictions, axis=3)
logit = (np.log(y_true.mean((1))+1e-45) - np.log(y_wrong.mean((1))+1e-45))
np.save(os.path.join(logdir, base, 'scores', f), logit)
def load_stats():
with mp.Pool(8) as p:
p.map(load_one, [x for x in os.listdir(logdir) if 'exp' in x])
logdir = sys.argv[1]
labels = np.load(os.path.join(logdir,"y_train.npy"))
load_stats()

View file

@ -0,0 +1,16 @@
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15

View file

@ -0,0 +1,18 @@
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0 &
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1 &
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2 &
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3 &
CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 &
CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 &
CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 &
CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 &
wait;
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8 &
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9 &
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10 &
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11 &
CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 &
CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 &
CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 &
CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 &
wait;

View file

@ -0,0 +1,329 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import shutil
from typing import Callable
import json
import jax
import jax.numpy as jn
import numpy as np
import tensorflow as tf # For data augmentation.
import tensorflow_datasets as tfds
from absl import app, flags
from tqdm import tqdm, trange
import objax
from objax.jaxboard import SummaryWriter, Summary
from objax.util import EasyDict
from objax.zoo import convnet, wide_resnet, dnnet
from dataset import DataSet
FLAGS = flags.FLAGS
def augment(x, shift: int, mirror=True):
"""
Augmentation function used in training the model.
"""
y = x['image']
if mirror:
y = tf.image.random_flip_left_right(y)
y = tf.pad(y, [[shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT')
y = tf.image.random_crop(y, tf.shape(x['image']))
return dict(image=y, label=x['label'])
class TrainLoop(objax.Module):
"""
Training loop for general machine learning models.
Based on the training loop from the objax CIFAR10 example code.
"""
predict: Callable
train_op: Callable
def __init__(self, nclass: int, **kwargs):
self.nclass = nclass
self.params = EasyDict(kwargs)
def train_step(self, summary: Summary, data: dict, progress: np.ndarray):
kv = self.train_op(progress, data['image'].numpy(), data['label'].numpy())
for k, v in kv.items():
if jn.isnan(v):
raise ValueError('NaN, try reducing learning rate', k)
if summary is not None:
summary.scalar(k, float(v))
def train(self, num_train_epochs: int, train_size: int, train: DataSet, test: DataSet, logdir: str, save_steps=100, patience=None):
"""
Completely standard training. Nothing interesting to see here.
"""
checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=20, makedir=True)
start_epoch, last_ckpt = checkpoint.restore(self.vars())
train_iter = iter(train)
progress = np.zeros(jax.local_device_count(), 'f') # for multi-GPU
best_acc = 0
best_acc_epoch = -1
with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard:
for epoch in range(start_epoch, num_train_epochs):
# Train
summary = Summary()
loop = range(0, train_size, self.params.batch)
for step in loop:
progress[:] = (step + (epoch * train_size)) / (num_train_epochs * train_size)
self.train_step(summary, next(train_iter), progress)
# Eval
accuracy, total = 0, 0
if epoch%FLAGS.eval_steps == 0 and test is not None:
for data in test:
total += data['image'].shape[0]
preds = np.argmax(self.predict(data['image'].numpy()), axis=1)
accuracy += (preds == data['label'].numpy()).sum()
accuracy /= total
summary.scalar('eval/accuracy', 100 * accuracy)
tensorboard.write(summary, step=(epoch + 1) * train_size)
print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, summary['losses/xe'](),
summary['eval/accuracy']()))
if summary['eval/accuracy']() > best_acc:
best_acc = summary['eval/accuracy']()
best_acc_epoch = epoch
elif patience is not None and epoch > best_acc_epoch + patience:
print("early stopping!")
checkpoint.save(self.vars(), epoch + 1)
return
else:
print('Epoch %04d Loss %.2f Accuracy --' % (epoch + 1, summary['losses/xe']()))
if epoch%save_steps == save_steps-1:
checkpoint.save(self.vars(), epoch + 1)
# We inherit from the training loop and define predict and train_op.
class MemModule(TrainLoop):
def __init__(self, model: Callable, nclass: int, mnist=False, **kwargs):
"""
Completely standard training. Nothing interesting to see here.
"""
super().__init__(nclass, **kwargs)
self.model = model(1 if mnist else 3, nclass)
self.opt = objax.optimizer.Momentum(self.model.vars())
self.model_ema = objax.optimizer.ExponentialMovingAverageModule(self.model, momentum=0.999, debias=True)
@objax.Function.with_vars(self.model.vars())
def loss(x, label):
logit = self.model(x, training=True)
loss_wd = 0.5 * sum((v.value ** 2).sum() for k, v in self.model.vars().items() if k.endswith('.w'))
loss_xe = objax.functional.loss.cross_entropy_logits(logit, label).mean()
return loss_xe + loss_wd * self.params.weight_decay, {'losses/xe': loss_xe, 'losses/wd': loss_wd}
gv = objax.GradValues(loss, self.model.vars())
self.gv = gv
@objax.Function.with_vars(self.vars())
def train_op(progress, x, y):
g, v = gv(x, y)
lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8))
lr = lr * jn.clip(progress*100,0,1)
self.opt(lr, g)
self.model_ema.update_ema()
return {'monitors/lr': lr, **v[1]}
self.predict = objax.Jit(objax.nn.Sequential([objax.ForceArgs(self.model_ema, training=False)]))
self.train_op = objax.Jit(train_op)
def network(arch: str):
if arch == 'cnn32-3-max':
return functools.partial(convnet.ConvNet, scales=3, filters=32, filters_max=1024,
pooling=objax.functional.max_pool_2d)
elif arch == 'cnn32-3-mean':
return functools.partial(convnet.ConvNet, scales=3, filters=32, filters_max=1024,
pooling=objax.functional.average_pool_2d)
elif arch == 'cnn64-3-max':
return functools.partial(convnet.ConvNet, scales=3, filters=64, filters_max=1024,
pooling=objax.functional.max_pool_2d)
elif arch == 'cnn64-3-mean':
return functools.partial(convnet.ConvNet, scales=3, filters=64, filters_max=1024,
pooling=objax.functional.average_pool_2d)
elif arch == 'wrn28-1':
return functools.partial(wide_resnet.WideResNet, depth=28, width=1)
elif arch == 'wrn28-2':
return functools.partial(wide_resnet.WideResNet, depth=28, width=2)
elif arch == 'wrn28-10':
return functools.partial(wide_resnet.WideResNet, depth=28, width=10)
raise ValueError('Architecture not recognized', arch)
def get_data(seed):
"""
This is the function to generate subsets of the data for training models.
First, we get the training dataset either from the numpy cache
or otherwise we load it from tensorflow datasets.
Then, we compute the subset. This works in one of two ways.
1. If we have a seed, then we just randomly choose examples based on
a prng with that seed, keeping FLAGS.pkeep fraction of the data.
2. Otherwise, if we have an experiment ID, then we do something fancier.
If we run each experiment independently then even after a lot of trials
there will still probably be some examples that were always included
or always excluded. So instead, with experiment IDs, we guarantee that
after FLAGS.num_experiments are done, each example is seen exactly half
of the time in train, and half of the time not in train.
"""
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
if os.path.exists(os.path.join(FLAGS.logdir, "x_train.npy")):
inputs = np.load(os.path.join(FLAGS.logdir, "x_train.npy"))
labels = np.load(os.path.join(FLAGS.logdir, "y_train.npy"))
else:
print("First time, creating dataset")
data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR))
inputs = data['train']['image']
labels = data['train']['label']
inputs = (inputs/127.5)-1
np.save(os.path.join(FLAGS.logdir, "x_train.npy"),inputs)
np.save(os.path.join(FLAGS.logdir, "y_train.npy"),labels)
nclass = np.max(labels)+1
np.random.seed(seed)
if FLAGS.num_experiments is not None:
np.random.seed(0)
keep = np.random.uniform(0,1,size=(FLAGS.num_experiments, FLAGS.dataset_size))
order = keep.argsort(0)
keep = order < int(FLAGS.pkeep * FLAGS.num_experiments)
keep = np.array(keep[FLAGS.expid], dtype=bool)
else:
keep = np.random.uniform(0, 1, size=FLAGS.dataset_size) <= FLAGS.pkeep
if FLAGS.only_subset is not None:
keep[FLAGS.only_subset:] = 0
xs = inputs[keep]
ys = labels[keep]
if FLAGS.augment == 'weak':
aug = lambda x: augment(x, 4)
elif FLAGS.augment == 'mirror':
aug = lambda x: augment(x, 0)
elif FLAGS.augment == 'none':
aug = lambda x: augment(x, 0, mirror=False)
else:
raise
train = DataSet.from_arrays(xs, ys,
augment_fn=aug)
test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:])
train = train.cache().shuffle(8192).repeat().parse().augment().batch(FLAGS.batch)
train = train.nchw().one_hot(nclass).prefetch(16)
test = test.cache().parse().batch(FLAGS.batch).nchw().prefetch(16)
return train, test, xs, ys, keep, nclass
def main(argv):
del argv
tf.config.experimental.set_visible_devices([], "GPU")
seed = FLAGS.seed
if seed is None:
import time
seed = np.random.randint(0, 1000000000)
seed ^= int(time.time())
args = EasyDict(arch=FLAGS.arch,
lr=FLAGS.lr,
batch=FLAGS.batch,
weight_decay=FLAGS.weight_decay,
augment=FLAGS.augment,
seed=seed)
if FLAGS.tunename:
logdir = '_'.join(sorted('%s=%s' % k for k in args.items()))
elif FLAGS.expid is not None:
logdir = "experiment-%d_%d"%(FLAGS.expid,FLAGS.num_experiments)
else:
logdir = "experiment-"+str(seed)
logdir = os.path.join(FLAGS.logdir, logdir)
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%10)):
print(f"run {FLAGS.expid} already completed.")
return
else:
if os.path.exists(logdir):
print(f"deleting run {FLAGS.expid} that did not complete.")
shutil.rmtree(logdir)
print(f"starting run {FLAGS.expid}.")
if not os.path.exists(logdir):
os.makedirs(logdir)
train, test, xs, ys, keep, nclass = get_data(seed)
# Define the network and train_it
tm = MemModule(network(FLAGS.arch), nclass=nclass,
mnist=FLAGS.dataset == 'mnist',
epochs=FLAGS.epochs,
expid=FLAGS.expid,
num_experiments=FLAGS.num_experiments,
pkeep=FLAGS.pkeep,
save_steps=FLAGS.save_steps,
only_subset=FLAGS.only_subset,
**args
)
r = {}
r.update(tm.params)
open(os.path.join(logdir,'hparams.json'),"w").write(json.dumps(tm.params))
np.save(os.path.join(logdir,'keep.npy'), keep)
tm.train(FLAGS.epochs, len(xs), train, test, logdir,
save_steps=FLAGS.save_steps, patience=FLAGS.patience)
if __name__ == '__main__':
flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.')
flags.DEFINE_float('lr', 0.1, 'Learning rate.')
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
flags.DEFINE_float('weight_decay', 0.0005, 'Weight decay ratio.')
flags.DEFINE_integer('batch', 256, 'Batch size')
flags.DEFINE_integer('epochs', 501, 'Training duration in number of epochs.')
flags.DEFINE_string('logdir', 'experiments', 'Directory where to save checkpoints and tensorboard data.')
flags.DEFINE_integer('seed', None, 'Training seed.')
flags.DEFINE_float('pkeep', .5, 'Probability to keep examples.')
flags.DEFINE_integer('expid', None, 'Experiment ID')
flags.DEFINE_integer('num_experiments', None, 'Number of experiments')
flags.DEFINE_string('augment', 'weak', 'Strong or weak augmentation')
flags.DEFINE_integer('only_subset', None, 'Only train on a subset of images.')
flags.DEFINE_integer('dataset_size', 50000, 'number of examples to keep.')
flags.DEFINE_integer('eval_steps', 1, 'how often to get eval accuracy.')
flags.DEFINE_integer('abort_after_epoch', None, 'stop trainin early at an epoch')
flags.DEFINE_integer('save_steps', 10, 'how often to get save model.')
flags.DEFINE_integer('patience', None, 'Early stopping after this many epochs without progress')
flags.DEFINE_bool('tunename', False, 'Use tune name?')
app.run(main)