diff --git a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py index 5566410..693820e 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py @@ -23,6 +23,7 @@ import tensorflow.compat.v1 as tf from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss +from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard def calculate_losses(estimator, input_fn, labels): @@ -36,7 +37,7 @@ def calculate_losses(estimator, input_fn, labels): Args: estimator: model to make prediction input_fn: input function to be used in estimator.predict - labels: true labels of samples + labels: true labels of samples (integer valued) Returns: preds: probability vector of each sample @@ -92,13 +93,10 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook): print('all_thresh_loss_advantage', results['all_thresh_loss_advantage']) logging.info(results) - if self._writer: - summary = tf.Summary() - summary.value.add(tag='attack advantage', - simple_value=results['all_thresh_loss_advantage']) - global_step = self._estimator.get_variable_value('global_step') - self._writer.add_summary(summary, global_step) - self._writer.flush() + # Write to tensorboard if writer is specified + global_step = self._estimator.get_variable_value('global_step') + write_to_tensorboard(self._writer, ['attack advantage'], + [results['all_thresh_loss_advantage']], global_step) def run_attack_on_tf_estimator_model(estimator, in_train, out_train, diff --git a/tensorflow_privacy/privacy/membership_inference_attack/utils.py b/tensorflow_privacy/privacy/membership_inference_attack/utils.py index 82e30e9..d3aa5c6 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/utils.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/utils.py @@ -19,6 +19,8 @@ from typing import Text, Dict, Union, List, Any, Tuple import numpy as np from sklearn import metrics +import tensorflow.compat.v1 as tf + ArrayDict = Dict[Text, np.ndarray] Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]] @@ -236,3 +238,26 @@ def log_loss(y, pred, small_value=1e-8): the cross-entropy loss of each sample """ return -np.log(np.maximum(pred[range(y.size), y], small_value)) + + +# ------------------------------------------------------------------------------ +# Tensorboard +# ------------------------------------------------------------------------------ + + +def write_to_tensorboard(writer, tags, values, step): + """Write metrics to tensorboard. + + Args: + writer: tensorboard writer + tags: a list of tags of metrics + values: a list of values of metrics + step: step for the summary + """ + if writer is None: + return + summary = tf.Summary() + for tag, val in zip(tags, values): + summary.value.add(tag=tag, simple_value=val) + writer.add_summary(summary, step) + writer.flush()