forked from 626_privacy/tensorflow_privacy
Delete unused TF 1.0 API in TensorFlow Privacy.
PiperOrigin-RevId: 425900761
This commit is contained in:
parent
6fde7b0480
commit
778c804d1b
2 changed files with 36 additions and 97 deletions
|
@ -14,19 +14,15 @@
|
||||||
"""A hook and a function in tf estimator for membership inference attack."""
|
"""A hook and a function in tf estimator for membership inference attack."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import data_structures
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils_tensorboard
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils import log_loss
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_losses(estimator, input_fn, labels):
|
def calculate_losses(estimator, input_fn, labels):
|
||||||
|
@ -47,21 +43,21 @@ def calculate_losses(estimator, input_fn, labels):
|
||||||
loss: cross entropy loss of each sample
|
loss: cross entropy loss of each sample
|
||||||
"""
|
"""
|
||||||
pred = np.array(list(estimator.predict(input_fn=input_fn)))
|
pred = np.array(list(estimator.predict(input_fn=input_fn)))
|
||||||
loss = log_loss(labels, pred)
|
loss = utils.log_loss(labels, pred)
|
||||||
return pred, loss
|
return pred, loss
|
||||||
|
|
||||||
|
|
||||||
class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
||||||
"""Training hook to perform membership inference attack on epoch end."""
|
"""Training hook to perform membership inference attack on epoch end."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
|
||||||
estimator,
|
estimator,
|
||||||
in_train,
|
in_train,
|
||||||
out_train,
|
out_train,
|
||||||
input_fn_constructor,
|
input_fn_constructor,
|
||||||
slicing_spec: SlicingSpec = None,
|
slicing_spec: data_structures.SlicingSpec = None,
|
||||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
attack_types: Iterable[data_structures.AttackType] = (
|
||||||
|
data_structures.AttackType.THRESHOLD_ATTACK,),
|
||||||
tensorboard_dir=None,
|
tensorboard_dir=None,
|
||||||
tensorboard_merge_classifiers=False):
|
tensorboard_merge_classifiers=False):
|
||||||
"""Initialize the hook.
|
"""Initialize the hook.
|
||||||
|
@ -112,7 +108,7 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
||||||
self._attack_types)
|
self._attack_types)
|
||||||
logging.info(results)
|
logging.info(results)
|
||||||
|
|
||||||
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
|
||||||
results)
|
results)
|
||||||
print('Attack result:')
|
print('Attack result:')
|
||||||
print('\n'.join([
|
print('\n'.join([
|
||||||
|
@ -123,7 +119,8 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
||||||
# Write to tensorboard if tensorboard_dir is specified
|
# Write to tensorboard if tensorboard_dir is specified
|
||||||
global_step = self._estimator.get_variable_value('global_step')
|
global_step = self._estimator.get_variable_value('global_step')
|
||||||
if self._writers is not None:
|
if self._writers is not None:
|
||||||
write_results_to_tensorboard(results, self._writers, global_step,
|
utils_tensorboard.write_results_to_tensorboard(
|
||||||
|
results, self._writers, global_step,
|
||||||
self._tensorboard_merge_classifiers)
|
self._tensorboard_merge_classifiers)
|
||||||
|
|
||||||
|
|
||||||
|
@ -132,8 +129,9 @@ def run_attack_on_tf_estimator_model(
|
||||||
in_train,
|
in_train,
|
||||||
out_train,
|
out_train,
|
||||||
input_fn_constructor,
|
input_fn_constructor,
|
||||||
slicing_spec: SlicingSpec = None,
|
slicing_spec: data_structures.SlicingSpec = None,
|
||||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)):
|
attack_types: Iterable[data_structures.AttackType] = (
|
||||||
|
data_structures.AttackType.THRESHOLD_ATTACK,)):
|
||||||
"""Performs the attack in the end of training.
|
"""Performs the attack in the end of training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -164,14 +162,14 @@ def run_attack_on_tf_estimator_model(
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def run_attack_helper(
|
def run_attack_helper(estimator,
|
||||||
estimator,
|
|
||||||
in_train_input_fn,
|
in_train_input_fn,
|
||||||
out_train_input_fn,
|
out_train_input_fn,
|
||||||
in_train_labels,
|
in_train_labels,
|
||||||
out_train_labels,
|
out_train_labels,
|
||||||
slicing_spec: SlicingSpec = None,
|
slicing_spec: data_structures.SlicingSpec = None,
|
||||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)):
|
attack_types: Iterable[data_structures.AttackType] = (
|
||||||
|
data_structures.AttackType.THRESHOLD_ATTACK,)):
|
||||||
"""A helper function to perform attack.
|
"""A helper function to perform attack.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -192,7 +190,7 @@ def run_attack_helper(
|
||||||
out_train_pred, out_train_loss = calculate_losses(estimator,
|
out_train_pred, out_train_loss = calculate_losses(estimator,
|
||||||
out_train_input_fn,
|
out_train_input_fn,
|
||||||
out_train_labels)
|
out_train_labels)
|
||||||
attack_input = AttackInputData(
|
attack_input = data_structures.AttackInputData(
|
||||||
logits_train=in_train_pred,
|
logits_train=in_train_pred,
|
||||||
logits_test=out_train_pred,
|
logits_test=out_train_pred,
|
||||||
labels_train=in_train_labels,
|
labels_train=in_train_labels,
|
||||||
|
|
|
@ -13,42 +13,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Utility functions for writing attack results to tensorboard."""
|
"""Utility functions for writing attack results to tensorboard."""
|
||||||
|
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import tensorflow as tf2
|
import tensorflow as tf
|
||||||
import tensorflow.compat.v1 as tf1
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
||||||
|
|
||||||
|
|
||||||
def write_to_tensorboard(writers, tags, values, step):
|
|
||||||
"""Write metrics to tensorboard.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
writers: a list of tensorboard writers or one writer to be used for metrics.
|
|
||||||
If it's a list, it should be of the same length as tags
|
|
||||||
tags: a list of tags of metrics
|
|
||||||
values: a list of values of metrics with the same length as tags
|
|
||||||
step: step for the tensorboard summary
|
|
||||||
"""
|
|
||||||
if writers is None or not writers:
|
|
||||||
raise ValueError('write_to_tensorboard does not get any writer.')
|
|
||||||
|
|
||||||
if not isinstance(writers, list):
|
|
||||||
writers = [writers] * len(tags)
|
|
||||||
|
|
||||||
assert len(writers) == len(tags) == len(values)
|
|
||||||
|
|
||||||
for writer, tag, val in zip(writers, tags, values):
|
|
||||||
summary = tf1.Summary()
|
|
||||||
summary.value.add(tag=tag, simple_value=val)
|
|
||||||
writer.add_summary(summary, step)
|
|
||||||
|
|
||||||
for writer in set(writers):
|
|
||||||
writer.flush()
|
|
||||||
|
|
||||||
|
|
||||||
def write_to_tensorboard_tf2(writers, tags, values, step):
|
def write_to_tensorboard_tf2(writers, tags, values, step):
|
||||||
"""Write metrics to tensorboard.
|
"""Write metrics to tensorboard.
|
||||||
|
|
||||||
|
@ -69,7 +40,7 @@ def write_to_tensorboard_tf2(writers, tags, values, step):
|
||||||
|
|
||||||
for writer, tag, val in zip(writers, tags, values):
|
for writer, tag, val in zip(writers, tags, values):
|
||||||
with writer.as_default():
|
with writer.as_default():
|
||||||
tf2.summary.scalar(tag, val, step=step)
|
tf.summary.scalar(tag, val, step=step)
|
||||||
writer.flush()
|
writer.flush()
|
||||||
|
|
||||||
for writer in set(writers):
|
for writer in set(writers):
|
||||||
|
@ -77,39 +48,9 @@ def write_to_tensorboard_tf2(writers, tags, values, step):
|
||||||
writer.flush()
|
writer.flush()
|
||||||
|
|
||||||
|
|
||||||
def write_results_to_tensorboard(attack_results: AttackResults,
|
|
||||||
writers: Union[tf1.summary.FileWriter,
|
|
||||||
List[tf1.summary.FileWriter]],
|
|
||||||
step: int, merge_classifiers: bool):
|
|
||||||
"""Write attack results to tensorboard.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attack_results: results from attack
|
|
||||||
writers: a list of tensorboard writers or one writer to be used for metrics
|
|
||||||
step: step for the tensorboard summary
|
|
||||||
merge_classifiers: if true, plot different classifiers with the same
|
|
||||||
slicing_spec and metric in the same figure
|
|
||||||
"""
|
|
||||||
if writers is None or not writers:
|
|
||||||
raise ValueError('write_results_to_tensorboard does not get any writer.')
|
|
||||||
|
|
||||||
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
|
||||||
attack_results)
|
|
||||||
if merge_classifiers:
|
|
||||||
att_tags = ['attack/' + f'{s}_{m}' for s, m in zip(att_slices, att_metrics)]
|
|
||||||
write_to_tensorboard([writers[t] for t in att_types], att_tags, att_values,
|
|
||||||
step)
|
|
||||||
else:
|
|
||||||
att_tags = [
|
|
||||||
'attack/' + f'{s}_{t}_{m}'
|
|
||||||
for t, s, m in zip(att_types, att_slices, att_metrics)
|
|
||||||
]
|
|
||||||
write_to_tensorboard(writers, att_tags, att_values, step)
|
|
||||||
|
|
||||||
|
|
||||||
def write_results_to_tensorboard_tf2(
|
def write_results_to_tensorboard_tf2(
|
||||||
attack_results: AttackResults,
|
attack_results: AttackResults,
|
||||||
writers: Union[tf2.summary.SummaryWriter, List[tf2.summary.SummaryWriter]],
|
writers: Union[tf.summary.SummaryWriter, List[tf.summary.SummaryWriter]],
|
||||||
step: int, merge_classifiers: bool):
|
step: int, merge_classifiers: bool):
|
||||||
"""Write attack results to tensorboard.
|
"""Write attack results to tensorboard.
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue