Write to Tensorboard in Keras under TF2.

PiperOrigin-RevId: 349446504
This commit is contained in:
Shuang Song 2020-12-29 11:17:43 -08:00 committed by A. Unique TensorFlower
parent b6413a4ea9
commit 8d53d8cc59
4 changed files with 79 additions and 21 deletions

View file

@ -19,7 +19,7 @@ import os
from typing import Iterable from typing import Iterable
from absl import logging from absl import logging
import tensorflow.compat.v1 as tf import tensorflow as tf
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
@ -27,7 +27,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
from tensorflow_privacy.privacy.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard from tensorflow_privacy.privacy.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard
def calculate_losses(model, data, labels): def calculate_losses(model, data, labels):
@ -76,13 +76,11 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
if tensorboard_dir: if tensorboard_dir:
if tensorboard_merge_classifiers: if tensorboard_merge_classifiers:
self._writers = {} self._writers = {}
with tf.Graph().as_default():
for attack_type in attack_types: for attack_type in attack_types:
self._writers[attack_type.name] = tf.summary.FileWriter( self._writers[attack_type.name] = tf.summary.create_file_writer(
os.path.join(tensorboard_dir, 'MI', attack_type.name)) os.path.join(tensorboard_dir, 'MI', attack_type.name))
else: else:
with tf.Graph().as_default(): self._writers = tf.summary.create_file_writer(
self._writers = tf.summary.FileWriter(
os.path.join(tensorboard_dir, 'MI')) os.path.join(tensorboard_dir, 'MI'))
logging.info('Will write to tensorboard.') logging.info('Will write to tensorboard.')
else: else:

View file

@ -19,7 +19,7 @@ from absl import app
from absl import flags from absl import flags
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow as tf
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec

View file

@ -102,7 +102,7 @@ def main(unused_argv):
x_train, y_train, x_test, y_test = load_cifar10() x_train, y_train, x_test, y_test = load_cifar10()
# Instantiate the tf.Estimator. # Instantiate the tf.Estimator.
mnist_classifier = tf.estimator.Estimator( classifier = tf.estimator.Estimator(
model_fn=small_cnn_fn, model_dir=FLAGS.model_dir) model_fn=small_cnn_fn, model_dir=FLAGS.model_dir)
# A function to construct input_fn given (data, label), to be used by the # A function to construct input_fn given (data, label), to be used by the
@ -112,7 +112,7 @@ def main(unused_argv):
# Get hook for membership inference attack. # Get hook for membership inference attack.
mia_hook = MembershipInferenceTrainingHook( mia_hook = MembershipInferenceTrainingHook(
mnist_classifier, classifier,
(x_train, y_train), (x_train, y_train),
(x_test, y_test), (x_test, y_test),
input_fn_constructor, input_fn_constructor,
@ -133,20 +133,20 @@ def main(unused_argv):
x={'x': x_test}, y=y_test, num_epochs=1, shuffle=False) x={'x': x_test}, y=y_test, num_epochs=1, shuffle=False)
# Training loop. # Training loop.
steps_per_epoch = 60000 // FLAGS.batch_size steps_per_epoch = 50000 // FLAGS.batch_size
for epoch in range(1, FLAGS.epochs + 1): for epoch in range(1, FLAGS.epochs + 1):
# Train the model, with the membership inference hook. # Train the model, with the membership inference hook.
mnist_classifier.train( classifier.train(
input_fn=train_input_fn, steps=steps_per_epoch, hooks=[mia_hook]) input_fn=train_input_fn, steps=steps_per_epoch, hooks=[mia_hook])
# Evaluate the model and print results # Evaluate the model and print results
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) eval_results = classifier.evaluate(input_fn=eval_input_fn)
test_accuracy = eval_results['accuracy'] test_accuracy = eval_results['accuracy']
print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy)) print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))
print('End of training attack') print('End of training attack')
attack_results = run_attack_on_tf_estimator_model( attack_results = run_attack_on_tf_estimator_model(
mnist_classifier, (x_train, y_train), (x_test, y_test), classifier, (x_train, y_train), (x_test, y_test),
input_fn_constructor, input_fn_constructor,
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True), slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS] attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS]

View file

@ -17,7 +17,8 @@
from typing import List from typing import List
from typing import Union from typing import Union
import tensorflow.compat.v1 as tf import tensorflow as tf2
import tensorflow.compat.v1 as tf1
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
@ -41,7 +42,7 @@ def write_to_tensorboard(writers, tags, values, step):
assert len(writers) == len(tags) == len(values) assert len(writers) == len(tags) == len(values)
for writer, tag, val in zip(writers, tags, values): for writer, tag, val in zip(writers, tags, values):
summary = tf.Summary() summary = tf1.Summary()
summary.value.add(tag=tag, simple_value=val) summary.value.add(tag=tag, simple_value=val)
writer.add_summary(summary, step) writer.add_summary(summary, step)
@ -49,9 +50,37 @@ def write_to_tensorboard(writers, tags, values, step):
writer.flush() writer.flush()
def write_to_tensorboard_tf2(writers, tags, values, step):
"""Write metrics to tensorboard.
Args:
writers: a list of tensorboard writers or one writer to be used for metrics.
If it's a list, it should be of the same length as tags
tags: a list of tags of metrics
values: a list of values of metrics with the same length as tags
step: step for the tensorboard summary
"""
if writers is None or not writers:
raise ValueError('write_to_tensorboard does not get any writer.')
if not isinstance(writers, list):
writers = [writers] * len(tags)
assert len(writers) == len(tags) == len(values)
for writer, tag, val in zip(writers, tags, values):
with writer.as_default():
tf2.summary.scalar(tag, val, step=step)
writer.flush()
for writer in set(writers):
with writer.as_default():
writer.flush()
def write_results_to_tensorboard( def write_results_to_tensorboard(
attack_results: AttackResults, attack_results: AttackResults,
writers: Union[tf.summary.FileWriter, List[tf.summary.FileWriter]], writers: Union[tf1.summary.FileWriter, List[tf1.summary.FileWriter]],
step: int, step: int,
merge_classifiers: bool): merge_classifiers: bool):
"""Write attack results to tensorboard. """Write attack results to tensorboard.
@ -69,11 +98,42 @@ def write_results_to_tensorboard(
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
attack_results) attack_results)
if merge_classifiers: if merge_classifiers:
att_tags = ['attack/' + '_'.join([s, m]) for s, m in att_tags = ['attack/' + f'{s}_{m}' for s, m in
zip(att_slices, att_metrics)] zip(att_slices, att_metrics)]
write_to_tensorboard([writers[t] for t in att_types], write_to_tensorboard([writers[t] for t in att_types],
att_tags, att_values, step) att_tags, att_values, step)
else: else:
att_tags = ['attack/' + '_'.join([s, t, m]) for t, s, m in att_tags = ['attack/' + f'{s}_{t}_{m}' for t, s, m in
zip(att_types, att_slices, att_metrics)] zip(att_types, att_slices, att_metrics)]
write_to_tensorboard(writers, att_tags, att_values, step) write_to_tensorboard(writers, att_tags, att_values, step)
def write_results_to_tensorboard_tf2(
attack_results: AttackResults,
writers: Union[tf2.summary.SummaryWriter, List[tf2.summary.SummaryWriter]],
step: int,
merge_classifiers: bool):
"""Write attack results to tensorboard.
Args:
attack_results: results from attack
writers: a list of tensorboard writers or one writer to be used for metrics
step: step for the tensorboard summary
merge_classifiers: if true, plot different classifiers with the same
slicing_spec and metric in the same figure
"""
if writers is None or not writers:
raise ValueError('write_results_to_tensorboard does not get any writer.')
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
attack_results)
if merge_classifiers:
att_tags = ['attack/' + f'{s}_{m}' for s, m in
zip(att_slices, att_metrics)]
write_to_tensorboard_tf2([writers[t] for t in att_types],
att_tags, att_values, step)
else:
att_tags = ['attack/' + f'{s}_{t}_{m}' for t, s, m in
zip(att_types, att_slices, att_metrics)]
write_to_tensorboard_tf2(writers, att_tags, att_values, step)