Update tf_estimator_evaluation and keras_evaluation to new API.
PiperOrigin-RevId: 328195220
This commit is contained in:
parent
7a77d5d92c
commit
f90c78bd54
8 changed files with 194 additions and 99 deletions
|
@ -293,8 +293,8 @@ class RocCurve:
|
||||||
"""Returns AUC and advantage metrics."""
|
"""Returns AUC and advantage metrics."""
|
||||||
return '\n'.join([
|
return '\n'.join([
|
||||||
'RocCurve(',
|
'RocCurve(',
|
||||||
' AUC: %f.02' % self.get_auc(),
|
' AUC: %.2f' % self.get_auc(),
|
||||||
' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')'
|
' Attacker advantage: %.2f' % self.get_attacker_advantage(), ')'
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
|
@ -324,8 +324,8 @@ class SingleAttackResult:
|
||||||
'SingleAttackResult(',
|
'SingleAttackResult(',
|
||||||
' SliceSpec: %s' % str(self.slice_spec),
|
' SliceSpec: %s' % str(self.slice_spec),
|
||||||
' AttackType: %s' % str(self.attack_type),
|
' AttackType: %s' % str(self.attack_type),
|
||||||
' AUC: %f.02' % self.get_auc(),
|
' AUC: %.2f' % self.get_auc(),
|
||||||
' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')'
|
' Attacker advantage: %.2f' % self.get_attacker_advantage(), ')'
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,11 +15,17 @@
|
||||||
# Lint as: python3
|
# Lint as: python3
|
||||||
"""A callback and a function in keras for membership inference attack."""
|
"""A callback and a function in keras for membership inference attack."""
|
||||||
|
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
|
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
|
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
|
||||||
|
|
||||||
|
@ -44,20 +50,25 @@ def calculate_losses(model, data, labels):
|
||||||
class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
||||||
"""Callback to perform membership inference attack on epoch end."""
|
"""Callback to perform membership inference attack on epoch end."""
|
||||||
|
|
||||||
def __init__(self, in_train, out_train, attack_classifiers,
|
def __init__(
|
||||||
tensorboard_dir=None):
|
self,
|
||||||
|
in_train, out_train,
|
||||||
|
slicing_spec: SlicingSpec = None,
|
||||||
|
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
||||||
|
tensorboard_dir=None):
|
||||||
"""Initalizes the callback.
|
"""Initalizes the callback.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_train: (in_training samples, in_training labels)
|
in_train: (in_training samples, in_training labels)
|
||||||
out_train: (out_training samples, out_training labels)
|
out_train: (out_training samples, out_training labels)
|
||||||
attack_classifiers: a list of classifiers to be used by attacker, must be
|
slicing_spec: slicing specification of the attack
|
||||||
a subset of ['lr', 'mlp', 'rf', 'knn']
|
attack_types: a list of attacks, each of type AttackType
|
||||||
tensorboard_dir: directory for tensorboard summary
|
tensorboard_dir: directory for tensorboard summary
|
||||||
"""
|
"""
|
||||||
self._in_train_data, self._in_train_labels = in_train
|
self._in_train_data, self._in_train_labels = in_train
|
||||||
self._out_train_data, self._out_train_labels = out_train
|
self._out_train_data, self._out_train_labels = out_train
|
||||||
self._attack_classifiers = attack_classifiers
|
self._slicing_spec = slicing_spec
|
||||||
|
self._attack_types = attack_types
|
||||||
# Setup tensorboard writer if tensorboard_dir is specified
|
# Setup tensorboard writer if tensorboard_dir is specified
|
||||||
if tensorboard_dir:
|
if tensorboard_dir:
|
||||||
with tf.Graph().as_default():
|
with tf.Graph().as_default():
|
||||||
|
@ -71,24 +82,33 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
||||||
self.model,
|
self.model,
|
||||||
(self._in_train_data, self._in_train_labels),
|
(self._in_train_data, self._in_train_labels),
|
||||||
(self._out_train_data, self._out_train_labels),
|
(self._out_train_data, self._out_train_labels),
|
||||||
self._attack_classifiers)
|
self._slicing_spec,
|
||||||
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
|
self._attack_types)
|
||||||
logging.info(results)
|
logging.info(results)
|
||||||
|
|
||||||
|
attack_properties, attack_values = get_all_attack_results(results)
|
||||||
|
print('Attack result:')
|
||||||
|
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
|
||||||
|
zip(attack_properties, attack_values)]))
|
||||||
|
|
||||||
# Write to tensorboard if tensorboard_dir is specified
|
# Write to tensorboard if tensorboard_dir is specified
|
||||||
write_to_tensorboard(self._writer, ['attack advantage'],
|
attack_property_tags = ['attack/' + '_'.join(p) for p in attack_properties]
|
||||||
[results['all_thresh_loss_advantage']], epoch)
|
write_to_tensorboard(self._writer, attack_property_tags, attack_values,
|
||||||
|
epoch)
|
||||||
|
|
||||||
|
|
||||||
def run_attack_on_keras_model(model, in_train, out_train, attack_classifiers):
|
def run_attack_on_keras_model(
|
||||||
|
model, in_train, out_train,
|
||||||
|
slicing_spec: SlicingSpec = None,
|
||||||
|
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)):
|
||||||
"""Performs the attack on a trained model.
|
"""Performs the attack on a trained model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: model to be tested
|
model: model to be tested
|
||||||
in_train: a (in_training samples, in_training labels) tuple
|
in_train: a (in_training samples, in_training labels) tuple
|
||||||
out_train: a (out_training samples, out_training labels) tuple
|
out_train: a (out_training samples, out_training labels) tuple
|
||||||
attack_classifiers: a list of classifiers to be used by attacker, must be
|
slicing_spec: slicing specification of the attack
|
||||||
a subset of ['lr', 'mlp', 'rf', 'knn']
|
attack_types: a list of attacks, each of type AttackType
|
||||||
Returns:
|
Returns:
|
||||||
Results of the attack
|
Results of the attack
|
||||||
"""
|
"""
|
||||||
|
@ -100,9 +120,12 @@ def run_attack_on_keras_model(model, in_train, out_train, attack_classifiers):
|
||||||
in_train_labels)
|
in_train_labels)
|
||||||
out_train_pred, out_train_loss = calculate_losses(model, out_train_data,
|
out_train_pred, out_train_loss = calculate_losses(model, out_train_data,
|
||||||
out_train_labels)
|
out_train_labels)
|
||||||
results = mia.run_all_attacks(in_train_loss, out_train_loss,
|
attack_input = AttackInputData(
|
||||||
in_train_pred, out_train_pred,
|
logits_train=in_train_pred, logits_test=out_train_pred,
|
||||||
in_train_labels, out_train_labels,
|
labels_train=in_train_labels, labels_test=out_train_labels,
|
||||||
attack_classifiers=attack_classifiers)
|
loss_train=in_train_loss, loss_test=out_train_loss
|
||||||
|
)
|
||||||
|
results = mia.run_attacks(attack_input,
|
||||||
|
slicing_spec=slicing_spec,
|
||||||
|
attack_types=attack_types)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
|
@ -20,8 +20,12 @@ from absl import flags
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import MembershipInferenceCallback
|
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import MembershipInferenceCallback
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import run_attack_on_keras_model
|
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import run_attack_on_keras_model
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||||
|
|
||||||
|
|
||||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||||
|
|
||||||
|
@ -78,10 +82,11 @@ def main(unused_argv):
|
||||||
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
||||||
|
|
||||||
# Get callback for membership inference attack.
|
# Get callback for membership inference attack.
|
||||||
mia_callback = MembershipInferenceCallback((train_data, train_labels),
|
mia_callback = MembershipInferenceCallback(
|
||||||
(test_data, test_labels),
|
(train_data, train_labels),
|
||||||
[],
|
(test_data, test_labels),
|
||||||
FLAGS.model_dir)
|
attack_types=[AttackType.THRESHOLD_ATTACK],
|
||||||
|
tensorboard_dir=FLAGS.model_dir)
|
||||||
|
|
||||||
# Train model with Keras
|
# Train model with Keras
|
||||||
model.fit(train_data, train_labels,
|
model.fit(train_data, train_labels,
|
||||||
|
@ -91,13 +96,18 @@ def main(unused_argv):
|
||||||
callbacks=[mia_callback],
|
callbacks=[mia_callback],
|
||||||
verbose=2)
|
verbose=2)
|
||||||
|
|
||||||
print('End of training attack')
|
print('End of training attack:')
|
||||||
attack_results = run_attack_on_keras_model(model,
|
attack_results = run_attack_on_keras_model(
|
||||||
(train_data, train_labels),
|
model,
|
||||||
(test_data, test_labels),
|
(train_data, train_labels),
|
||||||
[])
|
(test_data, test_labels),
|
||||||
print('all_thresh_loss_advantage',
|
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
|
||||||
attack_results['all_thresh_loss_advantage'])
|
attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS]
|
||||||
|
)
|
||||||
|
|
||||||
|
attack_properties, attack_values = get_all_attack_results(attack_results)
|
||||||
|
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
|
||||||
|
zip(attack_properties, attack_values)]))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -21,6 +21,9 @@ import numpy as np
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import keras_evaluation
|
from tensorflow_privacy.privacy.membership_inference_attack import keras_evaluation
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||||
|
|
||||||
|
|
||||||
class UtilsTest(absltest.TestCase):
|
class UtilsTest(absltest.TestCase):
|
||||||
|
@ -62,10 +65,11 @@ class UtilsTest(absltest.TestCase):
|
||||||
self.model,
|
self.model,
|
||||||
(self.train_data, self.train_labels),
|
(self.train_data, self.train_labels),
|
||||||
(self.test_data, self.test_labels),
|
(self.test_data, self.test_labels),
|
||||||
[])
|
attack_types=[AttackType.THRESHOLD_ATTACK])
|
||||||
self.assertIsInstance(results, dict)
|
self.assertIsInstance(results, AttackResults)
|
||||||
self.assertIn('all_thresh_loss_auc', results)
|
attack_properties, attack_values = get_all_attack_results(results)
|
||||||
self.assertIn('all_thresh_loss_advantage', results)
|
self.assertLen(attack_properties, 2)
|
||||||
|
self.assertLen(attack_values, 2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -15,13 +15,19 @@
|
||||||
# Lint as: python3
|
# Lint as: python3
|
||||||
"""A hook and a function in tf estimator for membership inference attack."""
|
"""A hook and a function in tf estimator for membership inference attack."""
|
||||||
|
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
|
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
|
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
|
||||||
|
|
||||||
|
@ -49,16 +55,17 @@ def calculate_losses(estimator, input_fn, labels):
|
||||||
|
|
||||||
|
|
||||||
class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
||||||
"""Training hook to perform membership inference attack after an epoch."""
|
"""Training hook to perform membership inference attack on epoch end."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
estimator,
|
self,
|
||||||
in_train,
|
estimator,
|
||||||
out_train,
|
in_train, out_train,
|
||||||
input_fn_constructor,
|
input_fn_constructor,
|
||||||
attack_classifiers,
|
slicing_spec: SlicingSpec = None,
|
||||||
writer=None):
|
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
||||||
"""Initalizes the hook.
|
writer=None):
|
||||||
|
"""Initialize the hook.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
estimator: model to be tested
|
estimator: model to be tested
|
||||||
|
@ -66,8 +73,8 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
||||||
out_train: (out_training samples, out_training labels)
|
out_train: (out_training samples, out_training labels)
|
||||||
input_fn_constructor: a function that receives sample, label and construct
|
input_fn_constructor: a function that receives sample, label and construct
|
||||||
the input_fn for model prediction
|
the input_fn for model prediction
|
||||||
attack_classifiers: a list of classifiers to be used by attacker, must be
|
slicing_spec: slicing specification of the attack
|
||||||
a subset of ['lr', 'mlp', 'rf', 'knn']
|
attack_types: a list of attacks, each of type AttackType
|
||||||
writer: summary writer for tensorboard
|
writer: summary writer for tensorboard
|
||||||
"""
|
"""
|
||||||
in_train_data, self._in_train_labels = in_train
|
in_train_data, self._in_train_labels = in_train
|
||||||
|
@ -79,7 +86,8 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
||||||
self._out_train_input_fn = input_fn_constructor(out_train_data,
|
self._out_train_input_fn = input_fn_constructor(out_train_data,
|
||||||
self._out_train_labels)
|
self._out_train_labels)
|
||||||
self._estimator = estimator
|
self._estimator = estimator
|
||||||
self._attack_classifiers = attack_classifiers
|
self._slicing_spec = slicing_spec
|
||||||
|
self._attack_types = attack_types
|
||||||
self._writer = writer
|
self._writer = writer
|
||||||
if self._writer:
|
if self._writer:
|
||||||
logging.info('Will write to tensorboard.')
|
logging.info('Will write to tensorboard.')
|
||||||
|
@ -89,19 +97,28 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
||||||
self._in_train_input_fn,
|
self._in_train_input_fn,
|
||||||
self._out_train_input_fn,
|
self._out_train_input_fn,
|
||||||
self._in_train_labels, self._out_train_labels,
|
self._in_train_labels, self._out_train_labels,
|
||||||
self._attack_classifiers)
|
self._slicing_spec,
|
||||||
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
|
self._attack_types)
|
||||||
logging.info(results)
|
logging.info(results)
|
||||||
|
|
||||||
|
attack_properties, attack_values = get_all_attack_results(results)
|
||||||
|
print('Attack result:')
|
||||||
|
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
|
||||||
|
zip(attack_properties, attack_values)]))
|
||||||
|
|
||||||
# Write to tensorboard if writer is specified
|
# Write to tensorboard if writer is specified
|
||||||
global_step = self._estimator.get_variable_value('global_step')
|
global_step = self._estimator.get_variable_value('global_step')
|
||||||
write_to_tensorboard(self._writer, ['attack advantage'],
|
attack_property_tags = ['attack/' + '_'.join(p) for p in attack_properties]
|
||||||
[results['all_thresh_loss_advantage']], global_step)
|
write_to_tensorboard(self._writer, attack_property_tags, attack_values,
|
||||||
|
global_step)
|
||||||
|
|
||||||
|
|
||||||
def run_attack_on_tf_estimator_model(estimator, in_train, out_train,
|
def run_attack_on_tf_estimator_model(
|
||||||
input_fn_constructor, attack_classifiers):
|
estimator, in_train, out_train,
|
||||||
"""A function to perform the attack in the end of training.
|
input_fn_constructor,
|
||||||
|
slicing_spec: SlicingSpec = None,
|
||||||
|
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)):
|
||||||
|
"""Performs the attack in the end of training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
estimator: model to be tested
|
estimator: model to be tested
|
||||||
|
@ -109,8 +126,8 @@ def run_attack_on_tf_estimator_model(estimator, in_train, out_train,
|
||||||
out_train: (out_training samples, out_training labels)
|
out_train: (out_training samples, out_training labels)
|
||||||
input_fn_constructor: a function that receives sample, label and construct
|
input_fn_constructor: a function that receives sample, label and construct
|
||||||
the input_fn for model prediction
|
the input_fn for model prediction
|
||||||
attack_classifiers: a list of classifiers to be used by attacker, must be
|
slicing_spec: slicing specification of the attack
|
||||||
a subset of ['lr', 'mlp', 'rf', 'knn']
|
attack_types: a list of attacks, each of type AttackType
|
||||||
Returns:
|
Returns:
|
||||||
Results of the attack
|
Results of the attack
|
||||||
"""
|
"""
|
||||||
|
@ -125,17 +142,19 @@ def run_attack_on_tf_estimator_model(estimator, in_train, out_train,
|
||||||
results = run_attack_helper(estimator,
|
results = run_attack_helper(estimator,
|
||||||
in_train_input_fn, out_train_input_fn,
|
in_train_input_fn, out_train_input_fn,
|
||||||
in_train_labels, out_train_labels,
|
in_train_labels, out_train_labels,
|
||||||
attack_classifiers)
|
slicing_spec,
|
||||||
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
|
attack_types)
|
||||||
logging.info('End of training attack:')
|
logging.info('End of training attack:')
|
||||||
logging.info(results)
|
logging.info(results)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def run_attack_helper(estimator,
|
def run_attack_helper(
|
||||||
in_train_input_fn, out_train_input_fn,
|
estimator,
|
||||||
in_train_labels, out_train_labels,
|
in_train_input_fn, out_train_input_fn,
|
||||||
attack_classifiers):
|
in_train_labels, out_train_labels,
|
||||||
|
slicing_spec: SlicingSpec = None,
|
||||||
|
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)):
|
||||||
"""A helper function to perform attack.
|
"""A helper function to perform attack.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -144,8 +163,8 @@ def run_attack_helper(estimator,
|
||||||
out_train_input_fn: input_fn for out of training data
|
out_train_input_fn: input_fn for out of training data
|
||||||
in_train_labels: in training labels
|
in_train_labels: in training labels
|
||||||
out_train_labels: out of training labels
|
out_train_labels: out of training labels
|
||||||
attack_classifiers: a list of classifiers to be used by attacker, must be
|
slicing_spec: slicing specification of the attack
|
||||||
a subset of ['lr', 'mlp', 'rf', 'knn']
|
attack_types: a list of attacks, each of type AttackType
|
||||||
Returns:
|
Returns:
|
||||||
Results of the attack
|
Results of the attack
|
||||||
"""
|
"""
|
||||||
|
@ -156,9 +175,13 @@ def run_attack_helper(estimator,
|
||||||
out_train_pred, out_train_loss = calculate_losses(estimator,
|
out_train_pred, out_train_loss = calculate_losses(estimator,
|
||||||
out_train_input_fn,
|
out_train_input_fn,
|
||||||
out_train_labels)
|
out_train_labels)
|
||||||
results = mia.run_all_attacks(in_train_loss, out_train_loss,
|
attack_input = AttackInputData(
|
||||||
in_train_pred, out_train_pred,
|
logits_train=in_train_pred, logits_test=out_train_pred,
|
||||||
in_train_labels, out_train_labels,
|
labels_train=in_train_labels, labels_test=out_train_labels,
|
||||||
attack_classifiers=attack_classifiers)
|
loss_train=in_train_loss, loss_test=out_train_loss
|
||||||
|
)
|
||||||
|
results = mia.run_attacks(attack_input,
|
||||||
|
slicing_spec=slicing_spec,
|
||||||
|
attack_types=attack_types)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
|
@ -21,9 +21,11 @@ from absl import logging
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import MembershipInferenceTrainingHook
|
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import MembershipInferenceTrainingHook
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import run_attack_on_tf_estimator_model
|
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import run_attack_on_tf_estimator_model
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||||
|
|
||||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||||
|
|
||||||
|
@ -97,9 +99,9 @@ def load_mnist():
|
||||||
|
|
||||||
|
|
||||||
def main(unused_argv):
|
def main(unused_argv):
|
||||||
tf.logging.set_verbosity(tf.logging.INFO)
|
tf.logging.set_verbosity(tf.logging.ERROR)
|
||||||
logging.set_verbosity(logging.INFO)
|
logging.set_verbosity(logging.ERROR)
|
||||||
logging.set_stderrthreshold(logging.INFO)
|
logging.set_stderrthreshold(logging.ERROR)
|
||||||
logging.get_absl_handler().use_absl_log_file()
|
logging.get_absl_handler().use_absl_log_file()
|
||||||
|
|
||||||
# Load training and test data.
|
# Load training and test data.
|
||||||
|
@ -121,12 +123,13 @@ def main(unused_argv):
|
||||||
summary_writer = tf.summary.FileWriter(FLAGS.model_dir)
|
summary_writer = tf.summary.FileWriter(FLAGS.model_dir)
|
||||||
else:
|
else:
|
||||||
summary_writer = None
|
summary_writer = None
|
||||||
mia_hook = MembershipInferenceTrainingHook(mnist_classifier,
|
mia_hook = MembershipInferenceTrainingHook(
|
||||||
(train_data, train_labels),
|
mnist_classifier,
|
||||||
(test_data, test_labels),
|
(train_data, train_labels),
|
||||||
input_fn_constructor,
|
(test_data, test_labels),
|
||||||
[],
|
input_fn_constructor,
|
||||||
summary_writer)
|
attack_types=[AttackType.THRESHOLD_ATTACK],
|
||||||
|
writer=summary_writer)
|
||||||
|
|
||||||
# Create tf.Estimator input functions for the training and test data.
|
# Create tf.Estimator input functions for the training and test data.
|
||||||
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||||
|
@ -151,11 +154,17 @@ def main(unused_argv):
|
||||||
print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))
|
print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))
|
||||||
|
|
||||||
print('End of training attack')
|
print('End of training attack')
|
||||||
run_attack_on_tf_estimator_model(mnist_classifier,
|
attack_results = run_attack_on_tf_estimator_model(
|
||||||
(train_data, train_labels),
|
mnist_classifier,
|
||||||
(test_data, test_labels),
|
(train_data, train_labels),
|
||||||
input_fn_constructor,
|
(test_data, test_labels),
|
||||||
['lr'])
|
input_fn_constructor,
|
||||||
|
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
|
||||||
|
attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS]
|
||||||
|
)
|
||||||
|
attack_properties, attack_values = get_all_attack_results(attack_results)
|
||||||
|
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
|
||||||
|
zip(attack_properties, attack_values)]))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -21,6 +21,9 @@ import numpy as np
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import tf_estimator_evaluation
|
from tensorflow_privacy.privacy.membership_inference_attack import tf_estimator_evaluation
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||||
|
|
||||||
|
|
||||||
class UtilsTest(absltest.TestCase):
|
class UtilsTest(absltest.TestCase):
|
||||||
|
@ -77,15 +80,17 @@ class UtilsTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_run_attack_helper(self):
|
def test_run_attack_helper(self):
|
||||||
"""Test the attack."""
|
"""Test the attack."""
|
||||||
results = tf_estimator_evaluation.run_attack_helper(self.classifier,
|
results = tf_estimator_evaluation.run_attack_helper(
|
||||||
self.input_fn_train,
|
self.classifier,
|
||||||
self.input_fn_test,
|
self.input_fn_train,
|
||||||
self.train_labels,
|
self.input_fn_test,
|
||||||
self.test_labels,
|
self.train_labels,
|
||||||
[])
|
self.test_labels,
|
||||||
self.assertIsInstance(results, dict)
|
attack_types=[AttackType.THRESHOLD_ATTACK])
|
||||||
self.assertIn('all_thresh_loss_auc', results)
|
self.assertIsInstance(results, AttackResults)
|
||||||
self.assertIn('all_thresh_loss_advantage', results)
|
attack_properties, attack_values = get_all_attack_results(results)
|
||||||
|
self.assertLen(attack_properties, 2)
|
||||||
|
self.assertLen(attack_values, 2)
|
||||||
|
|
||||||
def test_run_attack_on_tf_estimator_model(self):
|
def test_run_attack_on_tf_estimator_model(self):
|
||||||
"""Test the attack on the final models."""
|
"""Test the attack on the final models."""
|
||||||
|
@ -97,10 +102,11 @@ class UtilsTest(absltest.TestCase):
|
||||||
(self.train_data, self.train_labels),
|
(self.train_data, self.train_labels),
|
||||||
(self.test_data, self.test_labels),
|
(self.test_data, self.test_labels),
|
||||||
input_fn_constructor,
|
input_fn_constructor,
|
||||||
[])
|
attack_types=[AttackType.THRESHOLD_ATTACK])
|
||||||
self.assertIsInstance(results, dict)
|
self.assertIsInstance(results, AttackResults)
|
||||||
self.assertIn('all_thresh_loss_auc', results)
|
attack_properties, attack_values = get_all_attack_results(results)
|
||||||
self.assertIn('all_thresh_loss_advantage', results)
|
self.assertLen(attack_properties, 2)
|
||||||
|
self.assertLen(attack_values, 2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -20,6 +20,7 @@ from typing import Text, Dict, Union, List, Any, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
|
|
||||||
|
|
||||||
ArrayDict = Dict[Text, np.ndarray]
|
ArrayDict = Dict[Text, np.ndarray]
|
||||||
|
@ -73,6 +74,25 @@ def prepend_to_keys(in_dict: Dict[Text, Any], prefix: Text) -> Dict[Text, Any]:
|
||||||
return {prefix + k: v for k, v in in_dict.items()}
|
return {prefix + k: v for k, v in in_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Utilities for managing result.
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_attack_results(results: AttackResults):
|
||||||
|
"""Get all results as a list of attack properties and a list of attack result."""
|
||||||
|
properties = []
|
||||||
|
values = []
|
||||||
|
for attack_result in results.single_attack_results:
|
||||||
|
slice_spec = attack_result.slice_spec
|
||||||
|
prop = [str(slice_spec), str(attack_result.attack_type)]
|
||||||
|
properties += [prop + ['adv'], prop + ['auc']]
|
||||||
|
values += [float(attack_result.get_attacker_advantage()),
|
||||||
|
float(attack_result.get_auc())]
|
||||||
|
|
||||||
|
return properties, values
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
# Subsampling and data selection functionality
|
# Subsampling and data selection functionality
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
|
Loading…
Reference in a new issue