forked from 626_privacy/tensorflow_privacy
Update tf_estimator_evaluation and keras_evaluation to new API.
PiperOrigin-RevId: 328195220
This commit is contained in:
parent
7a77d5d92c
commit
f90c78bd54
8 changed files with 194 additions and 99 deletions
|
@ -293,8 +293,8 @@ class RocCurve:
|
|||
"""Returns AUC and advantage metrics."""
|
||||
return '\n'.join([
|
||||
'RocCurve(',
|
||||
' AUC: %f.02' % self.get_auc(),
|
||||
' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')'
|
||||
' AUC: %.2f' % self.get_auc(),
|
||||
' Attacker advantage: %.2f' % self.get_attacker_advantage(), ')'
|
||||
])
|
||||
|
||||
|
||||
|
@ -324,8 +324,8 @@ class SingleAttackResult:
|
|||
'SingleAttackResult(',
|
||||
' SliceSpec: %s' % str(self.slice_spec),
|
||||
' AttackType: %s' % str(self.attack_type),
|
||||
' AUC: %f.02' % self.get_auc(),
|
||||
' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')'
|
||||
' AUC: %.2f' % self.get_auc(),
|
||||
' Attacker advantage: %.2f' % self.get_attacker_advantage(), ')'
|
||||
])
|
||||
|
||||
|
||||
|
|
|
@ -15,11 +15,17 @@
|
|||
# Lint as: python3
|
||||
"""A callback and a function in keras for membership inference attack."""
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
from absl import logging
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
||||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
|
||||
|
||||
|
@ -44,20 +50,25 @@ def calculate_losses(model, data, labels):
|
|||
class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
||||
"""Callback to perform membership inference attack on epoch end."""
|
||||
|
||||
def __init__(self, in_train, out_train, attack_classifiers,
|
||||
tensorboard_dir=None):
|
||||
def __init__(
|
||||
self,
|
||||
in_train, out_train,
|
||||
slicing_spec: SlicingSpec = None,
|
||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
||||
tensorboard_dir=None):
|
||||
"""Initalizes the callback.
|
||||
|
||||
Args:
|
||||
in_train: (in_training samples, in_training labels)
|
||||
out_train: (out_training samples, out_training labels)
|
||||
attack_classifiers: a list of classifiers to be used by attacker, must be
|
||||
a subset of ['lr', 'mlp', 'rf', 'knn']
|
||||
slicing_spec: slicing specification of the attack
|
||||
attack_types: a list of attacks, each of type AttackType
|
||||
tensorboard_dir: directory for tensorboard summary
|
||||
"""
|
||||
self._in_train_data, self._in_train_labels = in_train
|
||||
self._out_train_data, self._out_train_labels = out_train
|
||||
self._attack_classifiers = attack_classifiers
|
||||
self._slicing_spec = slicing_spec
|
||||
self._attack_types = attack_types
|
||||
# Setup tensorboard writer if tensorboard_dir is specified
|
||||
if tensorboard_dir:
|
||||
with tf.Graph().as_default():
|
||||
|
@ -71,24 +82,33 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
|||
self.model,
|
||||
(self._in_train_data, self._in_train_labels),
|
||||
(self._out_train_data, self._out_train_labels),
|
||||
self._attack_classifiers)
|
||||
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
|
||||
self._slicing_spec,
|
||||
self._attack_types)
|
||||
logging.info(results)
|
||||
|
||||
attack_properties, attack_values = get_all_attack_results(results)
|
||||
print('Attack result:')
|
||||
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
|
||||
zip(attack_properties, attack_values)]))
|
||||
|
||||
# Write to tensorboard if tensorboard_dir is specified
|
||||
write_to_tensorboard(self._writer, ['attack advantage'],
|
||||
[results['all_thresh_loss_advantage']], epoch)
|
||||
attack_property_tags = ['attack/' + '_'.join(p) for p in attack_properties]
|
||||
write_to_tensorboard(self._writer, attack_property_tags, attack_values,
|
||||
epoch)
|
||||
|
||||
|
||||
def run_attack_on_keras_model(model, in_train, out_train, attack_classifiers):
|
||||
def run_attack_on_keras_model(
|
||||
model, in_train, out_train,
|
||||
slicing_spec: SlicingSpec = None,
|
||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)):
|
||||
"""Performs the attack on a trained model.
|
||||
|
||||
Args:
|
||||
model: model to be tested
|
||||
in_train: a (in_training samples, in_training labels) tuple
|
||||
out_train: a (out_training samples, out_training labels) tuple
|
||||
attack_classifiers: a list of classifiers to be used by attacker, must be
|
||||
a subset of ['lr', 'mlp', 'rf', 'knn']
|
||||
slicing_spec: slicing specification of the attack
|
||||
attack_types: a list of attacks, each of type AttackType
|
||||
Returns:
|
||||
Results of the attack
|
||||
"""
|
||||
|
@ -100,9 +120,12 @@ def run_attack_on_keras_model(model, in_train, out_train, attack_classifiers):
|
|||
in_train_labels)
|
||||
out_train_pred, out_train_loss = calculate_losses(model, out_train_data,
|
||||
out_train_labels)
|
||||
results = mia.run_all_attacks(in_train_loss, out_train_loss,
|
||||
in_train_pred, out_train_pred,
|
||||
in_train_labels, out_train_labels,
|
||||
attack_classifiers=attack_classifiers)
|
||||
attack_input = AttackInputData(
|
||||
logits_train=in_train_pred, logits_test=out_train_pred,
|
||||
labels_train=in_train_labels, labels_test=out_train_labels,
|
||||
loss_train=in_train_loss, loss_test=out_train_loss
|
||||
)
|
||||
results = mia.run_attacks(attack_input,
|
||||
slicing_spec=slicing_spec,
|
||||
attack_types=attack_types)
|
||||
return results
|
||||
|
||||
|
|
|
@ -20,8 +20,12 @@ from absl import flags
|
|||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import MembershipInferenceCallback
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import run_attack_on_keras_model
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||
|
||||
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
|
||||
|
@ -78,10 +82,11 @@ def main(unused_argv):
|
|||
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
||||
|
||||
# Get callback for membership inference attack.
|
||||
mia_callback = MembershipInferenceCallback((train_data, train_labels),
|
||||
(test_data, test_labels),
|
||||
[],
|
||||
FLAGS.model_dir)
|
||||
mia_callback = MembershipInferenceCallback(
|
||||
(train_data, train_labels),
|
||||
(test_data, test_labels),
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK],
|
||||
tensorboard_dir=FLAGS.model_dir)
|
||||
|
||||
# Train model with Keras
|
||||
model.fit(train_data, train_labels,
|
||||
|
@ -91,13 +96,18 @@ def main(unused_argv):
|
|||
callbacks=[mia_callback],
|
||||
verbose=2)
|
||||
|
||||
print('End of training attack')
|
||||
attack_results = run_attack_on_keras_model(model,
|
||||
(train_data, train_labels),
|
||||
(test_data, test_labels),
|
||||
[])
|
||||
print('all_thresh_loss_advantage',
|
||||
attack_results['all_thresh_loss_advantage'])
|
||||
print('End of training attack:')
|
||||
attack_results = run_attack_on_keras_model(
|
||||
model,
|
||||
(train_data, train_labels),
|
||||
(test_data, test_labels),
|
||||
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS]
|
||||
)
|
||||
|
||||
attack_properties, attack_values = get_all_attack_results(attack_results)
|
||||
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
|
||||
zip(attack_properties, attack_values)]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -21,6 +21,9 @@ import numpy as np
|
|||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack import keras_evaluation
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||
|
||||
|
||||
class UtilsTest(absltest.TestCase):
|
||||
|
@ -62,10 +65,11 @@ class UtilsTest(absltest.TestCase):
|
|||
self.model,
|
||||
(self.train_data, self.train_labels),
|
||||
(self.test_data, self.test_labels),
|
||||
[])
|
||||
self.assertIsInstance(results, dict)
|
||||
self.assertIn('all_thresh_loss_auc', results)
|
||||
self.assertIn('all_thresh_loss_advantage', results)
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK])
|
||||
self.assertIsInstance(results, AttackResults)
|
||||
attack_properties, attack_values = get_all_attack_results(results)
|
||||
self.assertLen(attack_properties, 2)
|
||||
self.assertLen(attack_values, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -15,13 +15,19 @@
|
|||
# Lint as: python3
|
||||
"""A hook and a function in tf estimator for membership inference attack."""
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
from absl import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
||||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
|
||||
|
||||
|
@ -49,16 +55,17 @@ def calculate_losses(estimator, input_fn, labels):
|
|||
|
||||
|
||||
class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
||||
"""Training hook to perform membership inference attack after an epoch."""
|
||||
"""Training hook to perform membership inference attack on epoch end."""
|
||||
|
||||
def __init__(self,
|
||||
estimator,
|
||||
in_train,
|
||||
out_train,
|
||||
input_fn_constructor,
|
||||
attack_classifiers,
|
||||
writer=None):
|
||||
"""Initalizes the hook.
|
||||
def __init__(
|
||||
self,
|
||||
estimator,
|
||||
in_train, out_train,
|
||||
input_fn_constructor,
|
||||
slicing_spec: SlicingSpec = None,
|
||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
||||
writer=None):
|
||||
"""Initialize the hook.
|
||||
|
||||
Args:
|
||||
estimator: model to be tested
|
||||
|
@ -66,8 +73,8 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
|||
out_train: (out_training samples, out_training labels)
|
||||
input_fn_constructor: a function that receives sample, label and construct
|
||||
the input_fn for model prediction
|
||||
attack_classifiers: a list of classifiers to be used by attacker, must be
|
||||
a subset of ['lr', 'mlp', 'rf', 'knn']
|
||||
slicing_spec: slicing specification of the attack
|
||||
attack_types: a list of attacks, each of type AttackType
|
||||
writer: summary writer for tensorboard
|
||||
"""
|
||||
in_train_data, self._in_train_labels = in_train
|
||||
|
@ -79,7 +86,8 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
|||
self._out_train_input_fn = input_fn_constructor(out_train_data,
|
||||
self._out_train_labels)
|
||||
self._estimator = estimator
|
||||
self._attack_classifiers = attack_classifiers
|
||||
self._slicing_spec = slicing_spec
|
||||
self._attack_types = attack_types
|
||||
self._writer = writer
|
||||
if self._writer:
|
||||
logging.info('Will write to tensorboard.')
|
||||
|
@ -89,19 +97,28 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
|||
self._in_train_input_fn,
|
||||
self._out_train_input_fn,
|
||||
self._in_train_labels, self._out_train_labels,
|
||||
self._attack_classifiers)
|
||||
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
|
||||
self._slicing_spec,
|
||||
self._attack_types)
|
||||
logging.info(results)
|
||||
|
||||
attack_properties, attack_values = get_all_attack_results(results)
|
||||
print('Attack result:')
|
||||
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
|
||||
zip(attack_properties, attack_values)]))
|
||||
|
||||
# Write to tensorboard if writer is specified
|
||||
global_step = self._estimator.get_variable_value('global_step')
|
||||
write_to_tensorboard(self._writer, ['attack advantage'],
|
||||
[results['all_thresh_loss_advantage']], global_step)
|
||||
attack_property_tags = ['attack/' + '_'.join(p) for p in attack_properties]
|
||||
write_to_tensorboard(self._writer, attack_property_tags, attack_values,
|
||||
global_step)
|
||||
|
||||
|
||||
def run_attack_on_tf_estimator_model(estimator, in_train, out_train,
|
||||
input_fn_constructor, attack_classifiers):
|
||||
"""A function to perform the attack in the end of training.
|
||||
def run_attack_on_tf_estimator_model(
|
||||
estimator, in_train, out_train,
|
||||
input_fn_constructor,
|
||||
slicing_spec: SlicingSpec = None,
|
||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)):
|
||||
"""Performs the attack in the end of training.
|
||||
|
||||
Args:
|
||||
estimator: model to be tested
|
||||
|
@ -109,8 +126,8 @@ def run_attack_on_tf_estimator_model(estimator, in_train, out_train,
|
|||
out_train: (out_training samples, out_training labels)
|
||||
input_fn_constructor: a function that receives sample, label and construct
|
||||
the input_fn for model prediction
|
||||
attack_classifiers: a list of classifiers to be used by attacker, must be
|
||||
a subset of ['lr', 'mlp', 'rf', 'knn']
|
||||
slicing_spec: slicing specification of the attack
|
||||
attack_types: a list of attacks, each of type AttackType
|
||||
Returns:
|
||||
Results of the attack
|
||||
"""
|
||||
|
@ -125,17 +142,19 @@ def run_attack_on_tf_estimator_model(estimator, in_train, out_train,
|
|||
results = run_attack_helper(estimator,
|
||||
in_train_input_fn, out_train_input_fn,
|
||||
in_train_labels, out_train_labels,
|
||||
attack_classifiers)
|
||||
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
|
||||
slicing_spec,
|
||||
attack_types)
|
||||
logging.info('End of training attack:')
|
||||
logging.info(results)
|
||||
return results
|
||||
|
||||
|
||||
def run_attack_helper(estimator,
|
||||
in_train_input_fn, out_train_input_fn,
|
||||
in_train_labels, out_train_labels,
|
||||
attack_classifiers):
|
||||
def run_attack_helper(
|
||||
estimator,
|
||||
in_train_input_fn, out_train_input_fn,
|
||||
in_train_labels, out_train_labels,
|
||||
slicing_spec: SlicingSpec = None,
|
||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)):
|
||||
"""A helper function to perform attack.
|
||||
|
||||
Args:
|
||||
|
@ -144,8 +163,8 @@ def run_attack_helper(estimator,
|
|||
out_train_input_fn: input_fn for out of training data
|
||||
in_train_labels: in training labels
|
||||
out_train_labels: out of training labels
|
||||
attack_classifiers: a list of classifiers to be used by attacker, must be
|
||||
a subset of ['lr', 'mlp', 'rf', 'knn']
|
||||
slicing_spec: slicing specification of the attack
|
||||
attack_types: a list of attacks, each of type AttackType
|
||||
Returns:
|
||||
Results of the attack
|
||||
"""
|
||||
|
@ -156,9 +175,13 @@ def run_attack_helper(estimator,
|
|||
out_train_pred, out_train_loss = calculate_losses(estimator,
|
||||
out_train_input_fn,
|
||||
out_train_labels)
|
||||
results = mia.run_all_attacks(in_train_loss, out_train_loss,
|
||||
in_train_pred, out_train_pred,
|
||||
in_train_labels, out_train_labels,
|
||||
attack_classifiers=attack_classifiers)
|
||||
attack_input = AttackInputData(
|
||||
logits_train=in_train_pred, logits_test=out_train_pred,
|
||||
labels_train=in_train_labels, labels_test=out_train_labels,
|
||||
loss_train=in_train_loss, loss_test=out_train_loss
|
||||
)
|
||||
results = mia.run_attacks(attack_input,
|
||||
slicing_spec=slicing_spec,
|
||||
attack_types=attack_types)
|
||||
return results
|
||||
|
||||
|
|
|
@ -21,9 +21,11 @@ from absl import logging
|
|||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import MembershipInferenceTrainingHook
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import run_attack_on_tf_estimator_model
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
|
||||
|
@ -97,9 +99,9 @@ def load_mnist():
|
|||
|
||||
|
||||
def main(unused_argv):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
logging.set_verbosity(logging.INFO)
|
||||
logging.set_stderrthreshold(logging.INFO)
|
||||
tf.logging.set_verbosity(tf.logging.ERROR)
|
||||
logging.set_verbosity(logging.ERROR)
|
||||
logging.set_stderrthreshold(logging.ERROR)
|
||||
logging.get_absl_handler().use_absl_log_file()
|
||||
|
||||
# Load training and test data.
|
||||
|
@ -121,12 +123,13 @@ def main(unused_argv):
|
|||
summary_writer = tf.summary.FileWriter(FLAGS.model_dir)
|
||||
else:
|
||||
summary_writer = None
|
||||
mia_hook = MembershipInferenceTrainingHook(mnist_classifier,
|
||||
(train_data, train_labels),
|
||||
(test_data, test_labels),
|
||||
input_fn_constructor,
|
||||
[],
|
||||
summary_writer)
|
||||
mia_hook = MembershipInferenceTrainingHook(
|
||||
mnist_classifier,
|
||||
(train_data, train_labels),
|
||||
(test_data, test_labels),
|
||||
input_fn_constructor,
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK],
|
||||
writer=summary_writer)
|
||||
|
||||
# Create tf.Estimator input functions for the training and test data.
|
||||
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
|
@ -151,11 +154,17 @@ def main(unused_argv):
|
|||
print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))
|
||||
|
||||
print('End of training attack')
|
||||
run_attack_on_tf_estimator_model(mnist_classifier,
|
||||
(train_data, train_labels),
|
||||
(test_data, test_labels),
|
||||
input_fn_constructor,
|
||||
['lr'])
|
||||
attack_results = run_attack_on_tf_estimator_model(
|
||||
mnist_classifier,
|
||||
(train_data, train_labels),
|
||||
(test_data, test_labels),
|
||||
input_fn_constructor,
|
||||
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS]
|
||||
)
|
||||
attack_properties, attack_values = get_all_attack_results(attack_results)
|
||||
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
|
||||
zip(attack_properties, attack_values)]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -21,6 +21,9 @@ import numpy as np
|
|||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack import tf_estimator_evaluation
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||
|
||||
|
||||
class UtilsTest(absltest.TestCase):
|
||||
|
@ -77,15 +80,17 @@ class UtilsTest(absltest.TestCase):
|
|||
|
||||
def test_run_attack_helper(self):
|
||||
"""Test the attack."""
|
||||
results = tf_estimator_evaluation.run_attack_helper(self.classifier,
|
||||
self.input_fn_train,
|
||||
self.input_fn_test,
|
||||
self.train_labels,
|
||||
self.test_labels,
|
||||
[])
|
||||
self.assertIsInstance(results, dict)
|
||||
self.assertIn('all_thresh_loss_auc', results)
|
||||
self.assertIn('all_thresh_loss_advantage', results)
|
||||
results = tf_estimator_evaluation.run_attack_helper(
|
||||
self.classifier,
|
||||
self.input_fn_train,
|
||||
self.input_fn_test,
|
||||
self.train_labels,
|
||||
self.test_labels,
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK])
|
||||
self.assertIsInstance(results, AttackResults)
|
||||
attack_properties, attack_values = get_all_attack_results(results)
|
||||
self.assertLen(attack_properties, 2)
|
||||
self.assertLen(attack_values, 2)
|
||||
|
||||
def test_run_attack_on_tf_estimator_model(self):
|
||||
"""Test the attack on the final models."""
|
||||
|
@ -97,10 +102,11 @@ class UtilsTest(absltest.TestCase):
|
|||
(self.train_data, self.train_labels),
|
||||
(self.test_data, self.test_labels),
|
||||
input_fn_constructor,
|
||||
[])
|
||||
self.assertIsInstance(results, dict)
|
||||
self.assertIn('all_thresh_loss_auc', results)
|
||||
self.assertIn('all_thresh_loss_advantage', results)
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK])
|
||||
self.assertIsInstance(results, AttackResults)
|
||||
attack_properties, attack_values = get_all_attack_results(results)
|
||||
self.assertLen(attack_properties, 2)
|
||||
self.assertLen(attack_values, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -20,6 +20,7 @@ from typing import Text, Dict, Union, List, Any, Tuple
|
|||
import numpy as np
|
||||
from sklearn import metrics
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
|
||||
|
||||
ArrayDict = Dict[Text, np.ndarray]
|
||||
|
@ -73,6 +74,25 @@ def prepend_to_keys(in_dict: Dict[Text, Any], prefix: Text) -> Dict[Text, Any]:
|
|||
return {prefix + k: v for k, v in in_dict.items()}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Utilities for managing result.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_all_attack_results(results: AttackResults):
|
||||
"""Get all results as a list of attack properties and a list of attack result."""
|
||||
properties = []
|
||||
values = []
|
||||
for attack_result in results.single_attack_results:
|
||||
slice_spec = attack_result.slice_spec
|
||||
prop = [str(slice_spec), str(attack_result.attack_type)]
|
||||
properties += [prop + ['adv'], prop + ['auc']]
|
||||
values += [float(attack_result.get_attacker_advantage()),
|
||||
float(attack_result.get_auc())]
|
||||
|
||||
return properties, values
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Subsampling and data selection functionality
|
||||
# ------------------------------------------------------------------------------
|
||||
|
|
Loading…
Reference in a new issue