diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index cfd67a6..f3c3158 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -573,18 +573,19 @@ def get_flattened_attack_metrics(results: AttackResults): results: membership inference attack results. Returns: - properties: a list of (slice, attack_type, metric name) + types: a list of attack types + slices: a list of slices + attack_metrics: a list of metric names values: a list of metric values, i-th element correspond to properties[i] """ - properties = [] + types = [] + slices = [] + attack_metrics = [] values = [] for attack_result in results.single_attack_results: - slice_spec = attack_result.slice_spec - prop = [str(slice_spec), str(attack_result.attack_type)] - properties += [prop + ['adv'], prop + ['auc']] - values += [ - float(attack_result.get_attacker_advantage()), - float(attack_result.get_auc()) - ] - - return properties, values + types += [str(attack_result.attack_type)] * 2 + slices += [str(attack_result.slice_spec)] * 2 + attack_metrics += ['adv', 'auc'] + values += [float(attack_result.get_attacker_advantage()), + float(attack_result.get_auc())] + return types, slices, attack_metrics, values diff --git a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py index e43346a..9bbab25 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py @@ -15,8 +15,8 @@ # Lint as: python3 """A callback and a function in keras for membership inference attack.""" +import os from typing import Iterable - from absl import logging import tensorflow.compat.v1 as tf @@ -27,7 +27,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss -from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard +from tensorflow_privacy.privacy.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard def calculate_losses(model, data, labels): @@ -55,7 +55,8 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback): in_train, out_train, slicing_spec: SlicingSpec = None, attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,), - tensorboard_dir=None): + tensorboard_dir=None, + tensorboard_merge_classifiers=False): """Initalizes the callback. Args: @@ -64,18 +65,28 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback): slicing_spec: slicing specification of the attack attack_types: a list of attacks, each of type AttackType tensorboard_dir: directory for tensorboard summary + tensorboard_merge_classifiers: if true, plot different classifiers with + the same slicing_spec and metric in the same figure """ self._in_train_data, self._in_train_labels = in_train self._out_train_data, self._out_train_labels = out_train self._slicing_spec = slicing_spec self._attack_types = attack_types - # Setup tensorboard writer if tensorboard_dir is specified + self._tensorboard_merge_classifiers = tensorboard_merge_classifiers if tensorboard_dir: - with tf.Graph().as_default(): - self._writer = tf.summary.FileWriter(tensorboard_dir) + if tensorboard_merge_classifiers: + self._writers = {} + with tf.Graph().as_default(): + for attack_type in attack_types: + self._writers[attack_type.name] = tf.summary.FileWriter( + os.path.join(tensorboard_dir, 'MI', attack_type.name)) + else: + with tf.Graph().as_default(): + self._writers = tf.summary.FileWriter( + os.path.join(tensorboard_dir, 'MI')) logging.info('Will write to tensorboard.') else: - self._writer = None + self._writers = None def on_epoch_end(self, epoch, logs=None): results = run_attack_on_keras_model( @@ -86,15 +97,16 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback): self._attack_types) logging.info(results) - attack_properties, attack_values = get_flattened_attack_metrics(results) + att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( + results) print('Attack result:') - print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in - zip(attack_properties, attack_values)])) + print('\n'.join([' %s: %.4f' % (', '.join([s, t, m]), v) for t, s, m, v in + zip(att_types, att_slices, att_metrics, att_values)])) # Write to tensorboard if tensorboard_dir is specified - attack_property_tags = ['attack/' + '_'.join(p) for p in attack_properties] - write_to_tensorboard(self._writer, attack_property_tags, attack_values, - epoch) + if self._writers is not None: + write_results_to_tensorboard(results, self._writers, epoch, + self._tensorboard_merge_classifiers) def run_attack_on_keras_model( diff --git a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_example.py b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_example.py index 50102dc..4a19552 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_example.py @@ -26,95 +26,86 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import MembershipInferenceCallback from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import run_attack_on_keras_model -GradientDescentOptimizer = tf.train.GradientDescentOptimizer FLAGS = flags.FLAGS - -flags.DEFINE_float('learning_rate', .15, 'Learning rate for training') -flags.DEFINE_integer('batch_size', 256, 'Batch size') -flags.DEFINE_integer('epochs', 10, 'Number of epochs') +flags.DEFINE_float('learning_rate', 0.02, 'Learning rate for training') +flags.DEFINE_integer('batch_size', 250, 'Batch size') +flags.DEFINE_integer('epochs', 100, 'Number of epochs') flags.DEFINE_string('model_dir', None, 'Model directory.') +flags.DEFINE_bool('tensorboard_merge_classifiers', False, 'If true, plot ' + 'different classifiers with the same slicing_spec and metric ' + 'in the same figure.') -def cnn_model(): - """Define a CNN model.""" - model = tf.keras.Sequential([ - tf.keras.layers.Conv2D( - 16, - 8, - strides=2, - padding='same', - activation='relu', - input_shape=(28, 28, 1)), - tf.keras.layers.MaxPool2D(2, 1), - tf.keras.layers.Conv2D( - 32, 4, strides=2, padding='valid', activation='relu'), - tf.keras.layers.MaxPool2D(2, 1), - tf.keras.layers.Flatten(), - tf.keras.layers.Dense(32, activation='relu'), - tf.keras.layers.Dense(10) - ]) +def small_cnn(): + """Setup a small CNN for image classification.""" + model = tf.keras.models.Sequential() + model.add(tf.keras.layers.Input(shape=(32, 32, 3))) + + for _ in range(3): + model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu')) + model.add(tf.keras.layers.MaxPooling2D()) + + model.add(tf.keras.layers.Flatten()) + model.add(tf.keras.layers.Dense(64, activation='relu')) + model.add(tf.keras.layers.Dense(10)) return model -def load_mnist(): - """Loads MNIST and preprocesses to combine training and validation data.""" - (train_data, - train_labels), (test_data, - test_labels) = tf.keras.datasets.mnist.load_data() +def load_cifar10(): + """Loads CIFAR10 data.""" + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() - train_data = np.array(train_data, dtype=np.float32) / 255 - test_data = np.array(test_data, dtype=np.float32) / 255 + x_train = np.array(x_train, dtype=np.float32) / 255 + x_test = np.array(x_test, dtype=np.float32) / 255 - train_data = train_data.reshape((train_data.shape[0], 28, 28, 1)) - test_data = test_data.reshape((test_data.shape[0], 28, 28, 1)) + y_train = np.array(y_train, dtype=np.int32).squeeze() + y_test = np.array(y_test, dtype=np.int32).squeeze() - train_labels = np.array(train_labels, dtype=np.int32) - test_labels = np.array(test_labels, dtype=np.int32) - - return train_data, train_labels, test_data, test_labels + return x_train, y_train, x_test, y_test def main(unused_argv): # Load training and test data. - train_data, train_labels, test_data, test_labels = load_mnist() + x_train, y_train, x_test, y_test = load_cifar10() # Get model, optimizer and specify loss. - model = cnn_model() - optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate) + model = small_cnn() + optimizer = tf.keras.optimizers.SGD(lr=FLAGS.learning_rate, momentum=0.9) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) # Get callback for membership inference attack. mia_callback = MembershipInferenceCallback( - (train_data, train_labels), (test_data, test_labels), - attack_types=[AttackType.THRESHOLD_ATTACK], - tensorboard_dir=FLAGS.model_dir) + (x_train, y_train), + (x_test, y_test), + slicing_spec=SlicingSpec(entire_dataset=True, by_class=True), + attack_types=[AttackType.THRESHOLD_ATTACK, + AttackType.K_NEAREST_NEIGHBORS], + tensorboard_dir=FLAGS.model_dir, + tensorboard_merge_classifiers=FLAGS.tensorboard_merge_classifiers) # Train model with Keras model.fit( - train_data, - train_labels, + x_train, + y_train, epochs=FLAGS.epochs, - validation_data=(test_data, test_labels), + validation_data=(x_test, y_test), batch_size=FLAGS.batch_size, callbacks=[mia_callback], verbose=2) print('End of training attack:') attack_results = run_attack_on_keras_model( - model, (train_data, train_labels), (test_data, test_labels), + model, (x_train, y_train), (x_test, y_test), slicing_spec=SlicingSpec(entire_dataset=True, by_class=True), attack_types=[ AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS ]) - - attack_properties, attack_values = get_flattened_attack_metrics( + att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( attack_results) - print('\n'.join([ - ' %s: %.4f' % (', '.join(p), r) - for p, r in zip(attack_properties, attack_values) - ])) + print('\n'.join([' %s: %.4f' % (', '.join([s, t, m]), v) for t, s, m, v in + zip(att_types, att_slices, att_metrics, att_values)])) if __name__ == '__main__': diff --git a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_test.py b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_test.py index 6bdb0fc..cddaadc 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_test.py @@ -67,9 +67,12 @@ class UtilsTest(absltest.TestCase): (self.test_data, self.test_labels), attack_types=[AttackType.THRESHOLD_ATTACK]) self.assertIsInstance(results, AttackResults) - attack_properties, attack_values = get_flattened_attack_metrics(results) - self.assertLen(attack_properties, 2) - self.assertLen(attack_values, 2) + att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( + results) + self.assertLen(att_types, 2) + self.assertLen(att_slices, 2) + self.assertLen(att_metrics, 2) + self.assertLen(att_values, 2) if __name__ == '__main__': diff --git a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py index 9a61cd0..0f5bae8 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py @@ -15,21 +15,18 @@ # Lint as: python3 """A hook and a function in tf estimator for membership inference attack.""" +import os from typing import Iterable - from absl import logging - import numpy as np - import tensorflow.compat.v1 as tf - from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss -from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard +from tensorflow_privacy.privacy.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard def calculate_losses(estimator, input_fn, labels): @@ -43,7 +40,7 @@ def calculate_losses(estimator, input_fn, labels): Args: estimator: model to make prediction input_fn: input function to be used in estimator.predict - labels: true labels of samples (integer valued) + labels: array of size (n_samples, ), true labels of samples (integer valued) Returns: preds: probability vector of each sample @@ -64,7 +61,8 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook): input_fn_constructor, slicing_spec: SlicingSpec = None, attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,), - writer=None): + tensorboard_dir=None, + tensorboard_merge_classifiers=False): """Initialize the hook. Args: @@ -75,7 +73,9 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook): the input_fn for model prediction slicing_spec: slicing specification of the attack attack_types: a list of attacks, each of type AttackType - writer: summary writer for tensorboard + tensorboard_dir: directory for tensorboard summary + tensorboard_merge_classifiers: if true, plot different classifiers with + the same slicing_spec and metric in the same figure """ in_train_data, self._in_train_labels = in_train out_train_data, self._out_train_labels = out_train @@ -88,9 +88,21 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook): self._estimator = estimator self._slicing_spec = slicing_spec self._attack_types = attack_types - self._writer = writer - if self._writer: + self._tensorboard_merge_classifiers = tensorboard_merge_classifiers + if tensorboard_dir: + if tensorboard_merge_classifiers: + self._writers = {} + with tf.Graph().as_default(): + for attack_type in attack_types: + self._writers[attack_type.name] = tf.summary.FileWriter( + os.path.join(tensorboard_dir, 'MI', attack_type.name)) + else: + with tf.Graph().as_default(): + self._writers = tf.summary.FileWriter( + os.path.join(tensorboard_dir, 'MI')) logging.info('Will write to tensorboard.') + else: + self._writers = None def end(self, session): results = run_attack_helper(self._estimator, @@ -101,16 +113,17 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook): self._attack_types) logging.info(results) - attack_properties, attack_values = get_flattened_attack_metrics(results) + att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( + results) print('Attack result:') - print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in - zip(attack_properties, attack_values)])) + print('\n'.join([' %s: %.4f' % (', '.join([s, t, m]), v) for t, s, m, v in + zip(att_types, att_slices, att_metrics, att_values)])) - # Write to tensorboard if writer is specified + # Write to tensorboard if tensorboard_dir is specified global_step = self._estimator.get_variable_value('global_step') - attack_property_tags = ['attack/' + '_'.join(p) for p in attack_properties] - write_to_tensorboard(self._writer, attack_property_tags, attack_values, - global_step) + if self._writers is not None: + write_results_to_tensorboard(results, self._writers, global_step, + self._tensorboard_merge_classifiers) def run_attack_on_tf_estimator_model( @@ -184,4 +197,3 @@ def run_attack_helper( slicing_spec=slicing_spec, attack_types=attack_types) return results - diff --git a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py index c579491..477848f 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py @@ -27,30 +27,27 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import MembershipInferenceTrainingHook from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import run_attack_on_tf_estimator_model -GradientDescentOptimizer = tf.train.GradientDescentOptimizer FLAGS = flags.FLAGS - -flags.DEFINE_float('learning_rate', .15, 'Learning rate for training') -flags.DEFINE_integer('batch_size', 256, 'Batch size') -flags.DEFINE_integer('epochs', 10, 'Number of epochs') +flags.DEFINE_float('learning_rate', 0.02, 'Learning rate for training') +flags.DEFINE_integer('batch_size', 250, 'Batch size') +flags.DEFINE_integer('epochs', 100, 'Number of epochs') flags.DEFINE_string('model_dir', None, 'Model directory.') +flags.DEFINE_bool('tensorboard_merge_classifiers', False, 'If true, plot ' + 'different classifiers with the same slicing_spec and metric ' + 'in the same figure.') -def cnn_model_fn(features, labels, mode): - """Model function for a CNN.""" +def small_cnn_fn(features, labels, mode): + """Setup a small CNN for image classification.""" + input_layer = tf.reshape(features['x'], [-1, 32, 32, 3]) + for _ in range(3): + y = tf.keras.layers.Conv2D(32, (3, 3), activation='relu')(input_layer) + y = tf.keras.layers.MaxPool2D()(y) - # Define CNN architecture using tf.keras.layers. - input_layer = tf.reshape(features['x'], [-1, 28, 28, 1]) - y = tf.keras.layers.Conv2D( - 16, 8, strides=2, padding='same', activation='relu').apply(input_layer) - y = tf.keras.layers.MaxPool2D(2, 1).apply(y) - y = tf.keras.layers.Conv2D( - 32, 4, strides=2, padding='valid', activation='relu').apply(y) - y = tf.keras.layers.MaxPool2D(2, 1).apply(y) - y = tf.keras.layers.Flatten().apply(y) - y = tf.keras.layers.Dense(32, activation='relu').apply(y) - logits = tf.keras.layers.Dense(10).apply(y) + y = tf.keras.layers.Flatten()(y) + y = tf.keras.layers.Dense(64, activation='relu')(y) + logits = tf.keras.layers.Dense(10)(y) if mode != tf.estimator.ModeKeys.PREDICT: vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( @@ -59,7 +56,8 @@ def cnn_model_fn(features, labels, mode): # Configure the training op (for TRAIN mode). if mode == tf.estimator.ModeKeys.TRAIN: - optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate) + optimizer = tf.train.MomentumOptimizer(learning_rate=FLAGS.learning_rate, + momentum=0.9) global_step = tf.train.get_global_step() train_op = optimizer.minimize(loss=scalar_loss, global_step=global_step) return tf.estimator.EstimatorSpec( @@ -81,19 +79,17 @@ def cnn_model_fn(features, labels, mode): return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) -def load_mnist(): - """Loads MNIST and preprocesses to combine training and validation data.""" - (train_data, - train_labels), (test_data, - test_labels) = tf.keras.datasets.mnist.load_data() +def load_cifar10(): + """Loads CIFAR10 data.""" + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() - train_data = np.array(train_data, dtype=np.float32) / 255 - test_data = np.array(test_data, dtype=np.float32) / 255 + x_train = np.array(x_train, dtype=np.float32) / 255 + x_test = np.array(x_test, dtype=np.float32) / 255 - train_labels = np.array(train_labels, dtype=np.int32) - test_labels = np.array(test_labels, dtype=np.int32) + y_train = np.array(y_train, dtype=np.int32).squeeze() + y_test = np.array(y_test, dtype=np.int32).squeeze() - return train_data, train_labels, test_data, test_labels + return x_train, y_train, x_test, y_test def main(unused_argv): @@ -103,39 +99,38 @@ def main(unused_argv): logging.get_absl_handler().use_absl_log_file() # Load training and test data. - train_data, train_labels, test_data, test_labels = load_mnist() + x_train, y_train, x_test, y_test = load_cifar10() # Instantiate the tf.Estimator. mnist_classifier = tf.estimator.Estimator( - model_fn=cnn_model_fn, model_dir=FLAGS.model_dir) + model_fn=small_cnn_fn, model_dir=FLAGS.model_dir) # A function to construct input_fn given (data, label), to be used by the # membership inference training hook. def input_fn_constructor(x, y): return tf.estimator.inputs.numpy_input_fn(x={'x': x}, y=y, shuffle=False) - with tf.Graph().as_default(): - # Get a summary writer for the hook to write to tensorboard. - # Can set summary_writer to None if not needed. - if FLAGS.model_dir: - summary_writer = tf.summary.FileWriter(FLAGS.model_dir) - else: - summary_writer = None - mia_hook = MembershipInferenceTrainingHook( - mnist_classifier, (train_data, train_labels), (test_data, test_labels), - input_fn_constructor, - attack_types=[AttackType.THRESHOLD_ATTACK], - writer=summary_writer) + # Get hook for membership inference attack. + mia_hook = MembershipInferenceTrainingHook( + mnist_classifier, + (x_train, y_train), + (x_test, y_test), + input_fn_constructor, + slicing_spec=SlicingSpec(entire_dataset=True, by_class=True), + attack_types=[AttackType.THRESHOLD_ATTACK, + AttackType.K_NEAREST_NEIGHBORS], + tensorboard_dir=FLAGS.model_dir, + tensorboard_merge_classifiers=FLAGS.tensorboard_merge_classifiers) # Create tf.Estimator input functions for the training and test data. train_input_fn = tf.estimator.inputs.numpy_input_fn( - x={'x': train_data}, - y=train_labels, + x={'x': x_train}, + y=y_train, batch_size=FLAGS.batch_size, num_epochs=FLAGS.epochs, shuffle=True) eval_input_fn = tf.estimator.inputs.numpy_input_fn( - x={'x': test_data}, y=test_labels, num_epochs=1, shuffle=False) + x={'x': x_test}, y=y_test, num_epochs=1, shuffle=False) # Training loop. steps_per_epoch = 60000 // FLAGS.batch_size @@ -151,18 +146,15 @@ def main(unused_argv): print('End of training attack') attack_results = run_attack_on_tf_estimator_model( - mnist_classifier, (train_data, train_labels), (test_data, test_labels), + mnist_classifier, (x_train, y_train), (x_test, y_test), input_fn_constructor, slicing_spec=SlicingSpec(entire_dataset=True, by_class=True), - attack_types=[ - AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS - ]) - attack_properties, attack_values = get_flattened_attack_metrics( + attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS] + ) + att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( attack_results) - print('\n'.join([ - ' %s: %.4f' % (', '.join(p), r) - for p, r in zip(attack_properties, attack_values) - ])) + print('\n'.join([' %s: %.4f' % (', '.join([s, t, m]), v) for t, s, m, v in + zip(att_types, att_slices, att_metrics, att_values)])) if __name__ == '__main__': diff --git a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_test.py b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_test.py index f8f8ce2..44c618f 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_test.py @@ -88,9 +88,12 @@ class UtilsTest(absltest.TestCase): self.test_labels, attack_types=[AttackType.THRESHOLD_ATTACK]) self.assertIsInstance(results, AttackResults) - attack_properties, attack_values = get_flattened_attack_metrics(results) - self.assertLen(attack_properties, 2) - self.assertLen(attack_values, 2) + att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( + results) + self.assertLen(att_types, 2) + self.assertLen(att_slices, 2) + self.assertLen(att_metrics, 2) + self.assertLen(att_values, 2) def test_run_attack_on_tf_estimator_model(self): """Test the attack on the final models.""" @@ -104,9 +107,12 @@ class UtilsTest(absltest.TestCase): input_fn_constructor, attack_types=[AttackType.THRESHOLD_ATTACK]) self.assertIsInstance(results, AttackResults) - attack_properties, attack_values = get_flattened_attack_metrics(results) - self.assertLen(attack_properties, 2) - self.assertLen(attack_values, 2) + att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( + results) + self.assertLen(att_types, 2) + self.assertLen(att_slices, 2) + self.assertLen(att_metrics, 2) + self.assertLen(att_values, 2) if __name__ == '__main__': diff --git a/tensorflow_privacy/privacy/membership_inference_attack/utils.py b/tensorflow_privacy/privacy/membership_inference_attack/utils.py index 67d7e99..24785a7 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/utils.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/utils.py @@ -20,7 +20,7 @@ from typing import Text, Dict, Union, List, Any, Tuple import numpy as np import scipy.special from sklearn import metrics -import tensorflow.compat.v1 as tf + ArrayDict = Dict[Text, np.ndarray] Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]] @@ -229,10 +229,12 @@ def log_loss(labels: np.ndarray, pred: np.ndarray, small_value=1e-8): """Compute the cross entropy loss. Args: - labels: numpy array, labels[i] is the true label (scalar) of the i-th sample - pred: numpy array, pred[i] is the probability vector of the i-th sample - small_value: np.log can become -inf if the probability is too close to 0, so - the probability is clipped below by small_value. + labels: numpy array of shape (num_samples,) labels[i] is the true label + (scalar) of the i-th sample + pred: numpy array of shape(num_samples, num_classes) where pred[i] is the + probability vector of the i-th sample + small_value: a scalar. np.log can become -inf if the probability is too + close to 0, so the probability is clipped below by small_value. Returns: the cross-entropy loss of each sample @@ -243,26 +245,3 @@ def log_loss(labels: np.ndarray, pred: np.ndarray, small_value=1e-8): def log_loss_from_logits(labels: np.ndarray, logits: np.ndarray): """Compute the cross entropy loss from logits.""" return log_loss(labels, scipy.special.softmax(logits, axis=-1)) - - -# ------------------------------------------------------------------------------ -# Tensorboard -# ------------------------------------------------------------------------------ - - -def write_to_tensorboard(writer, tags, values, step): - """Write metrics to tensorboard. - - Args: - writer: tensorboard writer - tags: a list of tags of metrics - values: a list of values of metrics - step: step for the summary - """ - if writer is None: - return - summary = tf.Summary() - for tag, val in zip(tags, values): - summary.value.add(tag=tag, simple_value=val) - writer.add_summary(summary, step) - writer.flush() diff --git a/tensorflow_privacy/privacy/membership_inference_attack/utils_tensorboard.py b/tensorflow_privacy/privacy/membership_inference_attack/utils_tensorboard.py new file mode 100644 index 0000000..799ca18 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/utils_tensorboard.py @@ -0,0 +1,79 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Utility functions for writing attack results to tensorboard.""" +from typing import List +from typing import Union + +import tensorflow.compat.v1 as tf +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics + + +def write_to_tensorboard(writers, tags, values, step): + """Write metrics to tensorboard. + + Args: + writers: a list of tensorboard writers or one writer to be used for metrics. + If it's a list, it should be of the same length as tags + tags: a list of tags of metrics + values: a list of values of metrics with the same length as tags + step: step for the tensorboard summary + """ + if writers is None or not writers: + raise ValueError('write_to_tensorboard does not get any writer.') + + if not isinstance(writers, list): + writers = [writers] * len(tags) + + assert len(writers) == len(tags) == len(values) + + for writer, tag, val in zip(writers, tags, values): + summary = tf.Summary() + summary.value.add(tag=tag, simple_value=val) + writer.add_summary(summary, step) + + for writer in set(writers): + writer.flush() + + +def write_results_to_tensorboard( + attack_results: AttackResults, + writers: Union[tf.summary.FileWriter, List[tf.summary.FileWriter]], + step: int, + merge_classifiers: bool): + """Write attack results to tensorboard. + + Args: + attack_results: results from attack + writers: a list of tensorboard writers or one writer to be used for metrics + step: step for the tensorboard summary + merge_classifiers: if true, plot different classifiers with the same + slicing_spec and metric in the same figure + """ + if writers is None or not writers: + raise ValueError('write_results_to_tensorboard does not get any writer.') + + att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( + attack_results) + if merge_classifiers: + att_tags = ['attack/' + '_'.join([s, m]) for s, m in + zip(att_slices, att_metrics)] + write_to_tensorboard([writers[t] for t in att_types], + att_tags, att_values, step) + else: + att_tags = ['attack/' + '_'.join([s, t, m]) for t, s, m in + zip(att_types, att_slices, att_metrics)] + write_to_tensorboard(writers, att_tags, att_values, step)