Update tf_estimator_evaluation and keras_evaluation to new API.

PiperOrigin-RevId: 328195220
This commit is contained in:
Shuang Song 2020-08-24 13:03:02 -07:00 committed by A. Unique TensorFlower
parent 7a77d5d92c
commit f90c78bd54
8 changed files with 194 additions and 99 deletions

View file

@ -293,8 +293,8 @@ class RocCurve:
"""Returns AUC and advantage metrics.""" """Returns AUC and advantage metrics."""
return '\n'.join([ return '\n'.join([
'RocCurve(', 'RocCurve(',
' AUC: %f.02' % self.get_auc(), ' AUC: %.2f' % self.get_auc(),
' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')' ' Attacker advantage: %.2f' % self.get_attacker_advantage(), ')'
]) ])
@ -324,8 +324,8 @@ class SingleAttackResult:
'SingleAttackResult(', 'SingleAttackResult(',
' SliceSpec: %s' % str(self.slice_spec), ' SliceSpec: %s' % str(self.slice_spec),
' AttackType: %s' % str(self.attack_type), ' AttackType: %s' % str(self.attack_type),
' AUC: %f.02' % self.get_auc(), ' AUC: %.2f' % self.get_auc(),
' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')' ' Attacker advantage: %.2f' % self.get_attacker_advantage(), ')'
]) ])

View file

@ -15,11 +15,17 @@
# Lint as: python3 # Lint as: python3
"""A callback and a function in keras for membership inference attack.""" """A callback and a function in keras for membership inference attack."""
from typing import Iterable
from absl import logging from absl import logging
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
@ -44,20 +50,25 @@ def calculate_losses(model, data, labels):
class MembershipInferenceCallback(tf.keras.callbacks.Callback): class MembershipInferenceCallback(tf.keras.callbacks.Callback):
"""Callback to perform membership inference attack on epoch end.""" """Callback to perform membership inference attack on epoch end."""
def __init__(self, in_train, out_train, attack_classifiers, def __init__(
self,
in_train, out_train,
slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
tensorboard_dir=None): tensorboard_dir=None):
"""Initalizes the callback. """Initalizes the callback.
Args: Args:
in_train: (in_training samples, in_training labels) in_train: (in_training samples, in_training labels)
out_train: (out_training samples, out_training labels) out_train: (out_training samples, out_training labels)
attack_classifiers: a list of classifiers to be used by attacker, must be slicing_spec: slicing specification of the attack
a subset of ['lr', 'mlp', 'rf', 'knn'] attack_types: a list of attacks, each of type AttackType
tensorboard_dir: directory for tensorboard summary tensorboard_dir: directory for tensorboard summary
""" """
self._in_train_data, self._in_train_labels = in_train self._in_train_data, self._in_train_labels = in_train
self._out_train_data, self._out_train_labels = out_train self._out_train_data, self._out_train_labels = out_train
self._attack_classifiers = attack_classifiers self._slicing_spec = slicing_spec
self._attack_types = attack_types
# Setup tensorboard writer if tensorboard_dir is specified # Setup tensorboard writer if tensorboard_dir is specified
if tensorboard_dir: if tensorboard_dir:
with tf.Graph().as_default(): with tf.Graph().as_default():
@ -71,24 +82,33 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
self.model, self.model,
(self._in_train_data, self._in_train_labels), (self._in_train_data, self._in_train_labels),
(self._out_train_data, self._out_train_labels), (self._out_train_data, self._out_train_labels),
self._attack_classifiers) self._slicing_spec,
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage']) self._attack_types)
logging.info(results) logging.info(results)
attack_properties, attack_values = get_all_attack_results(results)
print('Attack result:')
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
zip(attack_properties, attack_values)]))
# Write to tensorboard if tensorboard_dir is specified # Write to tensorboard if tensorboard_dir is specified
write_to_tensorboard(self._writer, ['attack advantage'], attack_property_tags = ['attack/' + '_'.join(p) for p in attack_properties]
[results['all_thresh_loss_advantage']], epoch) write_to_tensorboard(self._writer, attack_property_tags, attack_values,
epoch)
def run_attack_on_keras_model(model, in_train, out_train, attack_classifiers): def run_attack_on_keras_model(
model, in_train, out_train,
slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)):
"""Performs the attack on a trained model. """Performs the attack on a trained model.
Args: Args:
model: model to be tested model: model to be tested
in_train: a (in_training samples, in_training labels) tuple in_train: a (in_training samples, in_training labels) tuple
out_train: a (out_training samples, out_training labels) tuple out_train: a (out_training samples, out_training labels) tuple
attack_classifiers: a list of classifiers to be used by attacker, must be slicing_spec: slicing specification of the attack
a subset of ['lr', 'mlp', 'rf', 'knn'] attack_types: a list of attacks, each of type AttackType
Returns: Returns:
Results of the attack Results of the attack
""" """
@ -100,9 +120,12 @@ def run_attack_on_keras_model(model, in_train, out_train, attack_classifiers):
in_train_labels) in_train_labels)
out_train_pred, out_train_loss = calculate_losses(model, out_train_data, out_train_pred, out_train_loss = calculate_losses(model, out_train_data,
out_train_labels) out_train_labels)
results = mia.run_all_attacks(in_train_loss, out_train_loss, attack_input = AttackInputData(
in_train_pred, out_train_pred, logits_train=in_train_pred, logits_test=out_train_pred,
in_train_labels, out_train_labels, labels_train=in_train_labels, labels_test=out_train_labels,
attack_classifiers=attack_classifiers) loss_train=in_train_loss, loss_test=out_train_loss
)
results = mia.run_attacks(attack_input,
slicing_spec=slicing_spec,
attack_types=attack_types)
return results return results

View file

@ -20,8 +20,12 @@ from absl import flags
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import MembershipInferenceCallback from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import MembershipInferenceCallback
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import run_attack_on_keras_model from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import run_attack_on_keras_model
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
GradientDescentOptimizer = tf.train.GradientDescentOptimizer GradientDescentOptimizer = tf.train.GradientDescentOptimizer
@ -78,10 +82,11 @@ def main(unused_argv):
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
# Get callback for membership inference attack. # Get callback for membership inference attack.
mia_callback = MembershipInferenceCallback((train_data, train_labels), mia_callback = MembershipInferenceCallback(
(train_data, train_labels),
(test_data, test_labels), (test_data, test_labels),
[], attack_types=[AttackType.THRESHOLD_ATTACK],
FLAGS.model_dir) tensorboard_dir=FLAGS.model_dir)
# Train model with Keras # Train model with Keras
model.fit(train_data, train_labels, model.fit(train_data, train_labels,
@ -91,13 +96,18 @@ def main(unused_argv):
callbacks=[mia_callback], callbacks=[mia_callback],
verbose=2) verbose=2)
print('End of training attack') print('End of training attack:')
attack_results = run_attack_on_keras_model(model, attack_results = run_attack_on_keras_model(
model,
(train_data, train_labels), (train_data, train_labels),
(test_data, test_labels), (test_data, test_labels),
[]) slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
print('all_thresh_loss_advantage', attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS]
attack_results['all_thresh_loss_advantage']) )
attack_properties, attack_values = get_all_attack_results(attack_results)
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
zip(attack_properties, attack_values)]))
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -21,6 +21,9 @@ import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack import keras_evaluation from tensorflow_privacy.privacy.membership_inference_attack import keras_evaluation
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
class UtilsTest(absltest.TestCase): class UtilsTest(absltest.TestCase):
@ -62,10 +65,11 @@ class UtilsTest(absltest.TestCase):
self.model, self.model,
(self.train_data, self.train_labels), (self.train_data, self.train_labels),
(self.test_data, self.test_labels), (self.test_data, self.test_labels),
[]) attack_types=[AttackType.THRESHOLD_ATTACK])
self.assertIsInstance(results, dict) self.assertIsInstance(results, AttackResults)
self.assertIn('all_thresh_loss_auc', results) attack_properties, attack_values = get_all_attack_results(results)
self.assertIn('all_thresh_loss_advantage', results) self.assertLen(attack_properties, 2)
self.assertLen(attack_values, 2)
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -15,13 +15,19 @@
# Lint as: python3 # Lint as: python3
"""A hook and a function in tf estimator for membership inference attack.""" """A hook and a function in tf estimator for membership inference attack."""
from typing import Iterable
from absl import logging from absl import logging
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
@ -49,16 +55,17 @@ def calculate_losses(estimator, input_fn, labels):
class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook): class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
"""Training hook to perform membership inference attack after an epoch.""" """Training hook to perform membership inference attack on epoch end."""
def __init__(self, def __init__(
self,
estimator, estimator,
in_train, in_train, out_train,
out_train,
input_fn_constructor, input_fn_constructor,
attack_classifiers, slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
writer=None): writer=None):
"""Initalizes the hook. """Initialize the hook.
Args: Args:
estimator: model to be tested estimator: model to be tested
@ -66,8 +73,8 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
out_train: (out_training samples, out_training labels) out_train: (out_training samples, out_training labels)
input_fn_constructor: a function that receives sample, label and construct input_fn_constructor: a function that receives sample, label and construct
the input_fn for model prediction the input_fn for model prediction
attack_classifiers: a list of classifiers to be used by attacker, must be slicing_spec: slicing specification of the attack
a subset of ['lr', 'mlp', 'rf', 'knn'] attack_types: a list of attacks, each of type AttackType
writer: summary writer for tensorboard writer: summary writer for tensorboard
""" """
in_train_data, self._in_train_labels = in_train in_train_data, self._in_train_labels = in_train
@ -79,7 +86,8 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
self._out_train_input_fn = input_fn_constructor(out_train_data, self._out_train_input_fn = input_fn_constructor(out_train_data,
self._out_train_labels) self._out_train_labels)
self._estimator = estimator self._estimator = estimator
self._attack_classifiers = attack_classifiers self._slicing_spec = slicing_spec
self._attack_types = attack_types
self._writer = writer self._writer = writer
if self._writer: if self._writer:
logging.info('Will write to tensorboard.') logging.info('Will write to tensorboard.')
@ -89,19 +97,28 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
self._in_train_input_fn, self._in_train_input_fn,
self._out_train_input_fn, self._out_train_input_fn,
self._in_train_labels, self._out_train_labels, self._in_train_labels, self._out_train_labels,
self._attack_classifiers) self._slicing_spec,
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage']) self._attack_types)
logging.info(results) logging.info(results)
attack_properties, attack_values = get_all_attack_results(results)
print('Attack result:')
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
zip(attack_properties, attack_values)]))
# Write to tensorboard if writer is specified # Write to tensorboard if writer is specified
global_step = self._estimator.get_variable_value('global_step') global_step = self._estimator.get_variable_value('global_step')
write_to_tensorboard(self._writer, ['attack advantage'], attack_property_tags = ['attack/' + '_'.join(p) for p in attack_properties]
[results['all_thresh_loss_advantage']], global_step) write_to_tensorboard(self._writer, attack_property_tags, attack_values,
global_step)
def run_attack_on_tf_estimator_model(estimator, in_train, out_train, def run_attack_on_tf_estimator_model(
input_fn_constructor, attack_classifiers): estimator, in_train, out_train,
"""A function to perform the attack in the end of training. input_fn_constructor,
slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)):
"""Performs the attack in the end of training.
Args: Args:
estimator: model to be tested estimator: model to be tested
@ -109,8 +126,8 @@ def run_attack_on_tf_estimator_model(estimator, in_train, out_train,
out_train: (out_training samples, out_training labels) out_train: (out_training samples, out_training labels)
input_fn_constructor: a function that receives sample, label and construct input_fn_constructor: a function that receives sample, label and construct
the input_fn for model prediction the input_fn for model prediction
attack_classifiers: a list of classifiers to be used by attacker, must be slicing_spec: slicing specification of the attack
a subset of ['lr', 'mlp', 'rf', 'knn'] attack_types: a list of attacks, each of type AttackType
Returns: Returns:
Results of the attack Results of the attack
""" """
@ -125,17 +142,19 @@ def run_attack_on_tf_estimator_model(estimator, in_train, out_train,
results = run_attack_helper(estimator, results = run_attack_helper(estimator,
in_train_input_fn, out_train_input_fn, in_train_input_fn, out_train_input_fn,
in_train_labels, out_train_labels, in_train_labels, out_train_labels,
attack_classifiers) slicing_spec,
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage']) attack_types)
logging.info('End of training attack:') logging.info('End of training attack:')
logging.info(results) logging.info(results)
return results return results
def run_attack_helper(estimator, def run_attack_helper(
estimator,
in_train_input_fn, out_train_input_fn, in_train_input_fn, out_train_input_fn,
in_train_labels, out_train_labels, in_train_labels, out_train_labels,
attack_classifiers): slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)):
"""A helper function to perform attack. """A helper function to perform attack.
Args: Args:
@ -144,8 +163,8 @@ def run_attack_helper(estimator,
out_train_input_fn: input_fn for out of training data out_train_input_fn: input_fn for out of training data
in_train_labels: in training labels in_train_labels: in training labels
out_train_labels: out of training labels out_train_labels: out of training labels
attack_classifiers: a list of classifiers to be used by attacker, must be slicing_spec: slicing specification of the attack
a subset of ['lr', 'mlp', 'rf', 'knn'] attack_types: a list of attacks, each of type AttackType
Returns: Returns:
Results of the attack Results of the attack
""" """
@ -156,9 +175,13 @@ def run_attack_helper(estimator,
out_train_pred, out_train_loss = calculate_losses(estimator, out_train_pred, out_train_loss = calculate_losses(estimator,
out_train_input_fn, out_train_input_fn,
out_train_labels) out_train_labels)
results = mia.run_all_attacks(in_train_loss, out_train_loss, attack_input = AttackInputData(
in_train_pred, out_train_pred, logits_train=in_train_pred, logits_test=out_train_pred,
in_train_labels, out_train_labels, labels_train=in_train_labels, labels_test=out_train_labels,
attack_classifiers=attack_classifiers) loss_train=in_train_loss, loss_test=out_train_loss
)
results = mia.run_attacks(attack_input,
slicing_spec=slicing_spec,
attack_types=attack_types)
return results return results

View file

@ -21,9 +21,11 @@ from absl import logging
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import MembershipInferenceTrainingHook from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import MembershipInferenceTrainingHook
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import run_attack_on_tf_estimator_model from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import run_attack_on_tf_estimator_model
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
GradientDescentOptimizer = tf.train.GradientDescentOptimizer GradientDescentOptimizer = tf.train.GradientDescentOptimizer
@ -97,9 +99,9 @@ def load_mnist():
def main(unused_argv): def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.ERROR)
logging.set_verbosity(logging.INFO) logging.set_verbosity(logging.ERROR)
logging.set_stderrthreshold(logging.INFO) logging.set_stderrthreshold(logging.ERROR)
logging.get_absl_handler().use_absl_log_file() logging.get_absl_handler().use_absl_log_file()
# Load training and test data. # Load training and test data.
@ -121,12 +123,13 @@ def main(unused_argv):
summary_writer = tf.summary.FileWriter(FLAGS.model_dir) summary_writer = tf.summary.FileWriter(FLAGS.model_dir)
else: else:
summary_writer = None summary_writer = None
mia_hook = MembershipInferenceTrainingHook(mnist_classifier, mia_hook = MembershipInferenceTrainingHook(
mnist_classifier,
(train_data, train_labels), (train_data, train_labels),
(test_data, test_labels), (test_data, test_labels),
input_fn_constructor, input_fn_constructor,
[], attack_types=[AttackType.THRESHOLD_ATTACK],
summary_writer) writer=summary_writer)
# Create tf.Estimator input functions for the training and test data. # Create tf.Estimator input functions for the training and test data.
train_input_fn = tf.estimator.inputs.numpy_input_fn( train_input_fn = tf.estimator.inputs.numpy_input_fn(
@ -151,11 +154,17 @@ def main(unused_argv):
print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy)) print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))
print('End of training attack') print('End of training attack')
run_attack_on_tf_estimator_model(mnist_classifier, attack_results = run_attack_on_tf_estimator_model(
mnist_classifier,
(train_data, train_labels), (train_data, train_labels),
(test_data, test_labels), (test_data, test_labels),
input_fn_constructor, input_fn_constructor,
['lr']) slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS]
)
attack_properties, attack_values = get_all_attack_results(attack_results)
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
zip(attack_properties, attack_values)]))
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -21,6 +21,9 @@ import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack import tf_estimator_evaluation from tensorflow_privacy.privacy.membership_inference_attack import tf_estimator_evaluation
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
class UtilsTest(absltest.TestCase): class UtilsTest(absltest.TestCase):
@ -77,15 +80,17 @@ class UtilsTest(absltest.TestCase):
def test_run_attack_helper(self): def test_run_attack_helper(self):
"""Test the attack.""" """Test the attack."""
results = tf_estimator_evaluation.run_attack_helper(self.classifier, results = tf_estimator_evaluation.run_attack_helper(
self.classifier,
self.input_fn_train, self.input_fn_train,
self.input_fn_test, self.input_fn_test,
self.train_labels, self.train_labels,
self.test_labels, self.test_labels,
[]) attack_types=[AttackType.THRESHOLD_ATTACK])
self.assertIsInstance(results, dict) self.assertIsInstance(results, AttackResults)
self.assertIn('all_thresh_loss_auc', results) attack_properties, attack_values = get_all_attack_results(results)
self.assertIn('all_thresh_loss_advantage', results) self.assertLen(attack_properties, 2)
self.assertLen(attack_values, 2)
def test_run_attack_on_tf_estimator_model(self): def test_run_attack_on_tf_estimator_model(self):
"""Test the attack on the final models.""" """Test the attack on the final models."""
@ -97,10 +102,11 @@ class UtilsTest(absltest.TestCase):
(self.train_data, self.train_labels), (self.train_data, self.train_labels),
(self.test_data, self.test_labels), (self.test_data, self.test_labels),
input_fn_constructor, input_fn_constructor,
[]) attack_types=[AttackType.THRESHOLD_ATTACK])
self.assertIsInstance(results, dict) self.assertIsInstance(results, AttackResults)
self.assertIn('all_thresh_loss_auc', results) attack_properties, attack_values = get_all_attack_results(results)
self.assertIn('all_thresh_loss_advantage', results) self.assertLen(attack_properties, 2)
self.assertLen(attack_values, 2)
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -20,6 +20,7 @@ from typing import Text, Dict, Union, List, Any, Tuple
import numpy as np import numpy as np
from sklearn import metrics from sklearn import metrics
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
ArrayDict = Dict[Text, np.ndarray] ArrayDict = Dict[Text, np.ndarray]
@ -73,6 +74,25 @@ def prepend_to_keys(in_dict: Dict[Text, Any], prefix: Text) -> Dict[Text, Any]:
return {prefix + k: v for k, v in in_dict.items()} return {prefix + k: v for k, v in in_dict.items()}
# ------------------------------------------------------------------------------
# Utilities for managing result.
# ------------------------------------------------------------------------------
def get_all_attack_results(results: AttackResults):
"""Get all results as a list of attack properties and a list of attack result."""
properties = []
values = []
for attack_result in results.single_attack_results:
slice_spec = attack_result.slice_spec
prop = [str(slice_spec), str(attack_result.attack_type)]
properties += [prop + ['adv'], prop + ['auc']]
values += [float(attack_result.get_attacker_advantage()),
float(attack_result.get_auc())]
return properties, values
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Subsampling and data selection functionality # Subsampling and data selection functionality
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------