tensorflow_privacy/research/mi_poison_2022/train_poison.py

215 lines
8.2 KiB
Python
Raw Normal View History

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
# pyformat: disable
import os
import shutil
import json
import numpy as np
import tensorflow as tf # For data augmentation.
import tensorflow_datasets as tfds
from absl import app, flags
from objax.util import EasyDict
# from mi_lira_2021
from dataset import DataSet
from train import augment, MemModule, network
FLAGS = flags.FLAGS
def get_data(seed):
"""
This is the function to generate subsets of the data for training models.
First, we get the training dataset either from the numpy cache
or otherwise we load it from tensorflow datasets.
Then, we compute the subset. This works in one of two ways.
1. If we have a seed, then we just randomly choose examples based on
a prng with that seed, keeping FLAGS.pkeep fraction of the data.
2. Otherwise, if we have an experiment ID, then we do something fancier.
If we run each experiment independently then even after a lot of trials
there will still probably be some examples that were always included
or always excluded. So instead, with experiment IDs, we guarantee that
after FLAGS.num_experiments are done, each example is seen exactly half
of the time in train, and half of the time not in train.
Finally, we add some poisons. The same poisoned samples are added for
each randomly generated training set.
We first select FLAGS.num_poison_targets victim points that will be targeted
by the poisoning attack. For each of these victim points, the attacker will
insert FLAGS.poison_reps mislabeled replicas of the point into the training
set.
For CIFAR-10, we recommend that:
`FLAGS.num_poison_targets * FLAGS.poison_reps < 5000`
Otherwise, the poisons might introduce too much label noise and the model's
accuracy (and the attack's success rate) will be degraded.
"""
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
if os.path.exists(os.path.join(FLAGS.logdir, "x_train.npy")):
inputs = np.load(os.path.join(FLAGS.logdir, "x_train.npy"))
labels = np.load(os.path.join(FLAGS.logdir, "y_train.npy"))
else:
print("First time, creating dataset")
data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR))
inputs = data['train']['image']
labels = data['train']['label']
inputs = (inputs/127.5)-1
np.save(os.path.join(FLAGS.logdir, "x_train.npy"), inputs)
np.save(os.path.join(FLAGS.logdir, "y_train.npy"), labels)
nclass = np.max(labels)+1
np.random.seed(seed)
if FLAGS.num_experiments is not None:
np.random.seed(0)
keep = np.random.uniform(0, 1, size=(FLAGS.num_experiments, len(inputs)))
order = keep.argsort(0)
keep = order < int(FLAGS.pkeep * FLAGS.num_experiments)
keep = np.array(keep[FLAGS.expid], dtype=bool)
else:
keep = np.random.uniform(0, 1, size=len(inputs)) <= FLAGS.pkeep
xs = inputs[keep]
ys = labels[keep]
if FLAGS.num_poison_targets > 0:
# select some points as targets
np.random.seed(FLAGS.poison_pos_seed)
poison_pos = np.random.choice(len(inputs), size=FLAGS.num_poison_targets, replace=False)
# create mislabeled poisons for the targeted points and replicate each
# poison `FLAGS.poison_reps` times
y_noise = np.mod(labels[poison_pos] + np.random.randint(low=1, high=nclass, size=FLAGS.num_poison_targets), nclass)
ypoison = np.repeat(y_noise, FLAGS.poison_reps)
xpoison = np.repeat(inputs[poison_pos], FLAGS.poison_reps, axis=0)
xs = np.concatenate((xs, xpoison), axis=0)
ys = np.concatenate((ys, ypoison), axis=0)
if not os.path.exists(os.path.join(FLAGS.logdir, "poison_pos.npy")):
np.save(os.path.join(FLAGS.logdir, "poison_pos.npy"), poison_pos)
if FLAGS.augment == 'weak':
aug = lambda x: augment(x, 4)
elif FLAGS.augment == 'mirror':
aug = lambda x: augment(x, 0)
elif FLAGS.augment == 'none':
aug = lambda x: augment(x, 0, mirror=False)
else:
raise
print(xs.shape, ys.shape)
train = DataSet.from_arrays(xs, ys,
augment_fn=aug)
test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:])
train = train.cache().shuffle(len(xs)).repeat().parse().augment().batch(FLAGS.batch)
train = train.nchw().one_hot(nclass).prefetch(FLAGS.batch)
test = test.cache().parse().batch(FLAGS.batch).nchw().prefetch(FLAGS.batch)
return train, test, xs, ys, keep, nclass
def main(argv):
del argv
tf.config.experimental.set_visible_devices([], "GPU")
seed = FLAGS.seed
if seed is None:
import time
seed = np.random.randint(0, 1000000000)
seed ^= int(time.time())
args = EasyDict(arch=FLAGS.arch,
lr=FLAGS.lr,
batch=FLAGS.batch,
weight_decay=FLAGS.weight_decay,
augment=FLAGS.augment,
seed=seed)
if FLAGS.expid is not None:
logdir = "experiment-%d_%d" % (FLAGS.expid, FLAGS.num_experiments)
else:
logdir = "experiment-"+str(seed)
logdir = os.path.join(FLAGS.logdir, logdir)
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz" % FLAGS.epochs)):
print(f"run {FLAGS.expid} already completed.")
return
else:
if os.path.exists(logdir):
print(f"deleting run {FLAGS.expid} that did not complete.")
shutil.rmtree(logdir)
print(f"starting run {FLAGS.expid}.")
if not os.path.exists(logdir):
os.makedirs(logdir)
train, test, xs, ys, keep, nclass = get_data(seed)
# Define the network and train_it
tm = MemModule(network(FLAGS.arch), nclass=nclass,
mnist=FLAGS.dataset == 'mnist',
epochs=FLAGS.epochs,
expid=FLAGS.expid,
num_experiments=FLAGS.num_experiments,
pkeep=FLAGS.pkeep,
save_steps=FLAGS.save_steps,
**args
)
r = {}
r.update(tm.params)
open(os.path.join(logdir, 'hparams.json'), "w").write(json.dumps(tm.params))
np.save(os.path.join(logdir,'keep.npy'), keep)
tm.train(FLAGS.epochs, len(xs), train, test, logdir,
save_steps=FLAGS.save_steps)
if __name__ == '__main__':
flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.')
flags.DEFINE_float('lr', 0.1, 'Learning rate.')
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
flags.DEFINE_float('weight_decay', 0.0005, 'Weight decay ratio.')
flags.DEFINE_integer('batch', 256, 'Batch size')
flags.DEFINE_integer('epochs', 100, 'Training duration in number of epochs.')
flags.DEFINE_string('logdir', 'experiments', 'Directory where to save checkpoints and tensorboard data.')
flags.DEFINE_integer('seed', None, 'Training seed.')
flags.DEFINE_float('pkeep', .5, 'Probability to keep examples.')
flags.DEFINE_integer('expid', None, 'Experiment ID')
flags.DEFINE_integer('num_experiments', None, 'Number of experiments')
flags.DEFINE_string('augment', 'weak', 'Strong or weak augmentation')
flags.DEFINE_integer('eval_steps', 1, 'how often to get eval accuracy.')
flags.DEFINE_integer('save_steps', 10, 'how often to get save model.')
flags.DEFINE_integer('num_poison_targets', 250, 'Number of points to target '
'with the poisoning attack.')
flags.DEFINE_integer('poison_reps', 8, 'Number of times to repeat each poison.')
flags.DEFINE_integer('poison_pos_seed', 0, '')
app.run(main)