A separate tensorboard function.
PiperOrigin-RevId: 322820408
This commit is contained in:
parent
2ec0f36d1e
commit
267ea7f90d
2 changed files with 31 additions and 8 deletions
|
@ -23,6 +23,7 @@ import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
|
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
|
||||||
|
|
||||||
|
|
||||||
def calculate_losses(estimator, input_fn, labels):
|
def calculate_losses(estimator, input_fn, labels):
|
||||||
|
@ -36,7 +37,7 @@ def calculate_losses(estimator, input_fn, labels):
|
||||||
Args:
|
Args:
|
||||||
estimator: model to make prediction
|
estimator: model to make prediction
|
||||||
input_fn: input function to be used in estimator.predict
|
input_fn: input function to be used in estimator.predict
|
||||||
labels: true labels of samples
|
labels: true labels of samples (integer valued)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
preds: probability vector of each sample
|
preds: probability vector of each sample
|
||||||
|
@ -92,13 +93,10 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
||||||
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
|
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
|
||||||
logging.info(results)
|
logging.info(results)
|
||||||
|
|
||||||
if self._writer:
|
# Write to tensorboard if writer is specified
|
||||||
summary = tf.Summary()
|
|
||||||
summary.value.add(tag='attack advantage',
|
|
||||||
simple_value=results['all_thresh_loss_advantage'])
|
|
||||||
global_step = self._estimator.get_variable_value('global_step')
|
global_step = self._estimator.get_variable_value('global_step')
|
||||||
self._writer.add_summary(summary, global_step)
|
write_to_tensorboard(self._writer, ['attack advantage'],
|
||||||
self._writer.flush()
|
[results['all_thresh_loss_advantage']], global_step)
|
||||||
|
|
||||||
|
|
||||||
def run_attack_on_tf_estimator_model(estimator, in_train, out_train,
|
def run_attack_on_tf_estimator_model(estimator, in_train, out_train,
|
||||||
|
|
|
@ -19,6 +19,8 @@ from typing import Text, Dict, Union, List, Any, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
|
||||||
ArrayDict = Dict[Text, np.ndarray]
|
ArrayDict = Dict[Text, np.ndarray]
|
||||||
Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
|
Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
|
||||||
|
@ -236,3 +238,26 @@ def log_loss(y, pred, small_value=1e-8):
|
||||||
the cross-entropy loss of each sample
|
the cross-entropy loss of each sample
|
||||||
"""
|
"""
|
||||||
return -np.log(np.maximum(pred[range(y.size), y], small_value))
|
return -np.log(np.maximum(pred[range(y.size), y], small_value))
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Tensorboard
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def write_to_tensorboard(writer, tags, values, step):
|
||||||
|
"""Write metrics to tensorboard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
writer: tensorboard writer
|
||||||
|
tags: a list of tags of metrics
|
||||||
|
values: a list of values of metrics
|
||||||
|
step: step for the summary
|
||||||
|
"""
|
||||||
|
if writer is None:
|
||||||
|
return
|
||||||
|
summary = tf.Summary()
|
||||||
|
for tag, val in zip(tags, values):
|
||||||
|
summary.value.add(tag=tag, simple_value=val)
|
||||||
|
writer.add_summary(summary, step)
|
||||||
|
writer.flush()
|
||||||
|
|
Loading…
Reference in a new issue