Write to Tensorboard in Keras under TF2.
PiperOrigin-RevId: 349446504
This commit is contained in:
parent
b6413a4ea9
commit
8d53d8cc59
4 changed files with 79 additions and 21 deletions
|
@ -19,7 +19,7 @@ import os
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
from absl import logging
|
from absl import logging
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
|
@ -27,7 +27,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
|
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard
|
from tensorflow_privacy.privacy.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard
|
||||||
|
|
||||||
|
|
||||||
def calculate_losses(model, data, labels):
|
def calculate_losses(model, data, labels):
|
||||||
|
@ -76,13 +76,11 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
||||||
if tensorboard_dir:
|
if tensorboard_dir:
|
||||||
if tensorboard_merge_classifiers:
|
if tensorboard_merge_classifiers:
|
||||||
self._writers = {}
|
self._writers = {}
|
||||||
with tf.Graph().as_default():
|
|
||||||
for attack_type in attack_types:
|
for attack_type in attack_types:
|
||||||
self._writers[attack_type.name] = tf.summary.FileWriter(
|
self._writers[attack_type.name] = tf.summary.create_file_writer(
|
||||||
os.path.join(tensorboard_dir, 'MI', attack_type.name))
|
os.path.join(tensorboard_dir, 'MI', attack_type.name))
|
||||||
else:
|
else:
|
||||||
with tf.Graph().as_default():
|
self._writers = tf.summary.create_file_writer(
|
||||||
self._writers = tf.summary.FileWriter(
|
|
||||||
os.path.join(tensorboard_dir, 'MI'))
|
os.path.join(tensorboard_dir, 'MI'))
|
||||||
logging.info('Will write to tensorboard.')
|
logging.info('Will write to tensorboard.')
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -19,7 +19,7 @@ from absl import app
|
||||||
from absl import flags
|
from absl import flags
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||||
|
|
|
@ -102,7 +102,7 @@ def main(unused_argv):
|
||||||
x_train, y_train, x_test, y_test = load_cifar10()
|
x_train, y_train, x_test, y_test = load_cifar10()
|
||||||
|
|
||||||
# Instantiate the tf.Estimator.
|
# Instantiate the tf.Estimator.
|
||||||
mnist_classifier = tf.estimator.Estimator(
|
classifier = tf.estimator.Estimator(
|
||||||
model_fn=small_cnn_fn, model_dir=FLAGS.model_dir)
|
model_fn=small_cnn_fn, model_dir=FLAGS.model_dir)
|
||||||
|
|
||||||
# A function to construct input_fn given (data, label), to be used by the
|
# A function to construct input_fn given (data, label), to be used by the
|
||||||
|
@ -112,7 +112,7 @@ def main(unused_argv):
|
||||||
|
|
||||||
# Get hook for membership inference attack.
|
# Get hook for membership inference attack.
|
||||||
mia_hook = MembershipInferenceTrainingHook(
|
mia_hook = MembershipInferenceTrainingHook(
|
||||||
mnist_classifier,
|
classifier,
|
||||||
(x_train, y_train),
|
(x_train, y_train),
|
||||||
(x_test, y_test),
|
(x_test, y_test),
|
||||||
input_fn_constructor,
|
input_fn_constructor,
|
||||||
|
@ -133,20 +133,20 @@ def main(unused_argv):
|
||||||
x={'x': x_test}, y=y_test, num_epochs=1, shuffle=False)
|
x={'x': x_test}, y=y_test, num_epochs=1, shuffle=False)
|
||||||
|
|
||||||
# Training loop.
|
# Training loop.
|
||||||
steps_per_epoch = 60000 // FLAGS.batch_size
|
steps_per_epoch = 50000 // FLAGS.batch_size
|
||||||
for epoch in range(1, FLAGS.epochs + 1):
|
for epoch in range(1, FLAGS.epochs + 1):
|
||||||
# Train the model, with the membership inference hook.
|
# Train the model, with the membership inference hook.
|
||||||
mnist_classifier.train(
|
classifier.train(
|
||||||
input_fn=train_input_fn, steps=steps_per_epoch, hooks=[mia_hook])
|
input_fn=train_input_fn, steps=steps_per_epoch, hooks=[mia_hook])
|
||||||
|
|
||||||
# Evaluate the model and print results
|
# Evaluate the model and print results
|
||||||
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
|
eval_results = classifier.evaluate(input_fn=eval_input_fn)
|
||||||
test_accuracy = eval_results['accuracy']
|
test_accuracy = eval_results['accuracy']
|
||||||
print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))
|
print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))
|
||||||
|
|
||||||
print('End of training attack')
|
print('End of training attack')
|
||||||
attack_results = run_attack_on_tf_estimator_model(
|
attack_results = run_attack_on_tf_estimator_model(
|
||||||
mnist_classifier, (x_train, y_train), (x_test, y_test),
|
classifier, (x_train, y_train), (x_test, y_test),
|
||||||
input_fn_constructor,
|
input_fn_constructor,
|
||||||
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
|
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
|
||||||
attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS]
|
attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS]
|
||||||
|
|
|
@ -17,7 +17,8 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow as tf2
|
||||||
|
import tensorflow.compat.v1 as tf1
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
||||||
|
|
||||||
|
@ -41,7 +42,7 @@ def write_to_tensorboard(writers, tags, values, step):
|
||||||
assert len(writers) == len(tags) == len(values)
|
assert len(writers) == len(tags) == len(values)
|
||||||
|
|
||||||
for writer, tag, val in zip(writers, tags, values):
|
for writer, tag, val in zip(writers, tags, values):
|
||||||
summary = tf.Summary()
|
summary = tf1.Summary()
|
||||||
summary.value.add(tag=tag, simple_value=val)
|
summary.value.add(tag=tag, simple_value=val)
|
||||||
writer.add_summary(summary, step)
|
writer.add_summary(summary, step)
|
||||||
|
|
||||||
|
@ -49,9 +50,37 @@ def write_to_tensorboard(writers, tags, values, step):
|
||||||
writer.flush()
|
writer.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def write_to_tensorboard_tf2(writers, tags, values, step):
|
||||||
|
"""Write metrics to tensorboard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
writers: a list of tensorboard writers or one writer to be used for metrics.
|
||||||
|
If it's a list, it should be of the same length as tags
|
||||||
|
tags: a list of tags of metrics
|
||||||
|
values: a list of values of metrics with the same length as tags
|
||||||
|
step: step for the tensorboard summary
|
||||||
|
"""
|
||||||
|
if writers is None or not writers:
|
||||||
|
raise ValueError('write_to_tensorboard does not get any writer.')
|
||||||
|
|
||||||
|
if not isinstance(writers, list):
|
||||||
|
writers = [writers] * len(tags)
|
||||||
|
|
||||||
|
assert len(writers) == len(tags) == len(values)
|
||||||
|
|
||||||
|
for writer, tag, val in zip(writers, tags, values):
|
||||||
|
with writer.as_default():
|
||||||
|
tf2.summary.scalar(tag, val, step=step)
|
||||||
|
writer.flush()
|
||||||
|
|
||||||
|
for writer in set(writers):
|
||||||
|
with writer.as_default():
|
||||||
|
writer.flush()
|
||||||
|
|
||||||
|
|
||||||
def write_results_to_tensorboard(
|
def write_results_to_tensorboard(
|
||||||
attack_results: AttackResults,
|
attack_results: AttackResults,
|
||||||
writers: Union[tf.summary.FileWriter, List[tf.summary.FileWriter]],
|
writers: Union[tf1.summary.FileWriter, List[tf1.summary.FileWriter]],
|
||||||
step: int,
|
step: int,
|
||||||
merge_classifiers: bool):
|
merge_classifiers: bool):
|
||||||
"""Write attack results to tensorboard.
|
"""Write attack results to tensorboard.
|
||||||
|
@ -69,11 +98,42 @@ def write_results_to_tensorboard(
|
||||||
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
||||||
attack_results)
|
attack_results)
|
||||||
if merge_classifiers:
|
if merge_classifiers:
|
||||||
att_tags = ['attack/' + '_'.join([s, m]) for s, m in
|
att_tags = ['attack/' + f'{s}_{m}' for s, m in
|
||||||
zip(att_slices, att_metrics)]
|
zip(att_slices, att_metrics)]
|
||||||
write_to_tensorboard([writers[t] for t in att_types],
|
write_to_tensorboard([writers[t] for t in att_types],
|
||||||
att_tags, att_values, step)
|
att_tags, att_values, step)
|
||||||
else:
|
else:
|
||||||
att_tags = ['attack/' + '_'.join([s, t, m]) for t, s, m in
|
att_tags = ['attack/' + f'{s}_{t}_{m}' for t, s, m in
|
||||||
zip(att_types, att_slices, att_metrics)]
|
zip(att_types, att_slices, att_metrics)]
|
||||||
write_to_tensorboard(writers, att_tags, att_values, step)
|
write_to_tensorboard(writers, att_tags, att_values, step)
|
||||||
|
|
||||||
|
|
||||||
|
def write_results_to_tensorboard_tf2(
|
||||||
|
attack_results: AttackResults,
|
||||||
|
writers: Union[tf2.summary.SummaryWriter, List[tf2.summary.SummaryWriter]],
|
||||||
|
step: int,
|
||||||
|
merge_classifiers: bool):
|
||||||
|
"""Write attack results to tensorboard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attack_results: results from attack
|
||||||
|
writers: a list of tensorboard writers or one writer to be used for metrics
|
||||||
|
step: step for the tensorboard summary
|
||||||
|
merge_classifiers: if true, plot different classifiers with the same
|
||||||
|
slicing_spec and metric in the same figure
|
||||||
|
"""
|
||||||
|
if writers is None or not writers:
|
||||||
|
raise ValueError('write_results_to_tensorboard does not get any writer.')
|
||||||
|
|
||||||
|
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
||||||
|
attack_results)
|
||||||
|
if merge_classifiers:
|
||||||
|
att_tags = ['attack/' + f'{s}_{m}' for s, m in
|
||||||
|
zip(att_slices, att_metrics)]
|
||||||
|
write_to_tensorboard_tf2([writers[t] for t in att_types],
|
||||||
|
att_tags, att_values, step)
|
||||||
|
else:
|
||||||
|
att_tags = ['attack/' + f'{s}_{t}_{m}' for t, s, m in
|
||||||
|
zip(att_types, att_slices, att_metrics)]
|
||||||
|
write_to_tensorboard_tf2(writers, att_tags, att_values, step)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue