A separate tensorboard function.

PiperOrigin-RevId: 322820408
This commit is contained in:
Shuang Song 2020-07-23 10:55:24 -07:00 committed by A. Unique TensorFlower
parent 2ec0f36d1e
commit 267ea7f90d
2 changed files with 31 additions and 8 deletions

View file

@ -23,6 +23,7 @@ import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
def calculate_losses(estimator, input_fn, labels): def calculate_losses(estimator, input_fn, labels):
@ -36,7 +37,7 @@ def calculate_losses(estimator, input_fn, labels):
Args: Args:
estimator: model to make prediction estimator: model to make prediction
input_fn: input function to be used in estimator.predict input_fn: input function to be used in estimator.predict
labels: true labels of samples labels: true labels of samples (integer valued)
Returns: Returns:
preds: probability vector of each sample preds: probability vector of each sample
@ -92,13 +93,10 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage']) print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
logging.info(results) logging.info(results)
if self._writer: # Write to tensorboard if writer is specified
summary = tf.Summary()
summary.value.add(tag='attack advantage',
simple_value=results['all_thresh_loss_advantage'])
global_step = self._estimator.get_variable_value('global_step') global_step = self._estimator.get_variable_value('global_step')
self._writer.add_summary(summary, global_step) write_to_tensorboard(self._writer, ['attack advantage'],
self._writer.flush() [results['all_thresh_loss_advantage']], global_step)
def run_attack_on_tf_estimator_model(estimator, in_train, out_train, def run_attack_on_tf_estimator_model(estimator, in_train, out_train,

View file

@ -19,6 +19,8 @@ from typing import Text, Dict, Union, List, Any, Tuple
import numpy as np import numpy as np
from sklearn import metrics from sklearn import metrics
import tensorflow.compat.v1 as tf
ArrayDict = Dict[Text, np.ndarray] ArrayDict = Dict[Text, np.ndarray]
Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]] Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
@ -236,3 +238,26 @@ def log_loss(y, pred, small_value=1e-8):
the cross-entropy loss of each sample the cross-entropy loss of each sample
""" """
return -np.log(np.maximum(pred[range(y.size), y], small_value)) return -np.log(np.maximum(pred[range(y.size), y], small_value))
# ------------------------------------------------------------------------------
# Tensorboard
# ------------------------------------------------------------------------------
def write_to_tensorboard(writer, tags, values, step):
"""Write metrics to tensorboard.
Args:
writer: tensorboard writer
tags: a list of tags of metrics
values: a list of values of metrics
step: step for the summary
"""
if writer is None:
return
summary = tf.Summary()
for tag, val in zip(tags, values):
summary.value.add(tag=tag, simple_value=val)
writer.add_summary(summary, step)
writer.flush()