diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation.py index ef15734..1b72086 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation.py @@ -14,19 +14,15 @@ """A hook and a function in tf estimator for membership inference attack.""" import os - from typing import Iterable + from absl import logging import numpy as np -import tensorflow.compat.v1 as tf - +import tensorflow as tf +from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import data_structures from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import get_flattened_attack_metrics -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils import log_loss -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard +from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils +from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils_tensorboard def calculate_losses(estimator, input_fn, labels): @@ -47,23 +43,23 @@ def calculate_losses(estimator, input_fn, labels): loss: cross entropy loss of each sample """ pred = np.array(list(estimator.predict(input_fn=input_fn))) - loss = log_loss(labels, pred) + loss = utils.log_loss(labels, pred) return pred, loss class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook): """Training hook to perform membership inference attack on epoch end.""" - def __init__( - self, - estimator, - in_train, - out_train, - input_fn_constructor, - slicing_spec: SlicingSpec = None, - attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,), - tensorboard_dir=None, - tensorboard_merge_classifiers=False): + def __init__(self, + estimator, + in_train, + out_train, + input_fn_constructor, + slicing_spec: data_structures.SlicingSpec = None, + attack_types: Iterable[data_structures.AttackType] = ( + data_structures.AttackType.THRESHOLD_ATTACK,), + tensorboard_dir=None, + tensorboard_merge_classifiers=False): """Initialize the hook. Args: @@ -112,7 +108,7 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook): self._attack_types) logging.info(results) - att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( + att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics( results) print('Attack result:') print('\n'.join([ @@ -123,8 +119,9 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook): # Write to tensorboard if tensorboard_dir is specified global_step = self._estimator.get_variable_value('global_step') if self._writers is not None: - write_results_to_tensorboard(results, self._writers, global_step, - self._tensorboard_merge_classifiers) + utils_tensorboard.write_results_to_tensorboard( + results, self._writers, global_step, + self._tensorboard_merge_classifiers) def run_attack_on_tf_estimator_model( @@ -132,8 +129,9 @@ def run_attack_on_tf_estimator_model( in_train, out_train, input_fn_constructor, - slicing_spec: SlicingSpec = None, - attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)): + slicing_spec: data_structures.SlicingSpec = None, + attack_types: Iterable[data_structures.AttackType] = ( + data_structures.AttackType.THRESHOLD_ATTACK,)): """Performs the attack in the end of training. Args: @@ -164,14 +162,14 @@ def run_attack_on_tf_estimator_model( return results -def run_attack_helper( - estimator, - in_train_input_fn, - out_train_input_fn, - in_train_labels, - out_train_labels, - slicing_spec: SlicingSpec = None, - attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)): +def run_attack_helper(estimator, + in_train_input_fn, + out_train_input_fn, + in_train_labels, + out_train_labels, + slicing_spec: data_structures.SlicingSpec = None, + attack_types: Iterable[data_structures.AttackType] = ( + data_structures.AttackType.THRESHOLD_ATTACK,)): """A helper function to perform attack. Args: @@ -192,7 +190,7 @@ def run_attack_helper( out_train_pred, out_train_loss = calculate_losses(estimator, out_train_input_fn, out_train_labels) - attack_input = AttackInputData( + attack_input = data_structures.AttackInputData( logits_train=in_train_pred, logits_test=out_train_pred, labels_train=in_train_labels, diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_tensorboard.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_tensorboard.py index afaf596..2c150e0 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_tensorboard.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_tensorboard.py @@ -13,42 +13,13 @@ # limitations under the License. """Utility functions for writing attack results to tensorboard.""" -from typing import List -from typing import Union +from typing import List, Union -import tensorflow as tf2 -import tensorflow.compat.v1 as tf1 +import tensorflow as tf from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import get_flattened_attack_metrics -def write_to_tensorboard(writers, tags, values, step): - """Write metrics to tensorboard. - - Args: - writers: a list of tensorboard writers or one writer to be used for metrics. - If it's a list, it should be of the same length as tags - tags: a list of tags of metrics - values: a list of values of metrics with the same length as tags - step: step for the tensorboard summary - """ - if writers is None or not writers: - raise ValueError('write_to_tensorboard does not get any writer.') - - if not isinstance(writers, list): - writers = [writers] * len(tags) - - assert len(writers) == len(tags) == len(values) - - for writer, tag, val in zip(writers, tags, values): - summary = tf1.Summary() - summary.value.add(tag=tag, simple_value=val) - writer.add_summary(summary, step) - - for writer in set(writers): - writer.flush() - - def write_to_tensorboard_tf2(writers, tags, values, step): """Write metrics to tensorboard. @@ -69,7 +40,7 @@ def write_to_tensorboard_tf2(writers, tags, values, step): for writer, tag, val in zip(writers, tags, values): with writer.as_default(): - tf2.summary.scalar(tag, val, step=step) + tf.summary.scalar(tag, val, step=step) writer.flush() for writer in set(writers): @@ -77,39 +48,9 @@ def write_to_tensorboard_tf2(writers, tags, values, step): writer.flush() -def write_results_to_tensorboard(attack_results: AttackResults, - writers: Union[tf1.summary.FileWriter, - List[tf1.summary.FileWriter]], - step: int, merge_classifiers: bool): - """Write attack results to tensorboard. - - Args: - attack_results: results from attack - writers: a list of tensorboard writers or one writer to be used for metrics - step: step for the tensorboard summary - merge_classifiers: if true, plot different classifiers with the same - slicing_spec and metric in the same figure - """ - if writers is None or not writers: - raise ValueError('write_results_to_tensorboard does not get any writer.') - - att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( - attack_results) - if merge_classifiers: - att_tags = ['attack/' + f'{s}_{m}' for s, m in zip(att_slices, att_metrics)] - write_to_tensorboard([writers[t] for t in att_types], att_tags, att_values, - step) - else: - att_tags = [ - 'attack/' + f'{s}_{t}_{m}' - for t, s, m in zip(att_types, att_slices, att_metrics) - ] - write_to_tensorboard(writers, att_tags, att_values, step) - - def write_results_to_tensorboard_tf2( attack_results: AttackResults, - writers: Union[tf2.summary.SummaryWriter, List[tf2.summary.SummaryWriter]], + writers: Union[tf.summary.SummaryWriter, List[tf.summary.SummaryWriter]], step: int, merge_classifiers: bool): """Write attack results to tensorboard.