forked from 626_privacy/tensorflow_privacy
A separate tensorboard function.
PiperOrigin-RevId: 322820408
This commit is contained in:
parent
2ec0f36d1e
commit
267ea7f90d
2 changed files with 31 additions and 8 deletions
|
@ -23,6 +23,7 @@ import tensorflow.compat.v1 as tf
|
|||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
|
||||
|
||||
|
||||
def calculate_losses(estimator, input_fn, labels):
|
||||
|
@ -36,7 +37,7 @@ def calculate_losses(estimator, input_fn, labels):
|
|||
Args:
|
||||
estimator: model to make prediction
|
||||
input_fn: input function to be used in estimator.predict
|
||||
labels: true labels of samples
|
||||
labels: true labels of samples (integer valued)
|
||||
|
||||
Returns:
|
||||
preds: probability vector of each sample
|
||||
|
@ -92,13 +93,10 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
|||
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
|
||||
logging.info(results)
|
||||
|
||||
if self._writer:
|
||||
summary = tf.Summary()
|
||||
summary.value.add(tag='attack advantage',
|
||||
simple_value=results['all_thresh_loss_advantage'])
|
||||
global_step = self._estimator.get_variable_value('global_step')
|
||||
self._writer.add_summary(summary, global_step)
|
||||
self._writer.flush()
|
||||
# Write to tensorboard if writer is specified
|
||||
global_step = self._estimator.get_variable_value('global_step')
|
||||
write_to_tensorboard(self._writer, ['attack advantage'],
|
||||
[results['all_thresh_loss_advantage']], global_step)
|
||||
|
||||
|
||||
def run_attack_on_tf_estimator_model(estimator, in_train, out_train,
|
||||
|
|
|
@ -19,6 +19,8 @@ from typing import Text, Dict, Union, List, Any, Tuple
|
|||
|
||||
import numpy as np
|
||||
from sklearn import metrics
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
ArrayDict = Dict[Text, np.ndarray]
|
||||
Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
|
||||
|
@ -236,3 +238,26 @@ def log_loss(y, pred, small_value=1e-8):
|
|||
the cross-entropy loss of each sample
|
||||
"""
|
||||
return -np.log(np.maximum(pred[range(y.size), y], small_value))
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Tensorboard
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def write_to_tensorboard(writer, tags, values, step):
|
||||
"""Write metrics to tensorboard.
|
||||
|
||||
Args:
|
||||
writer: tensorboard writer
|
||||
tags: a list of tags of metrics
|
||||
values: a list of values of metrics
|
||||
step: step for the summary
|
||||
"""
|
||||
if writer is None:
|
||||
return
|
||||
summary = tf.Summary()
|
||||
for tag, val in zip(tags, values):
|
||||
summary.value.add(tag=tag, simple_value=val)
|
||||
writer.add_summary(summary, step)
|
||||
writer.flush()
|
||||
|
|
Loading…
Reference in a new issue