From 8d53d8cc59c65a671865fb8d16f8f1ef4e87c823 Mon Sep 17 00:00:00 2001 From: Shuang Song Date: Tue, 29 Dec 2020 11:17:43 -0800 Subject: [PATCH] Write to Tensorboard in Keras under TF2. PiperOrigin-RevId: 349446504 --- .../keras_evaluation.py | 16 ++--- .../keras_evaluation_example.py | 2 +- .../tf_estimator_evaluation_example.py | 12 ++-- .../utils_tensorboard.py | 70 +++++++++++++++++-- 4 files changed, 79 insertions(+), 21 deletions(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py index 5f504cb..54354c5 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py @@ -19,7 +19,7 @@ import os from typing import Iterable from absl import logging -import tensorflow.compat.v1 as tf +import tensorflow as tf from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData @@ -27,7 +27,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss -from tensorflow_privacy.privacy.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard +from tensorflow_privacy.privacy.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard def calculate_losses(model, data, labels): @@ -76,14 +76,12 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback): if tensorboard_dir: if tensorboard_merge_classifiers: self._writers = {} - with tf.Graph().as_default(): - for attack_type in attack_types: - self._writers[attack_type.name] = tf.summary.FileWriter( - os.path.join(tensorboard_dir, 'MI', attack_type.name)) + for attack_type in attack_types: + self._writers[attack_type.name] = tf.summary.create_file_writer( + os.path.join(tensorboard_dir, 'MI', attack_type.name)) else: - with tf.Graph().as_default(): - self._writers = tf.summary.FileWriter( - os.path.join(tensorboard_dir, 'MI')) + self._writers = tf.summary.create_file_writer( + os.path.join(tensorboard_dir, 'MI')) logging.info('Will write to tensorboard.') else: self._writers = None diff --git a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_example.py b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_example.py index 4a19552..26862b8 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_example.py @@ -19,7 +19,7 @@ from absl import app from absl import flags import numpy as np -import tensorflow.compat.v1 as tf +import tensorflow as tf from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec diff --git a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py index 477848f..6f97aa1 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py @@ -102,7 +102,7 @@ def main(unused_argv): x_train, y_train, x_test, y_test = load_cifar10() # Instantiate the tf.Estimator. - mnist_classifier = tf.estimator.Estimator( + classifier = tf.estimator.Estimator( model_fn=small_cnn_fn, model_dir=FLAGS.model_dir) # A function to construct input_fn given (data, label), to be used by the @@ -112,7 +112,7 @@ def main(unused_argv): # Get hook for membership inference attack. mia_hook = MembershipInferenceTrainingHook( - mnist_classifier, + classifier, (x_train, y_train), (x_test, y_test), input_fn_constructor, @@ -133,20 +133,20 @@ def main(unused_argv): x={'x': x_test}, y=y_test, num_epochs=1, shuffle=False) # Training loop. - steps_per_epoch = 60000 // FLAGS.batch_size + steps_per_epoch = 50000 // FLAGS.batch_size for epoch in range(1, FLAGS.epochs + 1): # Train the model, with the membership inference hook. - mnist_classifier.train( + classifier.train( input_fn=train_input_fn, steps=steps_per_epoch, hooks=[mia_hook]) # Evaluate the model and print results - eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) + eval_results = classifier.evaluate(input_fn=eval_input_fn) test_accuracy = eval_results['accuracy'] print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy)) print('End of training attack') attack_results = run_attack_on_tf_estimator_model( - mnist_classifier, (x_train, y_train), (x_test, y_test), + classifier, (x_train, y_train), (x_test, y_test), input_fn_constructor, slicing_spec=SlicingSpec(entire_dataset=True, by_class=True), attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS] diff --git a/tensorflow_privacy/privacy/membership_inference_attack/utils_tensorboard.py b/tensorflow_privacy/privacy/membership_inference_attack/utils_tensorboard.py index 799ca18..adedf33 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/utils_tensorboard.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/utils_tensorboard.py @@ -17,7 +17,8 @@ from typing import List from typing import Union -import tensorflow.compat.v1 as tf +import tensorflow as tf2 +import tensorflow.compat.v1 as tf1 from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics @@ -41,7 +42,7 @@ def write_to_tensorboard(writers, tags, values, step): assert len(writers) == len(tags) == len(values) for writer, tag, val in zip(writers, tags, values): - summary = tf.Summary() + summary = tf1.Summary() summary.value.add(tag=tag, simple_value=val) writer.add_summary(summary, step) @@ -49,9 +50,37 @@ def write_to_tensorboard(writers, tags, values, step): writer.flush() +def write_to_tensorboard_tf2(writers, tags, values, step): + """Write metrics to tensorboard. + + Args: + writers: a list of tensorboard writers or one writer to be used for metrics. + If it's a list, it should be of the same length as tags + tags: a list of tags of metrics + values: a list of values of metrics with the same length as tags + step: step for the tensorboard summary + """ + if writers is None or not writers: + raise ValueError('write_to_tensorboard does not get any writer.') + + if not isinstance(writers, list): + writers = [writers] * len(tags) + + assert len(writers) == len(tags) == len(values) + + for writer, tag, val in zip(writers, tags, values): + with writer.as_default(): + tf2.summary.scalar(tag, val, step=step) + writer.flush() + + for writer in set(writers): + with writer.as_default(): + writer.flush() + + def write_results_to_tensorboard( attack_results: AttackResults, - writers: Union[tf.summary.FileWriter, List[tf.summary.FileWriter]], + writers: Union[tf1.summary.FileWriter, List[tf1.summary.FileWriter]], step: int, merge_classifiers: bool): """Write attack results to tensorboard. @@ -69,11 +98,42 @@ def write_results_to_tensorboard( att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( attack_results) if merge_classifiers: - att_tags = ['attack/' + '_'.join([s, m]) for s, m in + att_tags = ['attack/' + f'{s}_{m}' for s, m in zip(att_slices, att_metrics)] write_to_tensorboard([writers[t] for t in att_types], att_tags, att_values, step) else: - att_tags = ['attack/' + '_'.join([s, t, m]) for t, s, m in + att_tags = ['attack/' + f'{s}_{t}_{m}' for t, s, m in zip(att_types, att_slices, att_metrics)] write_to_tensorboard(writers, att_tags, att_values, step) + + +def write_results_to_tensorboard_tf2( + attack_results: AttackResults, + writers: Union[tf2.summary.SummaryWriter, List[tf2.summary.SummaryWriter]], + step: int, + merge_classifiers: bool): + """Write attack results to tensorboard. + + Args: + attack_results: results from attack + writers: a list of tensorboard writers or one writer to be used for metrics + step: step for the tensorboard summary + merge_classifiers: if true, plot different classifiers with the same + slicing_spec and metric in the same figure + """ + if writers is None or not writers: + raise ValueError('write_results_to_tensorboard does not get any writer.') + + att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( + attack_results) + if merge_classifiers: + att_tags = ['attack/' + f'{s}_{m}' for s, m in + zip(att_slices, att_metrics)] + write_to_tensorboard_tf2([writers[t] for t in att_types], + att_tags, att_values, step) + else: + att_tags = ['attack/' + f'{s}_{t}_{m}' for t, s, m in + zip(att_types, att_slices, att_metrics)] + write_to_tensorboard_tf2(writers, att_tags, att_values, step) +