97eec1a8e3
PiperOrigin-RevId: 447573314
214 lines
8.2 KiB
Python
214 lines
8.2 KiB
Python
# Copyright 2021 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# pylint: skip-file
|
|
# pyformat: disable
|
|
|
|
import os
|
|
import shutil
|
|
import json
|
|
|
|
import numpy as np
|
|
import tensorflow as tf # For data augmentation.
|
|
import tensorflow_datasets as tfds
|
|
from absl import app, flags
|
|
|
|
from objax.util import EasyDict
|
|
|
|
# from mi_lira_2021
|
|
from dataset import DataSet
|
|
from train import augment, MemModule, network
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
|
def get_data(seed):
|
|
"""
|
|
This is the function to generate subsets of the data for training models.
|
|
|
|
First, we get the training dataset either from the numpy cache
|
|
or otherwise we load it from tensorflow datasets.
|
|
|
|
Then, we compute the subset. This works in one of two ways.
|
|
|
|
1. If we have a seed, then we just randomly choose examples based on
|
|
a prng with that seed, keeping FLAGS.pkeep fraction of the data.
|
|
|
|
2. Otherwise, if we have an experiment ID, then we do something fancier.
|
|
If we run each experiment independently then even after a lot of trials
|
|
there will still probably be some examples that were always included
|
|
or always excluded. So instead, with experiment IDs, we guarantee that
|
|
after FLAGS.num_experiments are done, each example is seen exactly half
|
|
of the time in train, and half of the time not in train.
|
|
|
|
Finally, we add some poisons. The same poisoned samples are added for
|
|
each randomly generated training set.
|
|
We first select FLAGS.num_poison_targets victim points that will be targeted
|
|
by the poisoning attack. For each of these victim points, the attacker will
|
|
insert FLAGS.poison_reps mislabeled replicas of the point into the training
|
|
set.
|
|
|
|
For CIFAR-10, we recommend that:
|
|
|
|
`FLAGS.num_poison_targets * FLAGS.poison_reps < 5000`
|
|
|
|
Otherwise, the poisons might introduce too much label noise and the model's
|
|
accuracy (and the attack's success rate) will be degraded.
|
|
"""
|
|
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
|
|
|
|
if os.path.exists(os.path.join(FLAGS.logdir, "x_train.npy")):
|
|
inputs = np.load(os.path.join(FLAGS.logdir, "x_train.npy"))
|
|
labels = np.load(os.path.join(FLAGS.logdir, "y_train.npy"))
|
|
else:
|
|
print("First time, creating dataset")
|
|
data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR))
|
|
inputs = data['train']['image']
|
|
labels = data['train']['label']
|
|
|
|
inputs = (inputs/127.5)-1
|
|
np.save(os.path.join(FLAGS.logdir, "x_train.npy"), inputs)
|
|
np.save(os.path.join(FLAGS.logdir, "y_train.npy"), labels)
|
|
|
|
nclass = np.max(labels)+1
|
|
|
|
np.random.seed(seed)
|
|
if FLAGS.num_experiments is not None:
|
|
np.random.seed(0)
|
|
keep = np.random.uniform(0, 1, size=(FLAGS.num_experiments, len(inputs)))
|
|
order = keep.argsort(0)
|
|
keep = order < int(FLAGS.pkeep * FLAGS.num_experiments)
|
|
keep = np.array(keep[FLAGS.expid], dtype=bool)
|
|
else:
|
|
keep = np.random.uniform(0, 1, size=len(inputs)) <= FLAGS.pkeep
|
|
|
|
xs = inputs[keep]
|
|
ys = labels[keep]
|
|
|
|
if FLAGS.num_poison_targets > 0:
|
|
|
|
# select some points as targets
|
|
np.random.seed(FLAGS.poison_pos_seed)
|
|
poison_pos = np.random.choice(len(inputs), size=FLAGS.num_poison_targets, replace=False)
|
|
|
|
# create mislabeled poisons for the targeted points and replicate each
|
|
# poison `FLAGS.poison_reps` times
|
|
y_noise = np.mod(labels[poison_pos] + np.random.randint(low=1, high=nclass, size=FLAGS.num_poison_targets), nclass)
|
|
ypoison = np.repeat(y_noise, FLAGS.poison_reps)
|
|
xpoison = np.repeat(inputs[poison_pos], FLAGS.poison_reps, axis=0)
|
|
xs = np.concatenate((xs, xpoison), axis=0)
|
|
ys = np.concatenate((ys, ypoison), axis=0)
|
|
|
|
if not os.path.exists(os.path.join(FLAGS.logdir, "poison_pos.npy")):
|
|
np.save(os.path.join(FLAGS.logdir, "poison_pos.npy"), poison_pos)
|
|
|
|
if FLAGS.augment == 'weak':
|
|
aug = lambda x: augment(x, 4)
|
|
elif FLAGS.augment == 'mirror':
|
|
aug = lambda x: augment(x, 0)
|
|
elif FLAGS.augment == 'none':
|
|
aug = lambda x: augment(x, 0, mirror=False)
|
|
else:
|
|
raise
|
|
|
|
print(xs.shape, ys.shape)
|
|
train = DataSet.from_arrays(xs, ys,
|
|
augment_fn=aug)
|
|
test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:])
|
|
train = train.cache().shuffle(len(xs)).repeat().parse().augment().batch(FLAGS.batch)
|
|
train = train.nchw().one_hot(nclass).prefetch(FLAGS.batch)
|
|
test = test.cache().parse().batch(FLAGS.batch).nchw().prefetch(FLAGS.batch)
|
|
|
|
return train, test, xs, ys, keep, nclass
|
|
|
|
|
|
def main(argv):
|
|
del argv
|
|
tf.config.experimental.set_visible_devices([], "GPU")
|
|
|
|
seed = FLAGS.seed
|
|
if seed is None:
|
|
import time
|
|
seed = np.random.randint(0, 1000000000)
|
|
seed ^= int(time.time())
|
|
|
|
args = EasyDict(arch=FLAGS.arch,
|
|
lr=FLAGS.lr,
|
|
batch=FLAGS.batch,
|
|
weight_decay=FLAGS.weight_decay,
|
|
augment=FLAGS.augment,
|
|
seed=seed)
|
|
|
|
if FLAGS.expid is not None:
|
|
logdir = "experiment-%d_%d" % (FLAGS.expid, FLAGS.num_experiments)
|
|
else:
|
|
logdir = "experiment-"+str(seed)
|
|
logdir = os.path.join(FLAGS.logdir, logdir)
|
|
|
|
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz" % FLAGS.epochs)):
|
|
print(f"run {FLAGS.expid} already completed.")
|
|
return
|
|
else:
|
|
if os.path.exists(logdir):
|
|
print(f"deleting run {FLAGS.expid} that did not complete.")
|
|
shutil.rmtree(logdir)
|
|
|
|
print(f"starting run {FLAGS.expid}.")
|
|
if not os.path.exists(logdir):
|
|
os.makedirs(logdir)
|
|
|
|
train, test, xs, ys, keep, nclass = get_data(seed)
|
|
|
|
# Define the network and train_it
|
|
tm = MemModule(network(FLAGS.arch), nclass=nclass,
|
|
mnist=FLAGS.dataset == 'mnist',
|
|
epochs=FLAGS.epochs,
|
|
expid=FLAGS.expid,
|
|
num_experiments=FLAGS.num_experiments,
|
|
pkeep=FLAGS.pkeep,
|
|
save_steps=FLAGS.save_steps,
|
|
**args
|
|
)
|
|
|
|
r = {}
|
|
r.update(tm.params)
|
|
|
|
open(os.path.join(logdir, 'hparams.json'), "w").write(json.dumps(tm.params))
|
|
np.save(os.path.join(logdir,'keep.npy'), keep)
|
|
|
|
tm.train(FLAGS.epochs, len(xs), train, test, logdir,
|
|
save_steps=FLAGS.save_steps)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.')
|
|
flags.DEFINE_float('lr', 0.1, 'Learning rate.')
|
|
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
|
|
flags.DEFINE_float('weight_decay', 0.0005, 'Weight decay ratio.')
|
|
flags.DEFINE_integer('batch', 256, 'Batch size')
|
|
flags.DEFINE_integer('epochs', 100, 'Training duration in number of epochs.')
|
|
flags.DEFINE_string('logdir', 'experiments', 'Directory where to save checkpoints and tensorboard data.')
|
|
flags.DEFINE_integer('seed', None, 'Training seed.')
|
|
flags.DEFINE_float('pkeep', .5, 'Probability to keep examples.')
|
|
flags.DEFINE_integer('expid', None, 'Experiment ID')
|
|
flags.DEFINE_integer('num_experiments', None, 'Number of experiments')
|
|
flags.DEFINE_string('augment', 'weak', 'Strong or weak augmentation')
|
|
flags.DEFINE_integer('eval_steps', 1, 'how often to get eval accuracy.')
|
|
flags.DEFINE_integer('save_steps', 10, 'how often to get save model.')
|
|
|
|
flags.DEFINE_integer('num_poison_targets', 250, 'Number of points to target '
|
|
'with the poisoning attack.')
|
|
flags.DEFINE_integer('poison_reps', 8, 'Number of times to repeat each poison.')
|
|
flags.DEFINE_integer('poison_pos_seed', 0, '')
|
|
app.run(main)
|