add auditing code

This commit is contained in:
Matthew Jagielski 2021-02-15 19:27:18 -05:00
parent 35a8096173
commit 3f2447e262
8 changed files with 676 additions and 0 deletions

View file

@ -0,0 +1,11 @@
# Auditing Private Machine Learning
Code for "Auditing Differentially Private Machine Learning: How Private is Private SGD?": https://arxiv.org/abs/2006.07709. This implementation is simple but not easily parallelizable. For a parallelizable version which is harder to run, see https://github.com/jagielski/auditing-dpsgd.
## Usage
This attack relies on the AuditAttack class found in audit.py. The class allows one to generate poisoning, run trials to compute membership scores for the poisoning, and then use the resulting membership scores to compute a lower bound on epsilon.
## Examples
Two examples are provided, mean_audit.py and fmnist_audit.py. fmnist_audit.py attacks the FashionMNIST dataset. It allows the user to specify between standard backdoor attacks and clipping-aware attacks, and also allows the user to specify between multiple poisoning attack sizes, model types, and whether to load saved model weights to start training from. mean_audit.py audits a model which computes the mean of a dataset. This provides an example of user-provided poisoning samples, rather than those autogenerated from our attacks.py library.
## Requirements
Requires scikit-learn=0.24.1, statsmodels=0.12.2, tensorflow=1.14.0

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1,115 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Poisoning attack library for auditing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
def make_clip_aware(trn_x, trn_y, l2_norm=10):
"""
trn_x: clean training features - must be shape (n_samples, n_features)
trn_y: clean training labels - must be shape (n_samples, )
Returns x, y1, y2
x: poisoning sample
y1: first corresponding y value
y2: second corresponding y value
"""
x_shape = list(trn_x.shape[1:])
to_image = lambda x: x.reshape([-1] + x_shape)
flatten = lambda x: x.reshape((x.shape[0], -1))
assert np.allclose(to_image(flatten(trn_x)), trn_x)
flat_x = flatten(trn_x)
pca = PCA(flat_x.shape[1])
pca.fit(flat_x)
new_x = l2_norm*pca.components_[-1]
lr = LogisticRegression(max_iter=1000)
lr.fit(flat_x, np.argmax(trn_y, axis=1))
num_classes = trn_y.shape[1]
lr_probs = lr.predict_proba(new_x[None, :])
min_y = np.argmin(lr_probs)
second_y = np.argmin(lr_probs + np.eye(num_classes)[min_y])
oh_min_y = np.eye(num_classes)[min_y]
oh_second_y = np.eye(num_classes)[second_y]
return to_image(new_x), oh_min_y, oh_second_y
def make_backdoor(trn_x, trn_y):
"""
trn_x: clean training features - must be shape (n_samples, n_features)
trn_y: clean training labels - must be shape (n_samples, )
Returns x, y1, y2
x: poisoning sample
y1: first corresponding y value
y2: second corresponding y value
"""
sample_ind = np.random.choice(trn_x.shape[0], 1)
pois_x = np.copy(trn_x[sample_ind, :])
pois_x[0] = 1 # set corner feature to 1
second_y = trn_y[sample_ind]
num_classes = trn_y.shape[1]
min_y = np.eye(num_classes)[second_y.argmax(1) + 1]
return pois_x, min_y, second_y
def make_many_pois(trn_x, trn_y, pois_sizes, attack="clip_aware", l2_norm=10):
"""
Makes a dict containing many poisoned datasets. make_pois is fairly slow:
this avoids making multiple calls
trn_x: clean training features - shape (n_samples, n_features)
trn_y: clean training labels - shape (n_samples, )
pois_sizes: list of poisoning sizes
l2_norm: l2 norm of the poisoned data
Returns dict: all_poisons
all_poisons[poison_size] is a pair of poisoned datasets
"""
if attack == "clip_aware":
pois_sample_x, y, second_y = make_clip_aware(trn_x, trn_y, l2_norm)
elif attack == "backdoor":
pois_sample_x, y, second_y = make_backdoor(trn_x, trn_y)
else:
raise NotImplementedError
all_poisons = {"pois": (pois_sample_x, y)}
for pois_size in pois_sizes: # make_pois is slow - don't want it in a loop
new_pois_x1, new_pois_y1 = trn_x.copy(), trn_y.copy()
new_pois_x2, new_pois_y2 = trn_x.copy(), trn_y.copy()
new_pois_x1[-pois_size:] = pois_sample_x[None, :]
new_pois_y1[-pois_size:] = y
new_pois_x2[-pois_size:] = pois_sample_x[None, :]
new_pois_y2[-pois_size:] = second_y
dataset1, dataset2 = (new_pois_x1, new_pois_y1), (new_pois_x2, new_pois_y2)
all_poisons[pois_size] = dataset1, dataset2
return all_poisons

View file

@ -0,0 +1,123 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Class for running auditing procedure."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from statsmodels.stats import proportion
import attacks
def compute_results(poison_scores, unpois_scores, pois_ct,
alpha=0.05, threshold=None):
"""
Searches over thresholds for the best epsilon lower bound and accuracy.
poison_scores: list of scores from poisoned models
unpois_scores: list of scores from unpoisoned models
pois_ct: number of poison points
alpha: confidence parameter
threshold: if None, search over all thresholds, else use given threshold
"""
if threshold is None: # search for best threshold
all_thresholds = np.unique(poison_scores + unpois_scores)
else:
all_thresholds = [threshold]
poison_arr = np.array(poison_scores)
unpois_arr = np.array(unpois_scores)
best_threshold, best_epsilon, best_acc = None, 0, 0
for thresh in all_thresholds:
epsilon, acc = compute_epsilon_and_acc(poison_arr, unpois_arr, thresh,
alpha, pois_ct)
if epsilon > best_epsilon:
best_epsilon, best_threshold = epsilon, thresh
best_acc = max(best_acc, acc)
return best_threshold, best_epsilon, best_acc
def compute_epsilon_and_acc(poison_arr, unpois_arr, threshold, alpha, pois_ct):
"""For a given threshold, compute epsilon and accuracy."""
poison_ct = (poison_arr > threshold).sum()
unpois_ct = (unpois_arr > threshold).sum()
# clopper_pearson uses alpha/2 budget on upper and lower
# so total budget will be 2*alpha/2 = alpha
p1, _ = proportion.proportion_confint(poison_ct, poison_arr.size,
alpha, method='beta')
_, p0 = proportion.proportion_confint(unpois_ct, unpois_arr.size,
alpha, method='beta')
if (p1 <= 1e-5) or (p0 >= 1 - 1e-5): # divide by zero issues
return 0, 0
if (p0 + p1) > 1: # see Appendix A
p0, p1 = (1-p1), (1-p0)
epsilon = np.log(p1/p0)/pois_ct
acc = (p1 + (1-p0))/2 # this is not necessarily the best accuracy
return epsilon, acc
class AuditAttack(object):
"""Audit attack class. Generates poisoning, then runs auditing algorithm."""
def __init__(self, trn_x, trn_y, train_function):
"""
trn_x: training features
trn_y: training labels
name: identifier for the attack
train_function: function returning membership score
"""
self.trn_x, self.trn_y = trn_x, trn_y
self.train_function = train_function
self.poisoning = None
def make_poisoning(self, pois_ct, attack_type, l2_norm=10):
"""Get poisoning data."""
return attacks.make_many_pois(self.trn_x, self.trn_y, [pois_ct],
attack=attack_type, l2_norm=l2_norm)
def run_experiments(self, num_trials):
"""Uses multiprocessing to run all training experiments."""
(pois_x1, pois_y1), (pois_x2, pois_y2) = self.poisoning['data']
sample_x, sample_y = self.poisoning['pois']
poison_scores = []
unpois_scores = []
for i in range(num_trials):
poison_tuple = (pois_x1, pois_y1, sample_x, sample_y, i)
unpois_tuple = (pois_x2, pois_y2, sample_x, sample_y, num_trials + i)
poison_scores.append(self.train_function(poison_tuple))
unpois_scores.append(self.train_function(unpois_tuple))
return poison_scores, unpois_scores
def run(self, pois_ct, attack_type, num_trials, alpha=0.05,
threshold=None, l2_norm=10):
"""Complete auditing algorithm. Generates poisoning if necessary."""
if self.poisoning is None:
self.poisoning = self.make_poisoning(pois_ct, attack_type, l2_norm)
self.poisoning['data'] = self.poisoning[pois_ct]
poison_scores, unpois_scores = self.run_experiments(num_trials)
results = compute_results(poison_scores, unpois_scores, pois_ct,
alpha=alpha, threshold=threshold)
return results

View file

@ -0,0 +1,91 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for audit.py."""
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import audit
def dummy_train_and_score_function(dataset):
del dataset
return 0
def get_auditor():
poisoning = {}
datasets = (np.zeros((5, 2)), np.zeros(5)), (np.zeros((5, 2)), np.zeros(5))
poisoning["data"] = datasets
poisoning["pois"] = (datasets[0][0][0], datasets[0][1][0])
auditor = audit.AuditAttack(datasets[0][0], datasets[0][1],
dummy_train_and_score_function)
auditor.poisoning = poisoning
return auditor
class AuditParameterizedTest(parameterized.TestCase):
"""Class to test parameterized audit.py functions."""
@parameterized.named_parameters(
('Test0', np.ones(500), np.zeros(500), 0.5, 0.01, 1,
(4.541915810224092, 0.9894593118113243)),
('Test1', np.ones(500), np.zeros(500), 0.5, 0.01, 2,
(2.27095790511, 0.9894593118113243)),
('Test2', np.ones(500), np.ones(500), 0.5, 0.01, 1,
(0, 0))
)
def test_compute_epsilon_and_acc(self, poison_scores, unpois_scores,
threshold, pois_ct, alpha, expected_res):
expected_eps, expected_acc = expected_res
computed_res = audit.compute_epsilon_and_acc(poison_scores, unpois_scores,
threshold, pois_ct, alpha)
computed_eps, computed_acc = computed_res
self.assertAlmostEqual(computed_eps, expected_eps)
self.assertAlmostEqual(computed_acc, expected_acc)
@parameterized.named_parameters(
('Test0', [1]*500, [0]*250 + [.5]*250, 1, 0.01, .5,
(.5, 4.541915810224092, 0.9894593118113243)),
('Test1', [1]*500, [0]*250 + [.5]*250, 1, 0.01, None,
(.5, 4.541915810224092, 0.9894593118113243)),
('Test2', [1]*500, [0]*500, 2, 0.01, .5,
(.5, 2.27095790511, 0.9894593118113243)),
)
def test_compute_results(self, poison_scores, unpois_scores, pois_ct,
alpha, threshold, expected_res):
expected_thresh, expected_eps, expected_acc = expected_res
computed_res = audit.compute_results(poison_scores, unpois_scores,
pois_ct, alpha, threshold)
computed_thresh, computed_eps, computed_acc = computed_res
self.assertAlmostEqual(computed_thresh, expected_thresh)
self.assertAlmostEqual(computed_eps, expected_eps)
self.assertAlmostEqual(computed_acc, expected_acc)
class AuditAttackTest(absltest.TestCase):
"""Nonparameterized audit.py test class."""
def test_run_experiments(self):
auditor = get_auditor()
pois, unpois = auditor.run_experiments(100)
expected = [0 for _ in range(100)]
self.assertListEqual(pois, expected)
self.assertListEqual(unpois, expected)
if __name__ == '__main__':
absltest.main()

View file

@ -0,0 +1,180 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Run auditing on the FashionMNIST dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from tensorflow_privacy.privacy.optimizers import dp_optimizer_vectorized
from absl import app
from absl import flags
import audit
#### FLAGS
FLAGS = flags.FLAGS
flags.DEFINE_float('learning_rate', 0.15, 'Learning rate for training')
flags.DEFINE_float('noise_multiplier', 1.1,
'Ratio of the standard deviation to the clipping norm')
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
flags.DEFINE_integer('batch_size', 250, 'Batch size')
flags.DEFINE_integer('epochs', 24, 'Number of epochs')
flags.DEFINE_integer(
'microbatches', 250, 'Number of microbatches '
'(must evenly divide batch_size)')
flags.DEFINE_string('model', 'lr', 'model to use, pick between lr and nn')
flags.DEFINE_string('attack_type', "clip_aware", 'clip_aware or backdoor')
flags.DEFINE_integer('pois_ct', 1, 'Number of poisoning points')
flags.DEFINE_integer('num_trials', 100, 'Number of trials for auditing')
flags.DEFINE_float('attack_l2_norm', 10, 'Size of poisoning data')
flags.DEFINE_float('alpha', 0.05, '1-confidence')
flags.DEFINE_boolean('load_weights', False,
'if True, use weights saved in init_weights.h5')
FLAGS = flags.FLAGS
def compute_epsilon(train_size):
"""Computes epsilon value for given hyperparameters."""
if FLAGS.noise_multiplier == 0.0:
return float('inf')
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
sampling_probability = FLAGS.batch_size / train_size
steps = FLAGS.epochs * train_size / FLAGS.batch_size
rdp = compute_rdp(q=sampling_probability,
noise_multiplier=FLAGS.noise_multiplier,
steps=steps,
orders=orders)
# Delta is set to approximate 1 / (number of training points).
return get_privacy_spent(orders, rdp, target_delta=1e-5)[0]
def build_model(x, y):
"""Build a keras model."""
input_shape = x.shape[1:]
num_classes = y.shape[1]
l2 = 0
if FLAGS.model == 'lr':
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=input_shape),
tf.keras.layers.Dense(num_classes, kernel_initializer='glorot_normal',
kernel_regularizer=tf.keras.regularizers.l2(l2))
])
elif FLAGS.model == 'nn':
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=input_shape),
tf.keras.layers.Dense(32, activation='relu',
kernel_initializer='glorot_normal',
kernel_regularizer=tf.keras.regularizers.l2(l2)),
tf.keras.layers.Dense(num_classes, kernel_initializer='glorot_normal',
kernel_regularizer=tf.keras.regularizers.l2(l2))
])
else:
raise NotImplementedError
return model
def train_model(model, train_x, train_y, save_weights=False):
"""Train the model on given data."""
optimizer = dp_optimizer_vectorized.VectorizedDPSGD(
l2_norm_clip=FLAGS.l2_norm_clip,
noise_multiplier=FLAGS.noise_multiplier,
num_microbatches=FLAGS.microbatches,
learning_rate=FLAGS.learning_rate)
loss = tf.keras.losses.CategoricalCrossentropy(
from_logits=True, reduction=tf.losses.Reduction.NONE)
# Compile model with Keras
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
if save_weights:
wts = model.get_weights()
np.save('save_model', wts)
model.set_weights(wts)
return model
if FLAGS.load_weights: # load preset weights
wts = np.load('save_model.npy', allow_pickle=True).tolist()
model.set_weights(wts)
# Train model with Keras
model.fit(train_x, train_y,
epochs=FLAGS.epochs,
validation_data=(train_x, train_y),
batch_size=FLAGS.batch_size,
verbose=0)
return model
def membership_test(model, pois_x, pois_y):
"""Membership inference - detect poisoning."""
probs = model.predict(np.concatenate([pois_x, np.zeros_like(pois_x)]))
return np.multiply(probs[0, :] - probs[1, :], pois_y).sum()
def train_and_score(dataset):
"""Complete training run with membership inference score."""
x, y, pois_x, pois_y, i = dataset
np.random.seed(i)
tf.set_random_seed(i)
tf.reset_default_graph()
model = build_model(x, y)
model = train_model(model, x, y)
return membership_test(model, pois_x, pois_y)
def main(unused_argv):
del unused_argv
# Load training and test data.
np.random.seed(0)
(trn_x, trn_y), _ = tf.keras.datasets.fashion_mnist.load_data()
trn_inds = np.where(trn_y < 2)[0]
trn_x = -.5 + trn_x[trn_inds] / 255.
trn_y = np.eye(2)[trn_y[trn_inds]]
# subsample dataset
ss_inds = np.random.choice(trn_x.shape[0], trn_x.shape[0]//2, replace=False)
trn_x = trn_x[ss_inds]
trn_y = trn_y[ss_inds]
init_model = build_model(trn_x, trn_y)
_ = train_model(init_model, trn_x, trn_y, save_weights=True)
auditor = audit.AuditAttack(trn_x, trn_y, train_and_score)
thresh, _, _ = auditor.run(FLAGS.pois_ct, FLAGS.attack_type, FLAGS.num_trials,
alpha=FLAGS.alpha, threshold=None,
l2_norm=FLAGS.attack_l2_norm)
_, eps, acc = auditor.run(FLAGS.pois_ct, FLAGS.attack_type, FLAGS.num_trials,
alpha=FLAGS.alpha, threshold=thresh,
l2_norm=FLAGS.attack_l2_norm)
epsilon_ub = compute_epsilon(trn_x.shape[0])
print("Analysis epsilon is {}.".format(epsilon_ub))
print("At threshold={}, epsilon={}.".format(thresh, eps))
print("The best accuracy at distinguishing poisoning is {}.".format(acc))
if __name__ == '__main__':
app.run(main)

View file

@ -0,0 +1,156 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Auditing a model which computes the mean of a synthetic dataset.
This gives an example for instrumenting the auditor to audit a user-given sample."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from tensorflow_privacy.privacy.optimizers import dp_optimizer_vectorized
from absl import app
from absl import flags
import audit
#### FLAGS
FLAGS = flags.FLAGS
flags.DEFINE_float('learning_rate', 0.15, 'Learning rate for training')
flags.DEFINE_float('noise_multiplier', 1.1,
'Ratio of the standard deviation to the clipping norm')
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
flags.DEFINE_integer('batch_size', 250, 'Batch size')
flags.DEFINE_integer('d', 250, 'Data dimension')
flags.DEFINE_integer('epochs', 1, 'Number of epochs')
flags.DEFINE_integer(
'microbatches', 250, 'Number of microbatches '
'(must evenly divide batch_size)')
flags.DEFINE_string('attack_type', "clip_aware", 'clip_aware or backdoor')
flags.DEFINE_integer('num_trials', 100, 'Number of trials for auditing')
flags.DEFINE_float('attack_l2_norm', 10, 'Size of poisoning data')
flags.DEFINE_float('alpha', 0.05, '1-confidence')
FLAGS = flags.FLAGS
def compute_epsilon(train_size):
"""Computes epsilon value for given hyperparameters."""
if FLAGS.noise_multiplier == 0.0:
return float('inf')
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
sampling_probability = FLAGS.batch_size / train_size
steps = FLAGS.epochs * train_size / FLAGS.batch_size
rdp = compute_rdp(q=sampling_probability,
noise_multiplier=FLAGS.noise_multiplier,
steps=steps,
orders=orders)
# Delta is set to approximate 1 / (number of training points).
return get_privacy_spent(orders, rdp, target_delta=1e-5)[0]
def build_model(x, y):
del x, y
model = tf.keras.Sequential([tf.keras.layers.Dense(
1, input_shape=(FLAGS.d,),
use_bias=False, kernel_initializer=tf.keras.initializers.Zeros())])
return model
def train_model(model, train_x, train_y):
"""Train the model on given data."""
optimizer = dp_optimizer_vectorized.VectorizedDPSGD(
l2_norm_clip=FLAGS.l2_norm_clip,
noise_multiplier=FLAGS.noise_multiplier,
num_microbatches=FLAGS.microbatches,
learning_rate=FLAGS.learning_rate)
# gradient of (.5-x.w)^2 is 2(.5-x.w)x
loss = tf.keras.losses.MeanSquaredError(reduction=tf.losses.Reduction.NONE)
# Compile model with Keras
model.compile(optimizer=optimizer, loss=loss, metrics=['mse'])
# Train model with Keras
model.fit(train_x, train_y,
epochs=FLAGS.epochs,
validation_data=(train_x, train_y),
batch_size=FLAGS.batch_size,
verbose=0)
return model
def membership_test(model, pois_x, pois_y):
"""Membership inference - detect poisoning."""
del pois_y
return model.predict(pois_x)
def gen_data(n, d):
"""Make binomial dataset."""
x = np.random.normal(size=(n, d))
y = np.ones(shape=(n,))/2.
return x, y
def train_and_score(dataset):
"""Complete training run with membership inference score."""
x, y, pois_x, pois_y, i = dataset
np.random.seed(i)
tf.set_random_seed(i)
model = build_model(x, y)
model = train_model(model, x, y)
return membership_test(model, pois_x, pois_y)
def main(unused_argv):
del unused_argv
# Load training and test data.
np.random.seed(0)
x, y = gen_data(1 + FLAGS.batch_size, FLAGS.d)
auditor = audit.AuditAttack(x, y, train_and_score)
# we will instrument the auditor to simply backdoor the last feature
pois_x1, pois_x2 = x[:-1].copy(), x[:-1].copy()
pois_x1[-1] = x[-1]
pois_y = y[:-1]
target_x = x[-1][None, :]
assert np.unique(np.nonzero(pois_x1 - pois_x2)[0]).size == 1
pois_data = (pois_x1, pois_y), (pois_x2, pois_y), (target_x, y[-1])
poisoning = {}
poisoning["data"] = (pois_data[0], pois_data[1])
poisoning["pois"] = pois_data[2]
auditor.poisoning = poisoning
thresh, _, _ = auditor.run(1, None, FLAGS.num_trials, alpha=FLAGS.alpha)
_, eps, acc = auditor.run(1, None, FLAGS.num_trials, alpha=FLAGS.alpha,
threshold=thresh)
epsilon_ub = compute_epsilon(FLAGS.batch_size)
print("Analysis epsilon is {}.".format(epsilon_ub))
print("At threshold={}, epsilon={}.".format(thresh, eps))
print("The best accuracy at distinguishing poisoning is {}.".format(acc))
if __name__ == '__main__':
app.run(main)