forked from 626_privacy/tensorflow_privacy
Option for plotting attack results in the same figure.
PiperOrigin-RevId: 333225502
This commit is contained in:
parent
677b3d9e9a
commit
7c53757250
9 changed files with 261 additions and 186 deletions
|
@ -573,18 +573,19 @@ def get_flattened_attack_metrics(results: AttackResults):
|
|||
results: membership inference attack results.
|
||||
|
||||
Returns:
|
||||
properties: a list of (slice, attack_type, metric name)
|
||||
types: a list of attack types
|
||||
slices: a list of slices
|
||||
attack_metrics: a list of metric names
|
||||
values: a list of metric values, i-th element correspond to properties[i]
|
||||
"""
|
||||
properties = []
|
||||
types = []
|
||||
slices = []
|
||||
attack_metrics = []
|
||||
values = []
|
||||
for attack_result in results.single_attack_results:
|
||||
slice_spec = attack_result.slice_spec
|
||||
prop = [str(slice_spec), str(attack_result.attack_type)]
|
||||
properties += [prop + ['adv'], prop + ['auc']]
|
||||
values += [
|
||||
float(attack_result.get_attacker_advantage()),
|
||||
float(attack_result.get_auc())
|
||||
]
|
||||
|
||||
return properties, values
|
||||
types += [str(attack_result.attack_type)] * 2
|
||||
slices += [str(attack_result.slice_spec)] * 2
|
||||
attack_metrics += ['adv', 'auc']
|
||||
values += [float(attack_result.get_attacker_advantage()),
|
||||
float(attack_result.get_auc())]
|
||||
return types, slices, attack_metrics, values
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# Lint as: python3
|
||||
"""A callback and a function in keras for membership inference attack."""
|
||||
|
||||
import os
|
||||
from typing import Iterable
|
||||
|
||||
from absl import logging
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
@ -27,7 +27,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard
|
||||
|
||||
|
||||
def calculate_losses(model, data, labels):
|
||||
|
@ -55,7 +55,8 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
|||
in_train, out_train,
|
||||
slicing_spec: SlicingSpec = None,
|
||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
||||
tensorboard_dir=None):
|
||||
tensorboard_dir=None,
|
||||
tensorboard_merge_classifiers=False):
|
||||
"""Initalizes the callback.
|
||||
|
||||
Args:
|
||||
|
@ -64,18 +65,28 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
|||
slicing_spec: slicing specification of the attack
|
||||
attack_types: a list of attacks, each of type AttackType
|
||||
tensorboard_dir: directory for tensorboard summary
|
||||
tensorboard_merge_classifiers: if true, plot different classifiers with
|
||||
the same slicing_spec and metric in the same figure
|
||||
"""
|
||||
self._in_train_data, self._in_train_labels = in_train
|
||||
self._out_train_data, self._out_train_labels = out_train
|
||||
self._slicing_spec = slicing_spec
|
||||
self._attack_types = attack_types
|
||||
# Setup tensorboard writer if tensorboard_dir is specified
|
||||
self._tensorboard_merge_classifiers = tensorboard_merge_classifiers
|
||||
if tensorboard_dir:
|
||||
if tensorboard_merge_classifiers:
|
||||
self._writers = {}
|
||||
with tf.Graph().as_default():
|
||||
self._writer = tf.summary.FileWriter(tensorboard_dir)
|
||||
for attack_type in attack_types:
|
||||
self._writers[attack_type.name] = tf.summary.FileWriter(
|
||||
os.path.join(tensorboard_dir, 'MI', attack_type.name))
|
||||
else:
|
||||
with tf.Graph().as_default():
|
||||
self._writers = tf.summary.FileWriter(
|
||||
os.path.join(tensorboard_dir, 'MI'))
|
||||
logging.info('Will write to tensorboard.')
|
||||
else:
|
||||
self._writer = None
|
||||
self._writers = None
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
results = run_attack_on_keras_model(
|
||||
|
@ -86,15 +97,16 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
|||
self._attack_types)
|
||||
logging.info(results)
|
||||
|
||||
attack_properties, attack_values = get_flattened_attack_metrics(results)
|
||||
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
||||
results)
|
||||
print('Attack result:')
|
||||
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
|
||||
zip(attack_properties, attack_values)]))
|
||||
print('\n'.join([' %s: %.4f' % (', '.join([s, t, m]), v) for t, s, m, v in
|
||||
zip(att_types, att_slices, att_metrics, att_values)]))
|
||||
|
||||
# Write to tensorboard if tensorboard_dir is specified
|
||||
attack_property_tags = ['attack/' + '_'.join(p) for p in attack_properties]
|
||||
write_to_tensorboard(self._writer, attack_property_tags, attack_values,
|
||||
epoch)
|
||||
if self._writers is not None:
|
||||
write_results_to_tensorboard(results, self._writers, epoch,
|
||||
self._tensorboard_merge_classifiers)
|
||||
|
||||
|
||||
def run_attack_on_keras_model(
|
||||
|
|
|
@ -26,95 +26,86 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import MembershipInferenceCallback
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import run_attack_on_keras_model
|
||||
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
|
||||
flags.DEFINE_integer('batch_size', 256, 'Batch size')
|
||||
flags.DEFINE_integer('epochs', 10, 'Number of epochs')
|
||||
flags.DEFINE_float('learning_rate', 0.02, 'Learning rate for training')
|
||||
flags.DEFINE_integer('batch_size', 250, 'Batch size')
|
||||
flags.DEFINE_integer('epochs', 100, 'Number of epochs')
|
||||
flags.DEFINE_string('model_dir', None, 'Model directory.')
|
||||
flags.DEFINE_bool('tensorboard_merge_classifiers', False, 'If true, plot '
|
||||
'different classifiers with the same slicing_spec and metric '
|
||||
'in the same figure.')
|
||||
|
||||
|
||||
def cnn_model():
|
||||
"""Define a CNN model."""
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv2D(
|
||||
16,
|
||||
8,
|
||||
strides=2,
|
||||
padding='same',
|
||||
activation='relu',
|
||||
input_shape=(28, 28, 1)),
|
||||
tf.keras.layers.MaxPool2D(2, 1),
|
||||
tf.keras.layers.Conv2D(
|
||||
32, 4, strides=2, padding='valid', activation='relu'),
|
||||
tf.keras.layers.MaxPool2D(2, 1),
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dense(32, activation='relu'),
|
||||
tf.keras.layers.Dense(10)
|
||||
])
|
||||
def small_cnn():
|
||||
"""Setup a small CNN for image classification."""
|
||||
model = tf.keras.models.Sequential()
|
||||
model.add(tf.keras.layers.Input(shape=(32, 32, 3)))
|
||||
|
||||
for _ in range(3):
|
||||
model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu'))
|
||||
model.add(tf.keras.layers.MaxPooling2D())
|
||||
|
||||
model.add(tf.keras.layers.Flatten())
|
||||
model.add(tf.keras.layers.Dense(64, activation='relu'))
|
||||
model.add(tf.keras.layers.Dense(10))
|
||||
return model
|
||||
|
||||
|
||||
def load_mnist():
|
||||
"""Loads MNIST and preprocesses to combine training and validation data."""
|
||||
(train_data,
|
||||
train_labels), (test_data,
|
||||
test_labels) = tf.keras.datasets.mnist.load_data()
|
||||
def load_cifar10():
|
||||
"""Loads CIFAR10 data."""
|
||||
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
|
||||
|
||||
train_data = np.array(train_data, dtype=np.float32) / 255
|
||||
test_data = np.array(test_data, dtype=np.float32) / 255
|
||||
x_train = np.array(x_train, dtype=np.float32) / 255
|
||||
x_test = np.array(x_test, dtype=np.float32) / 255
|
||||
|
||||
train_data = train_data.reshape((train_data.shape[0], 28, 28, 1))
|
||||
test_data = test_data.reshape((test_data.shape[0], 28, 28, 1))
|
||||
y_train = np.array(y_train, dtype=np.int32).squeeze()
|
||||
y_test = np.array(y_test, dtype=np.int32).squeeze()
|
||||
|
||||
train_labels = np.array(train_labels, dtype=np.int32)
|
||||
test_labels = np.array(test_labels, dtype=np.int32)
|
||||
|
||||
return train_data, train_labels, test_data, test_labels
|
||||
return x_train, y_train, x_test, y_test
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
# Load training and test data.
|
||||
train_data, train_labels, test_data, test_labels = load_mnist()
|
||||
x_train, y_train, x_test, y_test = load_cifar10()
|
||||
|
||||
# Get model, optimizer and specify loss.
|
||||
model = cnn_model()
|
||||
optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
|
||||
model = small_cnn()
|
||||
optimizer = tf.keras.optimizers.SGD(lr=FLAGS.learning_rate, momentum=0.9)
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
||||
|
||||
# Get callback for membership inference attack.
|
||||
mia_callback = MembershipInferenceCallback(
|
||||
(train_data, train_labels), (test_data, test_labels),
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK],
|
||||
tensorboard_dir=FLAGS.model_dir)
|
||||
(x_train, y_train),
|
||||
(x_test, y_test),
|
||||
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK,
|
||||
AttackType.K_NEAREST_NEIGHBORS],
|
||||
tensorboard_dir=FLAGS.model_dir,
|
||||
tensorboard_merge_classifiers=FLAGS.tensorboard_merge_classifiers)
|
||||
|
||||
# Train model with Keras
|
||||
model.fit(
|
||||
train_data,
|
||||
train_labels,
|
||||
x_train,
|
||||
y_train,
|
||||
epochs=FLAGS.epochs,
|
||||
validation_data=(test_data, test_labels),
|
||||
validation_data=(x_test, y_test),
|
||||
batch_size=FLAGS.batch_size,
|
||||
callbacks=[mia_callback],
|
||||
verbose=2)
|
||||
|
||||
print('End of training attack:')
|
||||
attack_results = run_attack_on_keras_model(
|
||||
model, (train_data, train_labels), (test_data, test_labels),
|
||||
model, (x_train, y_train), (x_test, y_test),
|
||||
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
|
||||
attack_types=[
|
||||
AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS
|
||||
])
|
||||
|
||||
attack_properties, attack_values = get_flattened_attack_metrics(
|
||||
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
||||
attack_results)
|
||||
print('\n'.join([
|
||||
' %s: %.4f' % (', '.join(p), r)
|
||||
for p, r in zip(attack_properties, attack_values)
|
||||
]))
|
||||
print('\n'.join([' %s: %.4f' % (', '.join([s, t, m]), v) for t, s, m, v in
|
||||
zip(att_types, att_slices, att_metrics, att_values)]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -67,9 +67,12 @@ class UtilsTest(absltest.TestCase):
|
|||
(self.test_data, self.test_labels),
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK])
|
||||
self.assertIsInstance(results, AttackResults)
|
||||
attack_properties, attack_values = get_flattened_attack_metrics(results)
|
||||
self.assertLen(attack_properties, 2)
|
||||
self.assertLen(attack_values, 2)
|
||||
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
||||
results)
|
||||
self.assertLen(att_types, 2)
|
||||
self.assertLen(att_slices, 2)
|
||||
self.assertLen(att_metrics, 2)
|
||||
self.assertLen(att_values, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -15,21 +15,18 @@
|
|||
# Lint as: python3
|
||||
"""A hook and a function in tf estimator for membership inference attack."""
|
||||
|
||||
import os
|
||||
from typing import Iterable
|
||||
|
||||
from absl import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard
|
||||
|
||||
|
||||
def calculate_losses(estimator, input_fn, labels):
|
||||
|
@ -43,7 +40,7 @@ def calculate_losses(estimator, input_fn, labels):
|
|||
Args:
|
||||
estimator: model to make prediction
|
||||
input_fn: input function to be used in estimator.predict
|
||||
labels: true labels of samples (integer valued)
|
||||
labels: array of size (n_samples, ), true labels of samples (integer valued)
|
||||
|
||||
Returns:
|
||||
preds: probability vector of each sample
|
||||
|
@ -64,7 +61,8 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
|||
input_fn_constructor,
|
||||
slicing_spec: SlicingSpec = None,
|
||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
||||
writer=None):
|
||||
tensorboard_dir=None,
|
||||
tensorboard_merge_classifiers=False):
|
||||
"""Initialize the hook.
|
||||
|
||||
Args:
|
||||
|
@ -75,7 +73,9 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
|||
the input_fn for model prediction
|
||||
slicing_spec: slicing specification of the attack
|
||||
attack_types: a list of attacks, each of type AttackType
|
||||
writer: summary writer for tensorboard
|
||||
tensorboard_dir: directory for tensorboard summary
|
||||
tensorboard_merge_classifiers: if true, plot different classifiers with
|
||||
the same slicing_spec and metric in the same figure
|
||||
"""
|
||||
in_train_data, self._in_train_labels = in_train
|
||||
out_train_data, self._out_train_labels = out_train
|
||||
|
@ -88,9 +88,21 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
|||
self._estimator = estimator
|
||||
self._slicing_spec = slicing_spec
|
||||
self._attack_types = attack_types
|
||||
self._writer = writer
|
||||
if self._writer:
|
||||
self._tensorboard_merge_classifiers = tensorboard_merge_classifiers
|
||||
if tensorboard_dir:
|
||||
if tensorboard_merge_classifiers:
|
||||
self._writers = {}
|
||||
with tf.Graph().as_default():
|
||||
for attack_type in attack_types:
|
||||
self._writers[attack_type.name] = tf.summary.FileWriter(
|
||||
os.path.join(tensorboard_dir, 'MI', attack_type.name))
|
||||
else:
|
||||
with tf.Graph().as_default():
|
||||
self._writers = tf.summary.FileWriter(
|
||||
os.path.join(tensorboard_dir, 'MI'))
|
||||
logging.info('Will write to tensorboard.')
|
||||
else:
|
||||
self._writers = None
|
||||
|
||||
def end(self, session):
|
||||
results = run_attack_helper(self._estimator,
|
||||
|
@ -101,16 +113,17 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
|||
self._attack_types)
|
||||
logging.info(results)
|
||||
|
||||
attack_properties, attack_values = get_flattened_attack_metrics(results)
|
||||
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
||||
results)
|
||||
print('Attack result:')
|
||||
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
|
||||
zip(attack_properties, attack_values)]))
|
||||
print('\n'.join([' %s: %.4f' % (', '.join([s, t, m]), v) for t, s, m, v in
|
||||
zip(att_types, att_slices, att_metrics, att_values)]))
|
||||
|
||||
# Write to tensorboard if writer is specified
|
||||
# Write to tensorboard if tensorboard_dir is specified
|
||||
global_step = self._estimator.get_variable_value('global_step')
|
||||
attack_property_tags = ['attack/' + '_'.join(p) for p in attack_properties]
|
||||
write_to_tensorboard(self._writer, attack_property_tags, attack_values,
|
||||
global_step)
|
||||
if self._writers is not None:
|
||||
write_results_to_tensorboard(results, self._writers, global_step,
|
||||
self._tensorboard_merge_classifiers)
|
||||
|
||||
|
||||
def run_attack_on_tf_estimator_model(
|
||||
|
@ -184,4 +197,3 @@ def run_attack_helper(
|
|||
slicing_spec=slicing_spec,
|
||||
attack_types=attack_types)
|
||||
return results
|
||||
|
||||
|
|
|
@ -27,30 +27,27 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import MembershipInferenceTrainingHook
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import run_attack_on_tf_estimator_model
|
||||
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
|
||||
flags.DEFINE_integer('batch_size', 256, 'Batch size')
|
||||
flags.DEFINE_integer('epochs', 10, 'Number of epochs')
|
||||
flags.DEFINE_float('learning_rate', 0.02, 'Learning rate for training')
|
||||
flags.DEFINE_integer('batch_size', 250, 'Batch size')
|
||||
flags.DEFINE_integer('epochs', 100, 'Number of epochs')
|
||||
flags.DEFINE_string('model_dir', None, 'Model directory.')
|
||||
flags.DEFINE_bool('tensorboard_merge_classifiers', False, 'If true, plot '
|
||||
'different classifiers with the same slicing_spec and metric '
|
||||
'in the same figure.')
|
||||
|
||||
|
||||
def cnn_model_fn(features, labels, mode):
|
||||
"""Model function for a CNN."""
|
||||
def small_cnn_fn(features, labels, mode):
|
||||
"""Setup a small CNN for image classification."""
|
||||
input_layer = tf.reshape(features['x'], [-1, 32, 32, 3])
|
||||
for _ in range(3):
|
||||
y = tf.keras.layers.Conv2D(32, (3, 3), activation='relu')(input_layer)
|
||||
y = tf.keras.layers.MaxPool2D()(y)
|
||||
|
||||
# Define CNN architecture using tf.keras.layers.
|
||||
input_layer = tf.reshape(features['x'], [-1, 28, 28, 1])
|
||||
y = tf.keras.layers.Conv2D(
|
||||
16, 8, strides=2, padding='same', activation='relu').apply(input_layer)
|
||||
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
|
||||
y = tf.keras.layers.Conv2D(
|
||||
32, 4, strides=2, padding='valid', activation='relu').apply(y)
|
||||
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
|
||||
y = tf.keras.layers.Flatten().apply(y)
|
||||
y = tf.keras.layers.Dense(32, activation='relu').apply(y)
|
||||
logits = tf.keras.layers.Dense(10).apply(y)
|
||||
y = tf.keras.layers.Flatten()(y)
|
||||
y = tf.keras.layers.Dense(64, activation='relu')(y)
|
||||
logits = tf.keras.layers.Dense(10)(y)
|
||||
|
||||
if mode != tf.estimator.ModeKeys.PREDICT:
|
||||
vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
|
@ -59,7 +56,8 @@ def cnn_model_fn(features, labels, mode):
|
|||
|
||||
# Configure the training op (for TRAIN mode).
|
||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||
optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
|
||||
optimizer = tf.train.MomentumOptimizer(learning_rate=FLAGS.learning_rate,
|
||||
momentum=0.9)
|
||||
global_step = tf.train.get_global_step()
|
||||
train_op = optimizer.minimize(loss=scalar_loss, global_step=global_step)
|
||||
return tf.estimator.EstimatorSpec(
|
||||
|
@ -81,19 +79,17 @@ def cnn_model_fn(features, labels, mode):
|
|||
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
|
||||
|
||||
|
||||
def load_mnist():
|
||||
"""Loads MNIST and preprocesses to combine training and validation data."""
|
||||
(train_data,
|
||||
train_labels), (test_data,
|
||||
test_labels) = tf.keras.datasets.mnist.load_data()
|
||||
def load_cifar10():
|
||||
"""Loads CIFAR10 data."""
|
||||
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
|
||||
|
||||
train_data = np.array(train_data, dtype=np.float32) / 255
|
||||
test_data = np.array(test_data, dtype=np.float32) / 255
|
||||
x_train = np.array(x_train, dtype=np.float32) / 255
|
||||
x_test = np.array(x_test, dtype=np.float32) / 255
|
||||
|
||||
train_labels = np.array(train_labels, dtype=np.int32)
|
||||
test_labels = np.array(test_labels, dtype=np.int32)
|
||||
y_train = np.array(y_train, dtype=np.int32).squeeze()
|
||||
y_test = np.array(y_test, dtype=np.int32).squeeze()
|
||||
|
||||
return train_data, train_labels, test_data, test_labels
|
||||
return x_train, y_train, x_test, y_test
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
|
@ -103,39 +99,38 @@ def main(unused_argv):
|
|||
logging.get_absl_handler().use_absl_log_file()
|
||||
|
||||
# Load training and test data.
|
||||
train_data, train_labels, test_data, test_labels = load_mnist()
|
||||
x_train, y_train, x_test, y_test = load_cifar10()
|
||||
|
||||
# Instantiate the tf.Estimator.
|
||||
mnist_classifier = tf.estimator.Estimator(
|
||||
model_fn=cnn_model_fn, model_dir=FLAGS.model_dir)
|
||||
model_fn=small_cnn_fn, model_dir=FLAGS.model_dir)
|
||||
|
||||
# A function to construct input_fn given (data, label), to be used by the
|
||||
# membership inference training hook.
|
||||
def input_fn_constructor(x, y):
|
||||
return tf.estimator.inputs.numpy_input_fn(x={'x': x}, y=y, shuffle=False)
|
||||
|
||||
with tf.Graph().as_default():
|
||||
# Get a summary writer for the hook to write to tensorboard.
|
||||
# Can set summary_writer to None if not needed.
|
||||
if FLAGS.model_dir:
|
||||
summary_writer = tf.summary.FileWriter(FLAGS.model_dir)
|
||||
else:
|
||||
summary_writer = None
|
||||
# Get hook for membership inference attack.
|
||||
mia_hook = MembershipInferenceTrainingHook(
|
||||
mnist_classifier, (train_data, train_labels), (test_data, test_labels),
|
||||
mnist_classifier,
|
||||
(x_train, y_train),
|
||||
(x_test, y_test),
|
||||
input_fn_constructor,
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK],
|
||||
writer=summary_writer)
|
||||
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK,
|
||||
AttackType.K_NEAREST_NEIGHBORS],
|
||||
tensorboard_dir=FLAGS.model_dir,
|
||||
tensorboard_merge_classifiers=FLAGS.tensorboard_merge_classifiers)
|
||||
|
||||
# Create tf.Estimator input functions for the training and test data.
|
||||
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
x={'x': train_data},
|
||||
y=train_labels,
|
||||
x={'x': x_train},
|
||||
y=y_train,
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_epochs=FLAGS.epochs,
|
||||
shuffle=True)
|
||||
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
x={'x': test_data}, y=test_labels, num_epochs=1, shuffle=False)
|
||||
x={'x': x_test}, y=y_test, num_epochs=1, shuffle=False)
|
||||
|
||||
# Training loop.
|
||||
steps_per_epoch = 60000 // FLAGS.batch_size
|
||||
|
@ -151,18 +146,15 @@ def main(unused_argv):
|
|||
|
||||
print('End of training attack')
|
||||
attack_results = run_attack_on_tf_estimator_model(
|
||||
mnist_classifier, (train_data, train_labels), (test_data, test_labels),
|
||||
mnist_classifier, (x_train, y_train), (x_test, y_test),
|
||||
input_fn_constructor,
|
||||
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
|
||||
attack_types=[
|
||||
AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS
|
||||
])
|
||||
attack_properties, attack_values = get_flattened_attack_metrics(
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS]
|
||||
)
|
||||
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
||||
attack_results)
|
||||
print('\n'.join([
|
||||
' %s: %.4f' % (', '.join(p), r)
|
||||
for p, r in zip(attack_properties, attack_values)
|
||||
]))
|
||||
print('\n'.join([' %s: %.4f' % (', '.join([s, t, m]), v) for t, s, m, v in
|
||||
zip(att_types, att_slices, att_metrics, att_values)]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -88,9 +88,12 @@ class UtilsTest(absltest.TestCase):
|
|||
self.test_labels,
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK])
|
||||
self.assertIsInstance(results, AttackResults)
|
||||
attack_properties, attack_values = get_flattened_attack_metrics(results)
|
||||
self.assertLen(attack_properties, 2)
|
||||
self.assertLen(attack_values, 2)
|
||||
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
||||
results)
|
||||
self.assertLen(att_types, 2)
|
||||
self.assertLen(att_slices, 2)
|
||||
self.assertLen(att_metrics, 2)
|
||||
self.assertLen(att_values, 2)
|
||||
|
||||
def test_run_attack_on_tf_estimator_model(self):
|
||||
"""Test the attack on the final models."""
|
||||
|
@ -104,9 +107,12 @@ class UtilsTest(absltest.TestCase):
|
|||
input_fn_constructor,
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK])
|
||||
self.assertIsInstance(results, AttackResults)
|
||||
attack_properties, attack_values = get_flattened_attack_metrics(results)
|
||||
self.assertLen(attack_properties, 2)
|
||||
self.assertLen(attack_values, 2)
|
||||
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
||||
results)
|
||||
self.assertLen(att_types, 2)
|
||||
self.assertLen(att_slices, 2)
|
||||
self.assertLen(att_metrics, 2)
|
||||
self.assertLen(att_values, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -20,7 +20,7 @@ from typing import Text, Dict, Union, List, Any, Tuple
|
|||
import numpy as np
|
||||
import scipy.special
|
||||
from sklearn import metrics
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
ArrayDict = Dict[Text, np.ndarray]
|
||||
Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
|
||||
|
@ -229,10 +229,12 @@ def log_loss(labels: np.ndarray, pred: np.ndarray, small_value=1e-8):
|
|||
"""Compute the cross entropy loss.
|
||||
|
||||
Args:
|
||||
labels: numpy array, labels[i] is the true label (scalar) of the i-th sample
|
||||
pred: numpy array, pred[i] is the probability vector of the i-th sample
|
||||
small_value: np.log can become -inf if the probability is too close to 0, so
|
||||
the probability is clipped below by small_value.
|
||||
labels: numpy array of shape (num_samples,) labels[i] is the true label
|
||||
(scalar) of the i-th sample
|
||||
pred: numpy array of shape(num_samples, num_classes) where pred[i] is the
|
||||
probability vector of the i-th sample
|
||||
small_value: a scalar. np.log can become -inf if the probability is too
|
||||
close to 0, so the probability is clipped below by small_value.
|
||||
|
||||
Returns:
|
||||
the cross-entropy loss of each sample
|
||||
|
@ -243,26 +245,3 @@ def log_loss(labels: np.ndarray, pred: np.ndarray, small_value=1e-8):
|
|||
def log_loss_from_logits(labels: np.ndarray, logits: np.ndarray):
|
||||
"""Compute the cross entropy loss from logits."""
|
||||
return log_loss(labels, scipy.special.softmax(logits, axis=-1))
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Tensorboard
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def write_to_tensorboard(writer, tags, values, step):
|
||||
"""Write metrics to tensorboard.
|
||||
|
||||
Args:
|
||||
writer: tensorboard writer
|
||||
tags: a list of tags of metrics
|
||||
values: a list of values of metrics
|
||||
step: step for the summary
|
||||
"""
|
||||
if writer is None:
|
||||
return
|
||||
summary = tf.Summary()
|
||||
for tag, val in zip(tags, values):
|
||||
summary.value.add(tag=tag, simple_value=val)
|
||||
writer.add_summary(summary, step)
|
||||
writer.flush()
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2020, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Utility functions for writing attack results to tensorboard."""
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
||||
|
||||
|
||||
def write_to_tensorboard(writers, tags, values, step):
|
||||
"""Write metrics to tensorboard.
|
||||
|
||||
Args:
|
||||
writers: a list of tensorboard writers or one writer to be used for metrics.
|
||||
If it's a list, it should be of the same length as tags
|
||||
tags: a list of tags of metrics
|
||||
values: a list of values of metrics with the same length as tags
|
||||
step: step for the tensorboard summary
|
||||
"""
|
||||
if writers is None or not writers:
|
||||
raise ValueError('write_to_tensorboard does not get any writer.')
|
||||
|
||||
if not isinstance(writers, list):
|
||||
writers = [writers] * len(tags)
|
||||
|
||||
assert len(writers) == len(tags) == len(values)
|
||||
|
||||
for writer, tag, val in zip(writers, tags, values):
|
||||
summary = tf.Summary()
|
||||
summary.value.add(tag=tag, simple_value=val)
|
||||
writer.add_summary(summary, step)
|
||||
|
||||
for writer in set(writers):
|
||||
writer.flush()
|
||||
|
||||
|
||||
def write_results_to_tensorboard(
|
||||
attack_results: AttackResults,
|
||||
writers: Union[tf.summary.FileWriter, List[tf.summary.FileWriter]],
|
||||
step: int,
|
||||
merge_classifiers: bool):
|
||||
"""Write attack results to tensorboard.
|
||||
|
||||
Args:
|
||||
attack_results: results from attack
|
||||
writers: a list of tensorboard writers or one writer to be used for metrics
|
||||
step: step for the tensorboard summary
|
||||
merge_classifiers: if true, plot different classifiers with the same
|
||||
slicing_spec and metric in the same figure
|
||||
"""
|
||||
if writers is None or not writers:
|
||||
raise ValueError('write_results_to_tensorboard does not get any writer.')
|
||||
|
||||
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
||||
attack_results)
|
||||
if merge_classifiers:
|
||||
att_tags = ['attack/' + '_'.join([s, m]) for s, m in
|
||||
zip(att_slices, att_metrics)]
|
||||
write_to_tensorboard([writers[t] for t in att_types],
|
||||
att_tags, att_values, step)
|
||||
else:
|
||||
att_tags = ['attack/' + '_'.join([s, t, m]) for t, s, m in
|
||||
zip(att_types, att_slices, att_metrics)]
|
||||
write_to_tensorboard(writers, att_tags, att_values, step)
|
Loading…
Reference in a new issue