diff --git a/research/mi_lira_2021/README.md b/research/mi_lira_2021/README.md new file mode 100644 index 0000000..fc287b0 --- /dev/null +++ b/research/mi_lira_2021/README.md @@ -0,0 +1,114 @@ +This directory contains code to reproduce our paper: + +**"Membership Inference Attacks From First Principles"** +https://arxiv.org/abs/2112.03570 +by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramer. + + +###INSTALLING + +You will need to install fairly standard dependencies + +`pip install scipy, sklearn, numpy, matplotlib` + +and also some machine learning framework to train models. We train our models +with JAX + ObJAX so you will need to follow build instructions for that +https://github.com/google/objax +https://objax.readthedocs.io/en/latest/installation_setup.html + + +###RUNNING THE CODE + +####1. Train the models + +The first step in our attack is to train shadow models. As a baseline +that should give most of the gains in our attack, you should start by +training 16 shadow models with the command + +> bash scripts/train_demo.sh + +or if you have multiple GPUs on your machine and want to train these models +in parallel, then modify and run + +> bash scripts/train_demo_multigpu.sh + +This will train several CIFAR-10 wide ResNet models to ~91% accuracy each, and +will output a bunch of files under the directory exp/cifar10 with structure: + +``` +exp/cifar10/ +- experiment_N_of_16 +-- hparams.json +-- keep.npy +-- ckpt/ +--- 0000000100.npz +-- tb/ +``` + +####2. Perform inference + +Once the models are trained, now it's necessary to perform inference and save +the output features for each training example for each model in the dataset. + +> python3 inference.py --logdir=exp/cifar10/ + +This will add to the experiment directory a new set of files + +``` +exp/cifar10/ +- experiment_N_of_16 +-- logits/ +--- 0000000100.npy +``` + +where this new file has shape (50000, 10) and stores the model's +output features for each example. + + +####3. Compute membership inference scores + +Finally we take the output features and generate our logit-scaled membership inference +scores for each example for each model. + +> python3 score.py exp/cifar10/ + +And this in turn generates a new directory + +``` +exp/cifar10/ +- experiment_N_of_16 +-- scores/ +--- 0000000100.npy +``` + +with shape (50000,) storing just our scores. + + +###PLOTTING THE RESULTS + +Finally we can generate pretty pictures, and run the plotting code + +> python3 plot.py + +which should give (something like) the following output + + +![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve") + +``` +Attack Ours (online) + AUC 0.6676, Accuracy 0.6077, TPR@0.1%FPR of 0.0169 +Attack Ours (online, fixed variance) + AUC 0.6856, Accuracy 0.6137, TPR@0.1%FPR of 0.0593 +Attack Ours (offline) + AUC 0.5488, Accuracy 0.5500, TPR@0.1%FPR of 0.0130 +Attack Ours (offline, fixed variance) + AUC 0.5549, Accuracy 0.5537, TPR@0.1%FPR of 0.0299 +Attack Global threshold + AUC 0.5921, Accuracy 0.6044, TPR@0.1%FPR of 0.0009 +``` + +where the global threshold attack is the baseline, and our online, +online-with-fixed-variance, offline, and offline-with-fixed-variance +attack variants are the four other curves. Note that because we only +train a few models, the fixed variance variants perform best. diff --git a/research/mi_lira_2021/dataset.py b/research/mi_lira_2021/dataset.py new file mode 100644 index 0000000..fa2c2b0 --- /dev/null +++ b/research/mi_lira_2021/dataset.py @@ -0,0 +1,95 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Tuple, List + +import numpy as np +import tensorflow as tf + + +def record_parse(serialized_example: str, image_shape: Tuple[int, int, int]): + features = tf.io.parse_single_example(serialized_example, + features={'image': tf.io.FixedLenFeature([], tf.string), + 'label': tf.io.FixedLenFeature([], tf.int64)}) + image = tf.image.decode_image(features['image']).set_shape(image_shape) + image = tf.cast(image, tf.float32) * (2.0 / 255) - 1.0 + return dict(image=image, label=features['label']) + + +class DataSet: + """Wrapper for tf.data.Dataset to permit extensions.""" + + def __init__(self, data: tf.data.Dataset, + image_shape: Tuple[int, int, int], + augment_fn: Optional[Callable] = None, + parse_fn: Optional[Callable] = record_parse): + self.data = data + self.parse_fn = parse_fn + self.augment_fn = augment_fn + self.image_shape = image_shape + + @classmethod + def from_arrays(cls, images: np.ndarray, labels: np.ndarray, augment_fn: Optional[Callable] = None): + return cls(tf.data.Dataset.from_tensor_slices(dict(image=images, label=labels)), images.shape[1:], + augment_fn=augment_fn, parse_fn=None) + + @classmethod + def from_files(cls, filenames: List[str], + image_shape: Tuple[int, int, int], + augment_fn: Optional[Callable], + parse_fn: Optional[Callable] = record_parse): + filenames_in = filenames + filenames = sorted(sum([tf.io.gfile.glob(x) for x in filenames], [])) + if not filenames: + raise ValueError('Empty dataset, files not found:', filenames_in) + return cls(tf.data.TFRecordDataset(filenames), image_shape, augment_fn=augment_fn, parse_fn=parse_fn) + + @classmethod + def from_tfds(cls, dataset: tf.data.Dataset, image_shape: Tuple[int, int, int], + augment_fn: Optional[Callable] = None): + return cls(dataset.map(lambda x: dict(image=tf.cast(x['image'], tf.float32) / 127.5 - 1, label=x['label'])), + image_shape, augment_fn=augment_fn, parse_fn=None) + + def __iter__(self): + return iter(self.data) + + def __getattr__(self, item): + if item in self.__dict__: + return self.__dict__[item] + + def call_and_update(*args, **kwargs): + v = getattr(self.__dict__['data'], item)(*args, **kwargs) + if isinstance(v, tf.data.Dataset): + return self.__class__(v, self.image_shape, augment_fn=self.augment_fn, parse_fn=self.parse_fn) + return v + + return call_and_update + + def augment(self, para_augment: int = 4): + if self.augment_fn: + return self.map(self.augment_fn, para_augment) + return self + + def nchw(self): + return self.map(lambda x: dict(image=tf.transpose(x['image'], [0, 3, 1, 2]), label=x['label'])) + + def one_hot(self, nclass: int): + return self.map(lambda x: dict(image=x['image'], label=tf.one_hot(x['label'], nclass))) + + def parse(self, para_parse: int = 2): + if not self.parse_fn: + return self + if self.image_shape: + return self.map(lambda x: self.parse_fn(x, self.image_shape), para_parse) + return self.map(self.parse_fn, para_parse) diff --git a/research/mi_lira_2021/fprtpr.png b/research/mi_lira_2021/fprtpr.png new file mode 100644 index 0000000..8419ca1 Binary files /dev/null and b/research/mi_lira_2021/fprtpr.png differ diff --git a/research/mi_lira_2021/inference.py b/research/mi_lira_2021/inference.py new file mode 100644 index 0000000..11ad696 --- /dev/null +++ b/research/mi_lira_2021/inference.py @@ -0,0 +1,150 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import os +from typing import Callable +import json + +import re +import jax +import jax.numpy as jn +import numpy as np +import tensorflow as tf # For data augmentation. +import tensorflow_datasets as tfds +from absl import app, flags +from tqdm import tqdm, trange +import pickle +from functools import partial + +import objax +from objax.jaxboard import SummaryWriter, Summary +from objax.util import EasyDict +from objax.zoo import convnet, wide_resnet + +from dataset import DataSet + +from train import MemModule, network + +from collections import defaultdict +FLAGS = flags.FLAGS + + +def main(argv): + """ + Perform inference of the saved model in order to generate the + output logits, using a particular set of augmentations. + """ + del argv + tf.config.experimental.set_visible_devices([], "GPU") + + def load(arch): + return MemModule(network(arch), nclass=100 if FLAGS.dataset == 'cifar100' else 10, + mnist=FLAGS.dataset == 'mnist', + arch=arch, + lr=.1, + batch=0, + epochs=0, + weight_decay=0) + + def cache_load(arch): + thing = [] + def fn(): + if len(thing) == 0: + thing.append(load(arch)) + return thing[0] + return fn + + xs_all = np.load(os.path.join(FLAGS.logdir,"x_train.npy"))[:FLAGS.dataset_size] + ys_all = np.load(os.path.join(FLAGS.logdir,"y_train.npy"))[:FLAGS.dataset_size] + + + def get_loss(model, xbatch, ybatch, shift, reflect=True, stride=1): + + outs = [] + for aug in [xbatch, xbatch[:,:,::-1,:]][:reflect+1]: + aug_pad = tf.pad(aug, [[0] * 2, [shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT').numpy() + for dx in range(0, 2*shift+1, stride): + for dy in range(0, 2*shift+1, stride): + this_x = aug_pad[:, dx:dx+32, dy:dy+32, :].transpose((0,3,1,2)) + + logits = model.model(this_x, training=True) + outs.append(logits) + + print(np.array(outs).shape) + return np.array(outs).transpose((1, 0, 2)) + + N = 5000 + + def features(model, xbatch, ybatch): + return get_loss(model, xbatch, ybatch, + shift=0, reflect=True, stride=1) + + for path in sorted(os.listdir(os.path.join(FLAGS.logdir))): + if re.search(FLAGS.regex, path) is None: + print("Skipping from regex") + continue + + hparams = json.load(open(os.path.join(FLAGS.logdir, path, "hparams.json"))) + arch = hparams['arch'] + model = cache_load(arch)() + + logdir = os.path.join(FLAGS.logdir, path) + + checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=10, makedir=True) + max_epoch, last_ckpt = checkpoint.restore(model.vars()) + if max_epoch == 0: continue + + if not os.path.exists(os.path.join(FLAGS.logdir, path, "logits")): + os.mkdir(os.path.join(FLAGS.logdir, path, "logits")) + if FLAGS.from_epoch is not None: + first = FLAGS.from_epoch + else: + first = max_epoch-1 + + for epoch in range(first,max_epoch+1): + if not os.path.exists(os.path.join(FLAGS.logdir, path, "ckpt", "%010d.npz"%epoch)): + # no checkpoint saved here + continue + + if os.path.exists(os.path.join(FLAGS.logdir, path, "logits", "%010d.npy"%epoch)): + print("Skipping already generated file", epoch) + continue + + try: + start_epoch, last_ckpt = checkpoint.restore(model.vars(), epoch) + except: + print("Fail to load", epoch) + continue + + stats = [] + + for i in range(0,len(xs_all),N): + stats.extend(features(model, xs_all[i:i+N], + ys_all[i:i+N])) + # This will be shape N, augs, nclass + + np.save(os.path.join(FLAGS.logdir, path, "logits", "%010d"%epoch), + np.array(stats)[:,None,:,:]) + +if __name__ == '__main__': + flags.DEFINE_string('dataset', 'cifar10', 'Dataset.') + flags.DEFINE_string('logdir', 'experiments/', 'Directory where to save checkpoints and tensorboard data.') + flags.DEFINE_string('regex', '.*experiment.*', 'keep files when matching') + flags.DEFINE_bool('random_labels', False, 'use random labels.') + flags.DEFINE_integer('dataset_size', 50000, 'size of dataset.') + flags.DEFINE_integer('from_epoch', None, 'which epoch to load from.') + flags.DEFINE_integer('seed_mod', None, 'keep mod seed.') + flags.DEFINE_integer('modulus', 8, 'modulus.') + app.run(main) diff --git a/research/mi_lira_2021/logs/.keep b/research/mi_lira_2021/logs/.keep new file mode 100644 index 0000000..e69de29 diff --git a/research/mi_lira_2021/plot.py b/research/mi_lira_2021/plot.py new file mode 100644 index 0000000..435125c --- /dev/null +++ b/research/mi_lira_2021/plot.py @@ -0,0 +1,224 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import scipy.stats + +import numpy as np +import matplotlib.pyplot as plt +from sklearn.metrics import auc, roc_curve +import functools + +# Look at me being proactive! +import matplotlib +matplotlib.rcParams['pdf.fonttype'] = 42 +matplotlib.rcParams['ps.fonttype'] = 42 + + +def sweep(score, x): + """ + Compute a ROC curve and then return the FPR, TPR, AUC, and ACC. + """ + fpr, tpr, _ = roc_curve(x, -score) + acc = np.max(1-(fpr+(1-tpr))/2) + return fpr, tpr, auc(fpr, tpr), acc + +def load_data(p): + """ + Load our saved scores and then put them into a big matrix. + """ + global scores, keep + scores = [] + keep = [] + + for root,ds,_ in os.walk(p): + for f in ds: + if not f.startswith("experiment"): continue + if not os.path.exists(os.path.join(root,f,"scores")): continue + last_epoch = sorted(os.listdir(os.path.join(root,f,"scores"))) + if len(last_epoch) == 0: continue + scores.append(np.load(os.path.join(root,f,"scores",last_epoch[-1]))) + keep.append(np.load(os.path.join(root,f,"keep.npy"))) + + scores = np.array(scores) + keep = np.array(keep)[:,:scores.shape[1]] + + return scores, keep + +def generate_ours(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000, + fix_variance=False): + """ + Fit a two predictive models using keep and scores in order to predict + if the examples in check_scores were training data or not, using the + ground truth answer from check_keep. + """ + dat_in = [] + dat_out = [] + + for j in range(scores.shape[1]): + dat_in.append(scores[keep[:,j],j,:]) + dat_out.append(scores[~keep[:,j],j,:]) + + in_size = min(min(map(len,dat_in)), in_size) + out_size = min(min(map(len,dat_out)), out_size) + + dat_in = np.array([x[:in_size] for x in dat_in]) + dat_out = np.array([x[:out_size] for x in dat_out]) + + mean_in = np.median(dat_in, 1) + mean_out = np.median(dat_out, 1) + + if fix_variance: + std_in = np.std(dat_in) + std_out = np.std(dat_in) + else: + std_in = np.std(dat_in, 1) + std_out = np.std(dat_out, 1) + + prediction = [] + answers = [] + for ans, sc in zip(check_keep, check_scores): + pr_in = -scipy.stats.norm.logpdf(sc, mean_in, std_in+1e-30) + pr_out = -scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30) + score = pr_in-pr_out + + prediction.extend(score.mean(1)) + answers.extend(ans) + + return prediction, answers + +def generate_ours_offline(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000, + fix_variance=False): + """ + Fit a single predictive model using keep and scores in order to predict + if the examples in check_scores were training data or not, using the + ground truth answer from check_keep. + """ + dat_in = [] + dat_out = [] + + for j in range(scores.shape[1]): + dat_in.append(scores[keep[:, j], j, :]) + dat_out.append(scores[~keep[:, j], j, :]) + + out_size = min(min(map(len,dat_out)), out_size) + + dat_out = np.array([x[:out_size] for x in dat_out]) + + mean_out = np.median(dat_out, 1) + + if fix_variance: + std_out = np.std(dat_out) + else: + std_out = np.std(dat_out, 1) + + prediction = [] + answers = [] + for ans, sc in zip(check_keep, check_scores): + score = scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30) + + prediction.extend(score.mean(1)) + answers.extend(ans) + return prediction, answers + + +def generate_global(keep, scores, check_keep, check_scores): + """ + Use a simple global threshold sweep to predict if the examples in + check_scores were training data or not, using the ground truth answer from + check_keep. + """ + prediction = [] + answers = [] + for ans, sc in zip(check_keep, check_scores): + prediction.extend(-sc.mean(1)) + answers.extend(ans) + + return prediction, answers + +def do_plot(fn, keep, scores, ntest, legend='', metric='auc', sweep_fn=sweep, **plot_kwargs): + """ + Generate the ROC curves by using ntest models as test models and the rest to train. + """ + + prediction, answers = fn(keep[:-ntest], + scores[:-ntest], + keep[-ntest:], + scores[-ntest:]) + + fpr, tpr, auc, acc = sweep_fn(np.array(prediction), np.array(answers, dtype=bool)) + + low = tpr[np.where(fpr<.001)[0][-1]] + + print('Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc,acc, low)) + + metric_text = '' + if metric == 'auc': + metric_text = 'auc=%.3f'%auc + elif metric == 'acc': + metric_text = 'acc=%.3f'%acc + + plt.plot(fpr, tpr, label=legend+metric_text, **plot_kwargs) + return (acc,auc) + + +def fig_fpr_tpr(): + + plt.figure(figsize=(4,3)) + + do_plot(generate_ours, + keep, scores, 1, + "Ours (online)\n", + metric='auc' + ) + + do_plot(functools.partial(generate_ours, fix_variance=True), + keep, scores, 1, + "Ours (online, fixed variance)\n", + metric='auc' + ) + + do_plot(functools.partial(generate_ours_offline), + keep, scores, 1, + "Ours (offline)\n", + metric='auc' + ) + + do_plot(functools.partial(generate_ours_offline, fix_variance=True), + keep, scores, 1, + "Ours (offline, fixed variance)\n", + metric='auc' + ) + + do_plot(generate_global, + keep, scores, 1, + "Global threshold\n", + metric='auc' + ) + + plt.semilogx() + plt.semilogy() + plt.xlim(1e-5,1) + plt.ylim(1e-5,1) + plt.xlabel("False Positive Rate") + plt.ylabel("True Positive Rate") + plt.plot([0, 1], [0, 1], ls='--', color='gray') + plt.subplots_adjust(bottom=.18, left=.18, top=.96, right=.96) + plt.legend(fontsize=8) + plt.savefig("/tmp/fprtpr.png") + plt.show() + + +load_data("exp/cifar10/") +fig_fpr_tpr() diff --git a/research/mi_lira_2021/score.py b/research/mi_lira_2021/score.py new file mode 100644 index 0000000..91aeaf4 --- /dev/null +++ b/research/mi_lira_2021/score.py @@ -0,0 +1,66 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import numpy as np +import os +import multiprocessing as mp + + +def load_one(base): + """ + This loads a logits and converts it to a scored prediction. + """ + root = os.path.join(logdir,base,'logits') + if not os.path.exists(root): return None + + if not os.path.exists(os.path.join(logdir,base,'scores')): + os.mkdir(os.path.join(logdir,base,'scores')) + + for f in os.listdir(root): + try: + opredictions = np.load(os.path.join(root,f)) + except: + print("Fail") + continue + + ## Be exceptionally careful. + ## Numerically stable everything, as described in the paper. + predictions = opredictions - np.max(opredictions, axis=3, keepdims=True) + predictions = np.array(np.exp(predictions), dtype=np.float64) + predictions = predictions/np.sum(predictions,axis=3,keepdims=True) + + COUNT = predictions.shape[0] + # x num_examples x num_augmentations x logits + y_true = predictions[np.arange(COUNT),:,:,labels[:COUNT]] + print(y_true.shape) + + print('mean acc',np.mean(predictions[:,0,0,:].argmax(1)==labels[:COUNT])) + + predictions[np.arange(COUNT),:,:,labels[:COUNT]] = 0 + y_wrong = np.sum(predictions, axis=3) + + logit = (np.log(y_true.mean((1))+1e-45) - np.log(y_wrong.mean((1))+1e-45)) + + np.save(os.path.join(logdir, base, 'scores', f), logit) + + +def load_stats(): + with mp.Pool(8) as p: + p.map(load_one, [x for x in os.listdir(logdir) if 'exp' in x]) + + +logdir = sys.argv[1] +labels = np.load(os.path.join(logdir,"y_train.npy")) +load_stats() diff --git a/research/mi_lira_2021/scripts/train_demo.sh b/research/mi_lira_2021/scripts/train_demo.sh new file mode 100644 index 0000000..06f8779 --- /dev/null +++ b/research/mi_lira_2021/scripts/train_demo.sh @@ -0,0 +1,16 @@ +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0 +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1 +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2 +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3 +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8 +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9 +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10 +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11 +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 diff --git a/research/mi_lira_2021/scripts/train_demo_multigpu.sh b/research/mi_lira_2021/scripts/train_demo_multigpu.sh new file mode 100644 index 0000000..6bd689d --- /dev/null +++ b/research/mi_lira_2021/scripts/train_demo_multigpu.sh @@ -0,0 +1,18 @@ +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0 & +CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1 & +CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2 & +CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3 & +CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 & +CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 & +CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 & +CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 & +wait; +CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8 & +CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9 & +CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10 & +CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11 & +CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 & +CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 & +CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 & +CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 & +wait; diff --git a/research/mi_lira_2021/train.py b/research/mi_lira_2021/train.py new file mode 100644 index 0000000..19ff0e3 --- /dev/null +++ b/research/mi_lira_2021/train.py @@ -0,0 +1,329 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import os +import shutil +from typing import Callable +import json + +import jax +import jax.numpy as jn +import numpy as np +import tensorflow as tf # For data augmentation. +import tensorflow_datasets as tfds +from absl import app, flags +from tqdm import tqdm, trange + +import objax +from objax.jaxboard import SummaryWriter, Summary +from objax.util import EasyDict +from objax.zoo import convnet, wide_resnet, dnnet + +from dataset import DataSet + +FLAGS = flags.FLAGS + +def augment(x, shift: int, mirror=True): + """ + Augmentation function used in training the model. + """ + y = x['image'] + if mirror: + y = tf.image.random_flip_left_right(y) + y = tf.pad(y, [[shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT') + y = tf.image.random_crop(y, tf.shape(x['image'])) + return dict(image=y, label=x['label']) + + +class TrainLoop(objax.Module): + """ + Training loop for general machine learning models. + Based on the training loop from the objax CIFAR10 example code. + """ + predict: Callable + train_op: Callable + + def __init__(self, nclass: int, **kwargs): + self.nclass = nclass + self.params = EasyDict(kwargs) + + def train_step(self, summary: Summary, data: dict, progress: np.ndarray): + kv = self.train_op(progress, data['image'].numpy(), data['label'].numpy()) + for k, v in kv.items(): + if jn.isnan(v): + raise ValueError('NaN, try reducing learning rate', k) + if summary is not None: + summary.scalar(k, float(v)) + + def train(self, num_train_epochs: int, train_size: int, train: DataSet, test: DataSet, logdir: str, save_steps=100, patience=None): + """ + Completely standard training. Nothing interesting to see here. + """ + checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=20, makedir=True) + start_epoch, last_ckpt = checkpoint.restore(self.vars()) + train_iter = iter(train) + progress = np.zeros(jax.local_device_count(), 'f') # for multi-GPU + + best_acc = 0 + best_acc_epoch = -1 + + with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard: + for epoch in range(start_epoch, num_train_epochs): + # Train + summary = Summary() + loop = range(0, train_size, self.params.batch) + for step in loop: + progress[:] = (step + (epoch * train_size)) / (num_train_epochs * train_size) + self.train_step(summary, next(train_iter), progress) + + # Eval + accuracy, total = 0, 0 + if epoch%FLAGS.eval_steps == 0 and test is not None: + for data in test: + total += data['image'].shape[0] + preds = np.argmax(self.predict(data['image'].numpy()), axis=1) + accuracy += (preds == data['label'].numpy()).sum() + accuracy /= total + summary.scalar('eval/accuracy', 100 * accuracy) + tensorboard.write(summary, step=(epoch + 1) * train_size) + print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, summary['losses/xe'](), + summary['eval/accuracy']())) + + if summary['eval/accuracy']() > best_acc: + best_acc = summary['eval/accuracy']() + best_acc_epoch = epoch + elif patience is not None and epoch > best_acc_epoch + patience: + print("early stopping!") + checkpoint.save(self.vars(), epoch + 1) + return + + else: + print('Epoch %04d Loss %.2f Accuracy --' % (epoch + 1, summary['losses/xe']())) + + if epoch%save_steps == save_steps-1: + checkpoint.save(self.vars(), epoch + 1) + + +# We inherit from the training loop and define predict and train_op. +class MemModule(TrainLoop): + def __init__(self, model: Callable, nclass: int, mnist=False, **kwargs): + """ + Completely standard training. Nothing interesting to see here. + """ + super().__init__(nclass, **kwargs) + self.model = model(1 if mnist else 3, nclass) + self.opt = objax.optimizer.Momentum(self.model.vars()) + self.model_ema = objax.optimizer.ExponentialMovingAverageModule(self.model, momentum=0.999, debias=True) + + @objax.Function.with_vars(self.model.vars()) + def loss(x, label): + logit = self.model(x, training=True) + loss_wd = 0.5 * sum((v.value ** 2).sum() for k, v in self.model.vars().items() if k.endswith('.w')) + loss_xe = objax.functional.loss.cross_entropy_logits(logit, label).mean() + return loss_xe + loss_wd * self.params.weight_decay, {'losses/xe': loss_xe, 'losses/wd': loss_wd} + + gv = objax.GradValues(loss, self.model.vars()) + self.gv = gv + + @objax.Function.with_vars(self.vars()) + def train_op(progress, x, y): + g, v = gv(x, y) + lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8)) + lr = lr * jn.clip(progress*100,0,1) + self.opt(lr, g) + self.model_ema.update_ema() + return {'monitors/lr': lr, **v[1]} + + self.predict = objax.Jit(objax.nn.Sequential([objax.ForceArgs(self.model_ema, training=False)])) + + self.train_op = objax.Jit(train_op) + + +def network(arch: str): + if arch == 'cnn32-3-max': + return functools.partial(convnet.ConvNet, scales=3, filters=32, filters_max=1024, + pooling=objax.functional.max_pool_2d) + elif arch == 'cnn32-3-mean': + return functools.partial(convnet.ConvNet, scales=3, filters=32, filters_max=1024, + pooling=objax.functional.average_pool_2d) + elif arch == 'cnn64-3-max': + return functools.partial(convnet.ConvNet, scales=3, filters=64, filters_max=1024, + pooling=objax.functional.max_pool_2d) + elif arch == 'cnn64-3-mean': + return functools.partial(convnet.ConvNet, scales=3, filters=64, filters_max=1024, + pooling=objax.functional.average_pool_2d) + elif arch == 'wrn28-1': + return functools.partial(wide_resnet.WideResNet, depth=28, width=1) + elif arch == 'wrn28-2': + return functools.partial(wide_resnet.WideResNet, depth=28, width=2) + elif arch == 'wrn28-10': + return functools.partial(wide_resnet.WideResNet, depth=28, width=10) + raise ValueError('Architecture not recognized', arch) + +def get_data(seed): + """ + This is the function to generate subsets of the data for training models. + + First, we get the training dataset either from the numpy cache + or otherwise we load it from tensorflow datasets. + + Then, we compute the subset. This works in one of two ways. + + 1. If we have a seed, then we just randomly choose examples based on + a prng with that seed, keeping FLAGS.pkeep fraction of the data. + + 2. Otherwise, if we have an experiment ID, then we do something fancier. + If we run each experiment independently then even after a lot of trials + there will still probably be some examples that were always included + or always excluded. So instead, with experiment IDs, we guarantee that + after FLAGS.num_experiments are done, each example is seen exactly half + of the time in train, and half of the time not in train. + + """ + DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS') + + if os.path.exists(os.path.join(FLAGS.logdir, "x_train.npy")): + inputs = np.load(os.path.join(FLAGS.logdir, "x_train.npy")) + labels = np.load(os.path.join(FLAGS.logdir, "y_train.npy")) + else: + print("First time, creating dataset") + data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR)) + inputs = data['train']['image'] + labels = data['train']['label'] + + inputs = (inputs/127.5)-1 + np.save(os.path.join(FLAGS.logdir, "x_train.npy"),inputs) + np.save(os.path.join(FLAGS.logdir, "y_train.npy"),labels) + + nclass = np.max(labels)+1 + + np.random.seed(seed) + if FLAGS.num_experiments is not None: + np.random.seed(0) + keep = np.random.uniform(0,1,size=(FLAGS.num_experiments, FLAGS.dataset_size)) + order = keep.argsort(0) + keep = order < int(FLAGS.pkeep * FLAGS.num_experiments) + keep = np.array(keep[FLAGS.expid], dtype=bool) + else: + keep = np.random.uniform(0, 1, size=FLAGS.dataset_size) <= FLAGS.pkeep + + if FLAGS.only_subset is not None: + keep[FLAGS.only_subset:] = 0 + + xs = inputs[keep] + ys = labels[keep] + + if FLAGS.augment == 'weak': + aug = lambda x: augment(x, 4) + elif FLAGS.augment == 'mirror': + aug = lambda x: augment(x, 0) + elif FLAGS.augment == 'none': + aug = lambda x: augment(x, 0, mirror=False) + else: + raise + + train = DataSet.from_arrays(xs, ys, + augment_fn=aug) + test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:]) + train = train.cache().shuffle(8192).repeat().parse().augment().batch(FLAGS.batch) + train = train.nchw().one_hot(nclass).prefetch(16) + test = test.cache().parse().batch(FLAGS.batch).nchw().prefetch(16) + + return train, test, xs, ys, keep, nclass + +def main(argv): + del argv + tf.config.experimental.set_visible_devices([], "GPU") + + seed = FLAGS.seed + if seed is None: + import time + seed = np.random.randint(0, 1000000000) + seed ^= int(time.time()) + + args = EasyDict(arch=FLAGS.arch, + lr=FLAGS.lr, + batch=FLAGS.batch, + weight_decay=FLAGS.weight_decay, + augment=FLAGS.augment, + seed=seed) + + + if FLAGS.tunename: + logdir = '_'.join(sorted('%s=%s' % k for k in args.items())) + elif FLAGS.expid is not None: + logdir = "experiment-%d_%d"%(FLAGS.expid,FLAGS.num_experiments) + else: + logdir = "experiment-"+str(seed) + logdir = os.path.join(FLAGS.logdir, logdir) + + if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%10)): + print(f"run {FLAGS.expid} already completed.") + return + else: + if os.path.exists(logdir): + print(f"deleting run {FLAGS.expid} that did not complete.") + shutil.rmtree(logdir) + + print(f"starting run {FLAGS.expid}.") + if not os.path.exists(logdir): + os.makedirs(logdir) + + train, test, xs, ys, keep, nclass = get_data(seed) + + # Define the network and train_it + tm = MemModule(network(FLAGS.arch), nclass=nclass, + mnist=FLAGS.dataset == 'mnist', + epochs=FLAGS.epochs, + expid=FLAGS.expid, + num_experiments=FLAGS.num_experiments, + pkeep=FLAGS.pkeep, + save_steps=FLAGS.save_steps, + only_subset=FLAGS.only_subset, + **args + ) + + r = {} + r.update(tm.params) + + open(os.path.join(logdir,'hparams.json'),"w").write(json.dumps(tm.params)) + np.save(os.path.join(logdir,'keep.npy'), keep) + + tm.train(FLAGS.epochs, len(xs), train, test, logdir, + save_steps=FLAGS.save_steps, patience=FLAGS.patience) + + + +if __name__ == '__main__': + flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.') + flags.DEFINE_float('lr', 0.1, 'Learning rate.') + flags.DEFINE_string('dataset', 'cifar10', 'Dataset.') + flags.DEFINE_float('weight_decay', 0.0005, 'Weight decay ratio.') + flags.DEFINE_integer('batch', 256, 'Batch size') + flags.DEFINE_integer('epochs', 501, 'Training duration in number of epochs.') + flags.DEFINE_string('logdir', 'experiments', 'Directory where to save checkpoints and tensorboard data.') + flags.DEFINE_integer('seed', None, 'Training seed.') + flags.DEFINE_float('pkeep', .5, 'Probability to keep examples.') + flags.DEFINE_integer('expid', None, 'Experiment ID') + flags.DEFINE_integer('num_experiments', None, 'Number of experiments') + flags.DEFINE_string('augment', 'weak', 'Strong or weak augmentation') + flags.DEFINE_integer('only_subset', None, 'Only train on a subset of images.') + flags.DEFINE_integer('dataset_size', 50000, 'number of examples to keep.') + flags.DEFINE_integer('eval_steps', 1, 'how often to get eval accuracy.') + flags.DEFINE_integer('abort_after_epoch', None, 'stop trainin early at an epoch') + flags.DEFINE_integer('save_steps', 10, 'how often to get save model.') + flags.DEFINE_integer('patience', None, 'Early stopping after this many epochs without progress') + flags.DEFINE_bool('tunename', False, 'Use tune name?') + app.run(main)