From f301595ba55c2dab6c093f7da28ca5fa32260aa1 Mon Sep 17 00:00:00 2001 From: Shuang Song Date: Mon, 10 Jan 2022 09:58:12 -0800 Subject: [PATCH] Fix a bug in keras_evaluation and its example. PiperOrigin-RevId: 420787967 --- .../keras_evaluation.py | 35 +++++++++++++------ .../keras_evaluation_example.py | 6 ++-- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation.py index 052132d..6091a59 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation.py @@ -30,19 +30,23 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard -def calculate_losses(model, data, labels): +def calculate_losses(model, data, labels, is_logit=False, batch_size=32): """Calculate losses of model prediction on data, provided true labels. Args: model: model to make prediction data: samples labels: true labels of samples (integer valued) + is_logit: whether the result of model.predict is logit or probability + batch_size: the batch size for model.predict Returns: - preds: probability vector of each sample + pred: probability vector of each sample loss: cross entropy loss of each sample """ - pred = model.predict(data) + pred = model.predict(data, batch_size=batch_size) + if is_logit: + pred = tf.nn.softmax(pred).numpy() loss = log_loss(labels, pred) return pred, loss @@ -56,7 +60,9 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback): slicing_spec: SlicingSpec = None, attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,), tensorboard_dir=None, - tensorboard_merge_classifiers=False): + tensorboard_merge_classifiers=False, + is_logit=False, + batch_size=32): """Initalizes the callback. Args: @@ -67,12 +73,16 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback): tensorboard_dir: directory for tensorboard summary tensorboard_merge_classifiers: if true, plot different classifiers with the same slicing_spec and metric in the same figure + is_logit: whether the result of model.predict is logit or probability + batch_size: the batch size for model.predict """ self._in_train_data, self._in_train_labels = in_train self._out_train_data, self._out_train_labels = out_train self._slicing_spec = slicing_spec self._attack_types = attack_types self._tensorboard_merge_classifiers = tensorboard_merge_classifiers + self._is_logit = is_logit + self._batch_size = batch_size if tensorboard_dir: if tensorboard_merge_classifiers: self._writers = {} @@ -92,7 +102,8 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback): (self._in_train_data, self._in_train_labels), (self._out_train_data, self._out_train_labels), self._slicing_spec, - self._attack_types) + self._attack_types, + self._is_logit, self._batch_size) logging.info(results) att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( @@ -110,7 +121,9 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback): def run_attack_on_keras_model( model, in_train, out_train, slicing_spec: SlicingSpec = None, - attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)): + attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,), + is_logit: bool = False, + batch_size: int = 32): """Performs the attack on a trained model. Args: @@ -119,6 +132,8 @@ def run_attack_on_keras_model( out_train: a (out_training samples, out_training labels) tuple slicing_spec: slicing specification of the attack attack_types: a list of attacks, each of type AttackType + is_logit: whether the result of model.predict is logit or probability + batch_size: the batch size for model.predict Returns: Results of the attack """ @@ -126,10 +141,10 @@ def run_attack_on_keras_model( out_train_data, out_train_labels = out_train # Compute predictions and losses - in_train_pred, in_train_loss = calculate_losses(model, in_train_data, - in_train_labels) - out_train_pred, out_train_loss = calculate_losses(model, out_train_data, - out_train_labels) + in_train_pred, in_train_loss = calculate_losses( + model, in_train_data, in_train_labels, is_logit, batch_size) + out_train_pred, out_train_loss = calculate_losses( + model, out_train_data, out_train_labels, is_logit, batch_size) attack_input = AttackInputData( logits_train=in_train_pred, logits_test=out_train_pred, labels_train=in_train_labels, labels_test=out_train_labels, diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation_example.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation_example.py index d000c55..2bb7278 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation_example.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation_example.py @@ -83,7 +83,8 @@ def main(unused_argv): attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS], tensorboard_dir=FLAGS.model_dir, - tensorboard_merge_classifiers=FLAGS.tensorboard_merge_classifiers) + tensorboard_merge_classifiers=FLAGS.tensorboard_merge_classifiers, + is_logit=True, batch_size=2048) # Train model with Keras model.fit( @@ -101,7 +102,8 @@ def main(unused_argv): slicing_spec=SlicingSpec(entire_dataset=True, by_class=True), attack_types=[ AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS - ]) + ], + is_logit=True, batch_size=2048) att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( attack_results) print('\n'.join([' %s: %.4f' % (', '.join([s, t, m]), v) for t, s, m, v in