Add code to reproduce Membership Inference Attacks From First Principles
This commit is contained in:
parent
8850c23f67
commit
7e40ad9704
10 changed files with 1012 additions and 0 deletions
114
research/mi_lira_2021/README.md
Normal file
114
research/mi_lira_2021/README.md
Normal file
|
@ -0,0 +1,114 @@
|
|||
This directory contains code to reproduce our paper:
|
||||
|
||||
**"Membership Inference Attacks From First Principles"**
|
||||
https://arxiv.org/abs/2112.03570
|
||||
by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramer.
|
||||
|
||||
|
||||
###INSTALLING
|
||||
|
||||
You will need to install fairly standard dependencies
|
||||
|
||||
`pip install scipy, sklearn, numpy, matplotlib`
|
||||
|
||||
and also some machine learning framework to train models. We train our models
|
||||
with JAX + ObJAX so you will need to follow build instructions for that
|
||||
https://github.com/google/objax
|
||||
https://objax.readthedocs.io/en/latest/installation_setup.html
|
||||
|
||||
|
||||
###RUNNING THE CODE
|
||||
|
||||
####1. Train the models
|
||||
|
||||
The first step in our attack is to train shadow models. As a baseline
|
||||
that should give most of the gains in our attack, you should start by
|
||||
training 16 shadow models with the command
|
||||
|
||||
> bash scripts/train_demo.sh
|
||||
|
||||
or if you have multiple GPUs on your machine and want to train these models
|
||||
in parallel, then modify and run
|
||||
|
||||
> bash scripts/train_demo_multigpu.sh
|
||||
|
||||
This will train several CIFAR-10 wide ResNet models to ~91% accuracy each, and
|
||||
will output a bunch of files under the directory exp/cifar10 with structure:
|
||||
|
||||
```
|
||||
exp/cifar10/
|
||||
- experiment_N_of_16
|
||||
-- hparams.json
|
||||
-- keep.npy
|
||||
-- ckpt/
|
||||
--- 0000000100.npz
|
||||
-- tb/
|
||||
```
|
||||
|
||||
####2. Perform inference
|
||||
|
||||
Once the models are trained, now it's necessary to perform inference and save
|
||||
the output features for each training example for each model in the dataset.
|
||||
|
||||
> python3 inference.py --logdir=exp/cifar10/
|
||||
|
||||
This will add to the experiment directory a new set of files
|
||||
|
||||
```
|
||||
exp/cifar10/
|
||||
- experiment_N_of_16
|
||||
-- logits/
|
||||
--- 0000000100.npy
|
||||
```
|
||||
|
||||
where this new file has shape (50000, 10) and stores the model's
|
||||
output features for each example.
|
||||
|
||||
|
||||
####3. Compute membership inference scores
|
||||
|
||||
Finally we take the output features and generate our logit-scaled membership inference
|
||||
scores for each example for each model.
|
||||
|
||||
> python3 score.py exp/cifar10/
|
||||
|
||||
And this in turn generates a new directory
|
||||
|
||||
```
|
||||
exp/cifar10/
|
||||
- experiment_N_of_16
|
||||
-- scores/
|
||||
--- 0000000100.npy
|
||||
```
|
||||
|
||||
with shape (50000,) storing just our scores.
|
||||
|
||||
|
||||
###PLOTTING THE RESULTS
|
||||
|
||||
Finally we can generate pretty pictures, and run the plotting code
|
||||
|
||||
> python3 plot.py
|
||||
|
||||
which should give (something like) the following output
|
||||
|
||||
|
||||
![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve")
|
||||
|
||||
```
|
||||
Attack Ours (online)
|
||||
AUC 0.6676, Accuracy 0.6077, TPR@0.1%FPR of 0.0169
|
||||
Attack Ours (online, fixed variance)
|
||||
AUC 0.6856, Accuracy 0.6137, TPR@0.1%FPR of 0.0593
|
||||
Attack Ours (offline)
|
||||
AUC 0.5488, Accuracy 0.5500, TPR@0.1%FPR of 0.0130
|
||||
Attack Ours (offline, fixed variance)
|
||||
AUC 0.5549, Accuracy 0.5537, TPR@0.1%FPR of 0.0299
|
||||
Attack Global threshold
|
||||
AUC 0.5921, Accuracy 0.6044, TPR@0.1%FPR of 0.0009
|
||||
```
|
||||
|
||||
where the global threshold attack is the baseline, and our online,
|
||||
online-with-fixed-variance, offline, and offline-with-fixed-variance
|
||||
attack variants are the four other curves. Note that because we only
|
||||
train a few models, the fixed variance variants perform best.
|
95
research/mi_lira_2021/dataset.py
Normal file
95
research/mi_lira_2021/dataset.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Optional, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def record_parse(serialized_example: str, image_shape: Tuple[int, int, int]):
|
||||
features = tf.io.parse_single_example(serialized_example,
|
||||
features={'image': tf.io.FixedLenFeature([], tf.string),
|
||||
'label': tf.io.FixedLenFeature([], tf.int64)})
|
||||
image = tf.image.decode_image(features['image']).set_shape(image_shape)
|
||||
image = tf.cast(image, tf.float32) * (2.0 / 255) - 1.0
|
||||
return dict(image=image, label=features['label'])
|
||||
|
||||
|
||||
class DataSet:
|
||||
"""Wrapper for tf.data.Dataset to permit extensions."""
|
||||
|
||||
def __init__(self, data: tf.data.Dataset,
|
||||
image_shape: Tuple[int, int, int],
|
||||
augment_fn: Optional[Callable] = None,
|
||||
parse_fn: Optional[Callable] = record_parse):
|
||||
self.data = data
|
||||
self.parse_fn = parse_fn
|
||||
self.augment_fn = augment_fn
|
||||
self.image_shape = image_shape
|
||||
|
||||
@classmethod
|
||||
def from_arrays(cls, images: np.ndarray, labels: np.ndarray, augment_fn: Optional[Callable] = None):
|
||||
return cls(tf.data.Dataset.from_tensor_slices(dict(image=images, label=labels)), images.shape[1:],
|
||||
augment_fn=augment_fn, parse_fn=None)
|
||||
|
||||
@classmethod
|
||||
def from_files(cls, filenames: List[str],
|
||||
image_shape: Tuple[int, int, int],
|
||||
augment_fn: Optional[Callable],
|
||||
parse_fn: Optional[Callable] = record_parse):
|
||||
filenames_in = filenames
|
||||
filenames = sorted(sum([tf.io.gfile.glob(x) for x in filenames], []))
|
||||
if not filenames:
|
||||
raise ValueError('Empty dataset, files not found:', filenames_in)
|
||||
return cls(tf.data.TFRecordDataset(filenames), image_shape, augment_fn=augment_fn, parse_fn=parse_fn)
|
||||
|
||||
@classmethod
|
||||
def from_tfds(cls, dataset: tf.data.Dataset, image_shape: Tuple[int, int, int],
|
||||
augment_fn: Optional[Callable] = None):
|
||||
return cls(dataset.map(lambda x: dict(image=tf.cast(x['image'], tf.float32) / 127.5 - 1, label=x['label'])),
|
||||
image_shape, augment_fn=augment_fn, parse_fn=None)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.data)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item in self.__dict__:
|
||||
return self.__dict__[item]
|
||||
|
||||
def call_and_update(*args, **kwargs):
|
||||
v = getattr(self.__dict__['data'], item)(*args, **kwargs)
|
||||
if isinstance(v, tf.data.Dataset):
|
||||
return self.__class__(v, self.image_shape, augment_fn=self.augment_fn, parse_fn=self.parse_fn)
|
||||
return v
|
||||
|
||||
return call_and_update
|
||||
|
||||
def augment(self, para_augment: int = 4):
|
||||
if self.augment_fn:
|
||||
return self.map(self.augment_fn, para_augment)
|
||||
return self
|
||||
|
||||
def nchw(self):
|
||||
return self.map(lambda x: dict(image=tf.transpose(x['image'], [0, 3, 1, 2]), label=x['label']))
|
||||
|
||||
def one_hot(self, nclass: int):
|
||||
return self.map(lambda x: dict(image=x['image'], label=tf.one_hot(x['label'], nclass)))
|
||||
|
||||
def parse(self, para_parse: int = 2):
|
||||
if not self.parse_fn:
|
||||
return self
|
||||
if self.image_shape:
|
||||
return self.map(lambda x: self.parse_fn(x, self.image_shape), para_parse)
|
||||
return self.map(self.parse_fn, para_parse)
|
BIN
research/mi_lira_2021/fprtpr.png
Normal file
BIN
research/mi_lira_2021/fprtpr.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 37 KiB |
150
research/mi_lira_2021/inference.py
Normal file
150
research/mi_lira_2021/inference.py
Normal file
|
@ -0,0 +1,150 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Callable
|
||||
import json
|
||||
|
||||
import re
|
||||
import jax
|
||||
import jax.numpy as jn
|
||||
import numpy as np
|
||||
import tensorflow as tf # For data augmentation.
|
||||
import tensorflow_datasets as tfds
|
||||
from absl import app, flags
|
||||
from tqdm import tqdm, trange
|
||||
import pickle
|
||||
from functools import partial
|
||||
|
||||
import objax
|
||||
from objax.jaxboard import SummaryWriter, Summary
|
||||
from objax.util import EasyDict
|
||||
from objax.zoo import convnet, wide_resnet
|
||||
|
||||
from dataset import DataSet
|
||||
|
||||
from train import MemModule, network
|
||||
|
||||
from collections import defaultdict
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def main(argv):
|
||||
"""
|
||||
Perform inference of the saved model in order to generate the
|
||||
output logits, using a particular set of augmentations.
|
||||
"""
|
||||
del argv
|
||||
tf.config.experimental.set_visible_devices([], "GPU")
|
||||
|
||||
def load(arch):
|
||||
return MemModule(network(arch), nclass=100 if FLAGS.dataset == 'cifar100' else 10,
|
||||
mnist=FLAGS.dataset == 'mnist',
|
||||
arch=arch,
|
||||
lr=.1,
|
||||
batch=0,
|
||||
epochs=0,
|
||||
weight_decay=0)
|
||||
|
||||
def cache_load(arch):
|
||||
thing = []
|
||||
def fn():
|
||||
if len(thing) == 0:
|
||||
thing.append(load(arch))
|
||||
return thing[0]
|
||||
return fn
|
||||
|
||||
xs_all = np.load(os.path.join(FLAGS.logdir,"x_train.npy"))[:FLAGS.dataset_size]
|
||||
ys_all = np.load(os.path.join(FLAGS.logdir,"y_train.npy"))[:FLAGS.dataset_size]
|
||||
|
||||
|
||||
def get_loss(model, xbatch, ybatch, shift, reflect=True, stride=1):
|
||||
|
||||
outs = []
|
||||
for aug in [xbatch, xbatch[:,:,::-1,:]][:reflect+1]:
|
||||
aug_pad = tf.pad(aug, [[0] * 2, [shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT').numpy()
|
||||
for dx in range(0, 2*shift+1, stride):
|
||||
for dy in range(0, 2*shift+1, stride):
|
||||
this_x = aug_pad[:, dx:dx+32, dy:dy+32, :].transpose((0,3,1,2))
|
||||
|
||||
logits = model.model(this_x, training=True)
|
||||
outs.append(logits)
|
||||
|
||||
print(np.array(outs).shape)
|
||||
return np.array(outs).transpose((1, 0, 2))
|
||||
|
||||
N = 5000
|
||||
|
||||
def features(model, xbatch, ybatch):
|
||||
return get_loss(model, xbatch, ybatch,
|
||||
shift=0, reflect=True, stride=1)
|
||||
|
||||
for path in sorted(os.listdir(os.path.join(FLAGS.logdir))):
|
||||
if re.search(FLAGS.regex, path) is None:
|
||||
print("Skipping from regex")
|
||||
continue
|
||||
|
||||
hparams = json.load(open(os.path.join(FLAGS.logdir, path, "hparams.json")))
|
||||
arch = hparams['arch']
|
||||
model = cache_load(arch)()
|
||||
|
||||
logdir = os.path.join(FLAGS.logdir, path)
|
||||
|
||||
checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=10, makedir=True)
|
||||
max_epoch, last_ckpt = checkpoint.restore(model.vars())
|
||||
if max_epoch == 0: continue
|
||||
|
||||
if not os.path.exists(os.path.join(FLAGS.logdir, path, "logits")):
|
||||
os.mkdir(os.path.join(FLAGS.logdir, path, "logits"))
|
||||
if FLAGS.from_epoch is not None:
|
||||
first = FLAGS.from_epoch
|
||||
else:
|
||||
first = max_epoch-1
|
||||
|
||||
for epoch in range(first,max_epoch+1):
|
||||
if not os.path.exists(os.path.join(FLAGS.logdir, path, "ckpt", "%010d.npz"%epoch)):
|
||||
# no checkpoint saved here
|
||||
continue
|
||||
|
||||
if os.path.exists(os.path.join(FLAGS.logdir, path, "logits", "%010d.npy"%epoch)):
|
||||
print("Skipping already generated file", epoch)
|
||||
continue
|
||||
|
||||
try:
|
||||
start_epoch, last_ckpt = checkpoint.restore(model.vars(), epoch)
|
||||
except:
|
||||
print("Fail to load", epoch)
|
||||
continue
|
||||
|
||||
stats = []
|
||||
|
||||
for i in range(0,len(xs_all),N):
|
||||
stats.extend(features(model, xs_all[i:i+N],
|
||||
ys_all[i:i+N]))
|
||||
# This will be shape N, augs, nclass
|
||||
|
||||
np.save(os.path.join(FLAGS.logdir, path, "logits", "%010d"%epoch),
|
||||
np.array(stats)[:,None,:,:])
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
|
||||
flags.DEFINE_string('logdir', 'experiments/', 'Directory where to save checkpoints and tensorboard data.')
|
||||
flags.DEFINE_string('regex', '.*experiment.*', 'keep files when matching')
|
||||
flags.DEFINE_bool('random_labels', False, 'use random labels.')
|
||||
flags.DEFINE_integer('dataset_size', 50000, 'size of dataset.')
|
||||
flags.DEFINE_integer('from_epoch', None, 'which epoch to load from.')
|
||||
flags.DEFINE_integer('seed_mod', None, 'keep mod seed.')
|
||||
flags.DEFINE_integer('modulus', 8, 'modulus.')
|
||||
app.run(main)
|
0
research/mi_lira_2021/logs/.keep
Normal file
0
research/mi_lira_2021/logs/.keep
Normal file
224
research/mi_lira_2021/plot.py
Normal file
224
research/mi_lira_2021/plot.py
Normal file
|
@ -0,0 +1,224 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import scipy.stats
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.metrics import auc, roc_curve
|
||||
import functools
|
||||
|
||||
# Look at me being proactive!
|
||||
import matplotlib
|
||||
matplotlib.rcParams['pdf.fonttype'] = 42
|
||||
matplotlib.rcParams['ps.fonttype'] = 42
|
||||
|
||||
|
||||
def sweep(score, x):
|
||||
"""
|
||||
Compute a ROC curve and then return the FPR, TPR, AUC, and ACC.
|
||||
"""
|
||||
fpr, tpr, _ = roc_curve(x, -score)
|
||||
acc = np.max(1-(fpr+(1-tpr))/2)
|
||||
return fpr, tpr, auc(fpr, tpr), acc
|
||||
|
||||
def load_data(p):
|
||||
"""
|
||||
Load our saved scores and then put them into a big matrix.
|
||||
"""
|
||||
global scores, keep
|
||||
scores = []
|
||||
keep = []
|
||||
|
||||
for root,ds,_ in os.walk(p):
|
||||
for f in ds:
|
||||
if not f.startswith("experiment"): continue
|
||||
if not os.path.exists(os.path.join(root,f,"scores")): continue
|
||||
last_epoch = sorted(os.listdir(os.path.join(root,f,"scores")))
|
||||
if len(last_epoch) == 0: continue
|
||||
scores.append(np.load(os.path.join(root,f,"scores",last_epoch[-1])))
|
||||
keep.append(np.load(os.path.join(root,f,"keep.npy")))
|
||||
|
||||
scores = np.array(scores)
|
||||
keep = np.array(keep)[:,:scores.shape[1]]
|
||||
|
||||
return scores, keep
|
||||
|
||||
def generate_ours(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000,
|
||||
fix_variance=False):
|
||||
"""
|
||||
Fit a two predictive models using keep and scores in order to predict
|
||||
if the examples in check_scores were training data or not, using the
|
||||
ground truth answer from check_keep.
|
||||
"""
|
||||
dat_in = []
|
||||
dat_out = []
|
||||
|
||||
for j in range(scores.shape[1]):
|
||||
dat_in.append(scores[keep[:,j],j,:])
|
||||
dat_out.append(scores[~keep[:,j],j,:])
|
||||
|
||||
in_size = min(min(map(len,dat_in)), in_size)
|
||||
out_size = min(min(map(len,dat_out)), out_size)
|
||||
|
||||
dat_in = np.array([x[:in_size] for x in dat_in])
|
||||
dat_out = np.array([x[:out_size] for x in dat_out])
|
||||
|
||||
mean_in = np.median(dat_in, 1)
|
||||
mean_out = np.median(dat_out, 1)
|
||||
|
||||
if fix_variance:
|
||||
std_in = np.std(dat_in)
|
||||
std_out = np.std(dat_in)
|
||||
else:
|
||||
std_in = np.std(dat_in, 1)
|
||||
std_out = np.std(dat_out, 1)
|
||||
|
||||
prediction = []
|
||||
answers = []
|
||||
for ans, sc in zip(check_keep, check_scores):
|
||||
pr_in = -scipy.stats.norm.logpdf(sc, mean_in, std_in+1e-30)
|
||||
pr_out = -scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30)
|
||||
score = pr_in-pr_out
|
||||
|
||||
prediction.extend(score.mean(1))
|
||||
answers.extend(ans)
|
||||
|
||||
return prediction, answers
|
||||
|
||||
def generate_ours_offline(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000,
|
||||
fix_variance=False):
|
||||
"""
|
||||
Fit a single predictive model using keep and scores in order to predict
|
||||
if the examples in check_scores were training data or not, using the
|
||||
ground truth answer from check_keep.
|
||||
"""
|
||||
dat_in = []
|
||||
dat_out = []
|
||||
|
||||
for j in range(scores.shape[1]):
|
||||
dat_in.append(scores[keep[:, j], j, :])
|
||||
dat_out.append(scores[~keep[:, j], j, :])
|
||||
|
||||
out_size = min(min(map(len,dat_out)), out_size)
|
||||
|
||||
dat_out = np.array([x[:out_size] for x in dat_out])
|
||||
|
||||
mean_out = np.median(dat_out, 1)
|
||||
|
||||
if fix_variance:
|
||||
std_out = np.std(dat_out)
|
||||
else:
|
||||
std_out = np.std(dat_out, 1)
|
||||
|
||||
prediction = []
|
||||
answers = []
|
||||
for ans, sc in zip(check_keep, check_scores):
|
||||
score = scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30)
|
||||
|
||||
prediction.extend(score.mean(1))
|
||||
answers.extend(ans)
|
||||
return prediction, answers
|
||||
|
||||
|
||||
def generate_global(keep, scores, check_keep, check_scores):
|
||||
"""
|
||||
Use a simple global threshold sweep to predict if the examples in
|
||||
check_scores were training data or not, using the ground truth answer from
|
||||
check_keep.
|
||||
"""
|
||||
prediction = []
|
||||
answers = []
|
||||
for ans, sc in zip(check_keep, check_scores):
|
||||
prediction.extend(-sc.mean(1))
|
||||
answers.extend(ans)
|
||||
|
||||
return prediction, answers
|
||||
|
||||
def do_plot(fn, keep, scores, ntest, legend='', metric='auc', sweep_fn=sweep, **plot_kwargs):
|
||||
"""
|
||||
Generate the ROC curves by using ntest models as test models and the rest to train.
|
||||
"""
|
||||
|
||||
prediction, answers = fn(keep[:-ntest],
|
||||
scores[:-ntest],
|
||||
keep[-ntest:],
|
||||
scores[-ntest:])
|
||||
|
||||
fpr, tpr, auc, acc = sweep_fn(np.array(prediction), np.array(answers, dtype=bool))
|
||||
|
||||
low = tpr[np.where(fpr<.001)[0][-1]]
|
||||
|
||||
print('Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc,acc, low))
|
||||
|
||||
metric_text = ''
|
||||
if metric == 'auc':
|
||||
metric_text = 'auc=%.3f'%auc
|
||||
elif metric == 'acc':
|
||||
metric_text = 'acc=%.3f'%acc
|
||||
|
||||
plt.plot(fpr, tpr, label=legend+metric_text, **plot_kwargs)
|
||||
return (acc,auc)
|
||||
|
||||
|
||||
def fig_fpr_tpr():
|
||||
|
||||
plt.figure(figsize=(4,3))
|
||||
|
||||
do_plot(generate_ours,
|
||||
keep, scores, 1,
|
||||
"Ours (online)\n",
|
||||
metric='auc'
|
||||
)
|
||||
|
||||
do_plot(functools.partial(generate_ours, fix_variance=True),
|
||||
keep, scores, 1,
|
||||
"Ours (online, fixed variance)\n",
|
||||
metric='auc'
|
||||
)
|
||||
|
||||
do_plot(functools.partial(generate_ours_offline),
|
||||
keep, scores, 1,
|
||||
"Ours (offline)\n",
|
||||
metric='auc'
|
||||
)
|
||||
|
||||
do_plot(functools.partial(generate_ours_offline, fix_variance=True),
|
||||
keep, scores, 1,
|
||||
"Ours (offline, fixed variance)\n",
|
||||
metric='auc'
|
||||
)
|
||||
|
||||
do_plot(generate_global,
|
||||
keep, scores, 1,
|
||||
"Global threshold\n",
|
||||
metric='auc'
|
||||
)
|
||||
|
||||
plt.semilogx()
|
||||
plt.semilogy()
|
||||
plt.xlim(1e-5,1)
|
||||
plt.ylim(1e-5,1)
|
||||
plt.xlabel("False Positive Rate")
|
||||
plt.ylabel("True Positive Rate")
|
||||
plt.plot([0, 1], [0, 1], ls='--', color='gray')
|
||||
plt.subplots_adjust(bottom=.18, left=.18, top=.96, right=.96)
|
||||
plt.legend(fontsize=8)
|
||||
plt.savefig("/tmp/fprtpr.png")
|
||||
plt.show()
|
||||
|
||||
|
||||
load_data("exp/cifar10/")
|
||||
fig_fpr_tpr()
|
66
research/mi_lira_2021/score.py
Normal file
66
research/mi_lira_2021/score.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
import os
|
||||
import multiprocessing as mp
|
||||
|
||||
|
||||
def load_one(base):
|
||||
"""
|
||||
This loads a logits and converts it to a scored prediction.
|
||||
"""
|
||||
root = os.path.join(logdir,base,'logits')
|
||||
if not os.path.exists(root): return None
|
||||
|
||||
if not os.path.exists(os.path.join(logdir,base,'scores')):
|
||||
os.mkdir(os.path.join(logdir,base,'scores'))
|
||||
|
||||
for f in os.listdir(root):
|
||||
try:
|
||||
opredictions = np.load(os.path.join(root,f))
|
||||
except:
|
||||
print("Fail")
|
||||
continue
|
||||
|
||||
## Be exceptionally careful.
|
||||
## Numerically stable everything, as described in the paper.
|
||||
predictions = opredictions - np.max(opredictions, axis=3, keepdims=True)
|
||||
predictions = np.array(np.exp(predictions), dtype=np.float64)
|
||||
predictions = predictions/np.sum(predictions,axis=3,keepdims=True)
|
||||
|
||||
COUNT = predictions.shape[0]
|
||||
# x num_examples x num_augmentations x logits
|
||||
y_true = predictions[np.arange(COUNT),:,:,labels[:COUNT]]
|
||||
print(y_true.shape)
|
||||
|
||||
print('mean acc',np.mean(predictions[:,0,0,:].argmax(1)==labels[:COUNT]))
|
||||
|
||||
predictions[np.arange(COUNT),:,:,labels[:COUNT]] = 0
|
||||
y_wrong = np.sum(predictions, axis=3)
|
||||
|
||||
logit = (np.log(y_true.mean((1))+1e-45) - np.log(y_wrong.mean((1))+1e-45))
|
||||
|
||||
np.save(os.path.join(logdir, base, 'scores', f), logit)
|
||||
|
||||
|
||||
def load_stats():
|
||||
with mp.Pool(8) as p:
|
||||
p.map(load_one, [x for x in os.listdir(logdir) if 'exp' in x])
|
||||
|
||||
|
||||
logdir = sys.argv[1]
|
||||
labels = np.load(os.path.join(logdir,"y_train.npy"))
|
||||
load_stats()
|
16
research/mi_lira_2021/scripts/train_demo.sh
Normal file
16
research/mi_lira_2021/scripts/train_demo.sh
Normal file
|
@ -0,0 +1,16 @@
|
|||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15
|
18
research/mi_lira_2021/scripts/train_demo_multigpu.sh
Normal file
18
research/mi_lira_2021/scripts/train_demo_multigpu.sh
Normal file
|
@ -0,0 +1,18 @@
|
|||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0 &
|
||||
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1 &
|
||||
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2 &
|
||||
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3 &
|
||||
CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 &
|
||||
CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 &
|
||||
CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 &
|
||||
CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 &
|
||||
wait;
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8 &
|
||||
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9 &
|
||||
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10 &
|
||||
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11 &
|
||||
CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 &
|
||||
CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 &
|
||||
CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 &
|
||||
CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 &
|
||||
wait;
|
329
research/mi_lira_2021/train.py
Normal file
329
research/mi_lira_2021/train.py
Normal file
|
@ -0,0 +1,329 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import os
|
||||
import shutil
|
||||
from typing import Callable
|
||||
import json
|
||||
|
||||
import jax
|
||||
import jax.numpy as jn
|
||||
import numpy as np
|
||||
import tensorflow as tf # For data augmentation.
|
||||
import tensorflow_datasets as tfds
|
||||
from absl import app, flags
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import objax
|
||||
from objax.jaxboard import SummaryWriter, Summary
|
||||
from objax.util import EasyDict
|
||||
from objax.zoo import convnet, wide_resnet, dnnet
|
||||
|
||||
from dataset import DataSet
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
def augment(x, shift: int, mirror=True):
|
||||
"""
|
||||
Augmentation function used in training the model.
|
||||
"""
|
||||
y = x['image']
|
||||
if mirror:
|
||||
y = tf.image.random_flip_left_right(y)
|
||||
y = tf.pad(y, [[shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT')
|
||||
y = tf.image.random_crop(y, tf.shape(x['image']))
|
||||
return dict(image=y, label=x['label'])
|
||||
|
||||
|
||||
class TrainLoop(objax.Module):
|
||||
"""
|
||||
Training loop for general machine learning models.
|
||||
Based on the training loop from the objax CIFAR10 example code.
|
||||
"""
|
||||
predict: Callable
|
||||
train_op: Callable
|
||||
|
||||
def __init__(self, nclass: int, **kwargs):
|
||||
self.nclass = nclass
|
||||
self.params = EasyDict(kwargs)
|
||||
|
||||
def train_step(self, summary: Summary, data: dict, progress: np.ndarray):
|
||||
kv = self.train_op(progress, data['image'].numpy(), data['label'].numpy())
|
||||
for k, v in kv.items():
|
||||
if jn.isnan(v):
|
||||
raise ValueError('NaN, try reducing learning rate', k)
|
||||
if summary is not None:
|
||||
summary.scalar(k, float(v))
|
||||
|
||||
def train(self, num_train_epochs: int, train_size: int, train: DataSet, test: DataSet, logdir: str, save_steps=100, patience=None):
|
||||
"""
|
||||
Completely standard training. Nothing interesting to see here.
|
||||
"""
|
||||
checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=20, makedir=True)
|
||||
start_epoch, last_ckpt = checkpoint.restore(self.vars())
|
||||
train_iter = iter(train)
|
||||
progress = np.zeros(jax.local_device_count(), 'f') # for multi-GPU
|
||||
|
||||
best_acc = 0
|
||||
best_acc_epoch = -1
|
||||
|
||||
with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard:
|
||||
for epoch in range(start_epoch, num_train_epochs):
|
||||
# Train
|
||||
summary = Summary()
|
||||
loop = range(0, train_size, self.params.batch)
|
||||
for step in loop:
|
||||
progress[:] = (step + (epoch * train_size)) / (num_train_epochs * train_size)
|
||||
self.train_step(summary, next(train_iter), progress)
|
||||
|
||||
# Eval
|
||||
accuracy, total = 0, 0
|
||||
if epoch%FLAGS.eval_steps == 0 and test is not None:
|
||||
for data in test:
|
||||
total += data['image'].shape[0]
|
||||
preds = np.argmax(self.predict(data['image'].numpy()), axis=1)
|
||||
accuracy += (preds == data['label'].numpy()).sum()
|
||||
accuracy /= total
|
||||
summary.scalar('eval/accuracy', 100 * accuracy)
|
||||
tensorboard.write(summary, step=(epoch + 1) * train_size)
|
||||
print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, summary['losses/xe'](),
|
||||
summary['eval/accuracy']()))
|
||||
|
||||
if summary['eval/accuracy']() > best_acc:
|
||||
best_acc = summary['eval/accuracy']()
|
||||
best_acc_epoch = epoch
|
||||
elif patience is not None and epoch > best_acc_epoch + patience:
|
||||
print("early stopping!")
|
||||
checkpoint.save(self.vars(), epoch + 1)
|
||||
return
|
||||
|
||||
else:
|
||||
print('Epoch %04d Loss %.2f Accuracy --' % (epoch + 1, summary['losses/xe']()))
|
||||
|
||||
if epoch%save_steps == save_steps-1:
|
||||
checkpoint.save(self.vars(), epoch + 1)
|
||||
|
||||
|
||||
# We inherit from the training loop and define predict and train_op.
|
||||
class MemModule(TrainLoop):
|
||||
def __init__(self, model: Callable, nclass: int, mnist=False, **kwargs):
|
||||
"""
|
||||
Completely standard training. Nothing interesting to see here.
|
||||
"""
|
||||
super().__init__(nclass, **kwargs)
|
||||
self.model = model(1 if mnist else 3, nclass)
|
||||
self.opt = objax.optimizer.Momentum(self.model.vars())
|
||||
self.model_ema = objax.optimizer.ExponentialMovingAverageModule(self.model, momentum=0.999, debias=True)
|
||||
|
||||
@objax.Function.with_vars(self.model.vars())
|
||||
def loss(x, label):
|
||||
logit = self.model(x, training=True)
|
||||
loss_wd = 0.5 * sum((v.value ** 2).sum() for k, v in self.model.vars().items() if k.endswith('.w'))
|
||||
loss_xe = objax.functional.loss.cross_entropy_logits(logit, label).mean()
|
||||
return loss_xe + loss_wd * self.params.weight_decay, {'losses/xe': loss_xe, 'losses/wd': loss_wd}
|
||||
|
||||
gv = objax.GradValues(loss, self.model.vars())
|
||||
self.gv = gv
|
||||
|
||||
@objax.Function.with_vars(self.vars())
|
||||
def train_op(progress, x, y):
|
||||
g, v = gv(x, y)
|
||||
lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8))
|
||||
lr = lr * jn.clip(progress*100,0,1)
|
||||
self.opt(lr, g)
|
||||
self.model_ema.update_ema()
|
||||
return {'monitors/lr': lr, **v[1]}
|
||||
|
||||
self.predict = objax.Jit(objax.nn.Sequential([objax.ForceArgs(self.model_ema, training=False)]))
|
||||
|
||||
self.train_op = objax.Jit(train_op)
|
||||
|
||||
|
||||
def network(arch: str):
|
||||
if arch == 'cnn32-3-max':
|
||||
return functools.partial(convnet.ConvNet, scales=3, filters=32, filters_max=1024,
|
||||
pooling=objax.functional.max_pool_2d)
|
||||
elif arch == 'cnn32-3-mean':
|
||||
return functools.partial(convnet.ConvNet, scales=3, filters=32, filters_max=1024,
|
||||
pooling=objax.functional.average_pool_2d)
|
||||
elif arch == 'cnn64-3-max':
|
||||
return functools.partial(convnet.ConvNet, scales=3, filters=64, filters_max=1024,
|
||||
pooling=objax.functional.max_pool_2d)
|
||||
elif arch == 'cnn64-3-mean':
|
||||
return functools.partial(convnet.ConvNet, scales=3, filters=64, filters_max=1024,
|
||||
pooling=objax.functional.average_pool_2d)
|
||||
elif arch == 'wrn28-1':
|
||||
return functools.partial(wide_resnet.WideResNet, depth=28, width=1)
|
||||
elif arch == 'wrn28-2':
|
||||
return functools.partial(wide_resnet.WideResNet, depth=28, width=2)
|
||||
elif arch == 'wrn28-10':
|
||||
return functools.partial(wide_resnet.WideResNet, depth=28, width=10)
|
||||
raise ValueError('Architecture not recognized', arch)
|
||||
|
||||
def get_data(seed):
|
||||
"""
|
||||
This is the function to generate subsets of the data for training models.
|
||||
|
||||
First, we get the training dataset either from the numpy cache
|
||||
or otherwise we load it from tensorflow datasets.
|
||||
|
||||
Then, we compute the subset. This works in one of two ways.
|
||||
|
||||
1. If we have a seed, then we just randomly choose examples based on
|
||||
a prng with that seed, keeping FLAGS.pkeep fraction of the data.
|
||||
|
||||
2. Otherwise, if we have an experiment ID, then we do something fancier.
|
||||
If we run each experiment independently then even after a lot of trials
|
||||
there will still probably be some examples that were always included
|
||||
or always excluded. So instead, with experiment IDs, we guarantee that
|
||||
after FLAGS.num_experiments are done, each example is seen exactly half
|
||||
of the time in train, and half of the time not in train.
|
||||
|
||||
"""
|
||||
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
|
||||
|
||||
if os.path.exists(os.path.join(FLAGS.logdir, "x_train.npy")):
|
||||
inputs = np.load(os.path.join(FLAGS.logdir, "x_train.npy"))
|
||||
labels = np.load(os.path.join(FLAGS.logdir, "y_train.npy"))
|
||||
else:
|
||||
print("First time, creating dataset")
|
||||
data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR))
|
||||
inputs = data['train']['image']
|
||||
labels = data['train']['label']
|
||||
|
||||
inputs = (inputs/127.5)-1
|
||||
np.save(os.path.join(FLAGS.logdir, "x_train.npy"),inputs)
|
||||
np.save(os.path.join(FLAGS.logdir, "y_train.npy"),labels)
|
||||
|
||||
nclass = np.max(labels)+1
|
||||
|
||||
np.random.seed(seed)
|
||||
if FLAGS.num_experiments is not None:
|
||||
np.random.seed(0)
|
||||
keep = np.random.uniform(0,1,size=(FLAGS.num_experiments, FLAGS.dataset_size))
|
||||
order = keep.argsort(0)
|
||||
keep = order < int(FLAGS.pkeep * FLAGS.num_experiments)
|
||||
keep = np.array(keep[FLAGS.expid], dtype=bool)
|
||||
else:
|
||||
keep = np.random.uniform(0, 1, size=FLAGS.dataset_size) <= FLAGS.pkeep
|
||||
|
||||
if FLAGS.only_subset is not None:
|
||||
keep[FLAGS.only_subset:] = 0
|
||||
|
||||
xs = inputs[keep]
|
||||
ys = labels[keep]
|
||||
|
||||
if FLAGS.augment == 'weak':
|
||||
aug = lambda x: augment(x, 4)
|
||||
elif FLAGS.augment == 'mirror':
|
||||
aug = lambda x: augment(x, 0)
|
||||
elif FLAGS.augment == 'none':
|
||||
aug = lambda x: augment(x, 0, mirror=False)
|
||||
else:
|
||||
raise
|
||||
|
||||
train = DataSet.from_arrays(xs, ys,
|
||||
augment_fn=aug)
|
||||
test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:])
|
||||
train = train.cache().shuffle(8192).repeat().parse().augment().batch(FLAGS.batch)
|
||||
train = train.nchw().one_hot(nclass).prefetch(16)
|
||||
test = test.cache().parse().batch(FLAGS.batch).nchw().prefetch(16)
|
||||
|
||||
return train, test, xs, ys, keep, nclass
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
tf.config.experimental.set_visible_devices([], "GPU")
|
||||
|
||||
seed = FLAGS.seed
|
||||
if seed is None:
|
||||
import time
|
||||
seed = np.random.randint(0, 1000000000)
|
||||
seed ^= int(time.time())
|
||||
|
||||
args = EasyDict(arch=FLAGS.arch,
|
||||
lr=FLAGS.lr,
|
||||
batch=FLAGS.batch,
|
||||
weight_decay=FLAGS.weight_decay,
|
||||
augment=FLAGS.augment,
|
||||
seed=seed)
|
||||
|
||||
|
||||
if FLAGS.tunename:
|
||||
logdir = '_'.join(sorted('%s=%s' % k for k in args.items()))
|
||||
elif FLAGS.expid is not None:
|
||||
logdir = "experiment-%d_%d"%(FLAGS.expid,FLAGS.num_experiments)
|
||||
else:
|
||||
logdir = "experiment-"+str(seed)
|
||||
logdir = os.path.join(FLAGS.logdir, logdir)
|
||||
|
||||
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%10)):
|
||||
print(f"run {FLAGS.expid} already completed.")
|
||||
return
|
||||
else:
|
||||
if os.path.exists(logdir):
|
||||
print(f"deleting run {FLAGS.expid} that did not complete.")
|
||||
shutil.rmtree(logdir)
|
||||
|
||||
print(f"starting run {FLAGS.expid}.")
|
||||
if not os.path.exists(logdir):
|
||||
os.makedirs(logdir)
|
||||
|
||||
train, test, xs, ys, keep, nclass = get_data(seed)
|
||||
|
||||
# Define the network and train_it
|
||||
tm = MemModule(network(FLAGS.arch), nclass=nclass,
|
||||
mnist=FLAGS.dataset == 'mnist',
|
||||
epochs=FLAGS.epochs,
|
||||
expid=FLAGS.expid,
|
||||
num_experiments=FLAGS.num_experiments,
|
||||
pkeep=FLAGS.pkeep,
|
||||
save_steps=FLAGS.save_steps,
|
||||
only_subset=FLAGS.only_subset,
|
||||
**args
|
||||
)
|
||||
|
||||
r = {}
|
||||
r.update(tm.params)
|
||||
|
||||
open(os.path.join(logdir,'hparams.json'),"w").write(json.dumps(tm.params))
|
||||
np.save(os.path.join(logdir,'keep.npy'), keep)
|
||||
|
||||
tm.train(FLAGS.epochs, len(xs), train, test, logdir,
|
||||
save_steps=FLAGS.save_steps, patience=FLAGS.patience)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.')
|
||||
flags.DEFINE_float('lr', 0.1, 'Learning rate.')
|
||||
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
|
||||
flags.DEFINE_float('weight_decay', 0.0005, 'Weight decay ratio.')
|
||||
flags.DEFINE_integer('batch', 256, 'Batch size')
|
||||
flags.DEFINE_integer('epochs', 501, 'Training duration in number of epochs.')
|
||||
flags.DEFINE_string('logdir', 'experiments', 'Directory where to save checkpoints and tensorboard data.')
|
||||
flags.DEFINE_integer('seed', None, 'Training seed.')
|
||||
flags.DEFINE_float('pkeep', .5, 'Probability to keep examples.')
|
||||
flags.DEFINE_integer('expid', None, 'Experiment ID')
|
||||
flags.DEFINE_integer('num_experiments', None, 'Number of experiments')
|
||||
flags.DEFINE_string('augment', 'weak', 'Strong or weak augmentation')
|
||||
flags.DEFINE_integer('only_subset', None, 'Only train on a subset of images.')
|
||||
flags.DEFINE_integer('dataset_size', 50000, 'number of examples to keep.')
|
||||
flags.DEFINE_integer('eval_steps', 1, 'how often to get eval accuracy.')
|
||||
flags.DEFINE_integer('abort_after_epoch', None, 'stop trainin early at an epoch')
|
||||
flags.DEFINE_integer('save_steps', 10, 'how often to get save model.')
|
||||
flags.DEFINE_integer('patience', None, 'Early stopping after this many epochs without progress')
|
||||
flags.DEFINE_bool('tunename', False, 'Use tune name?')
|
||||
app.run(main)
|
Loading…
Reference in a new issue