Fix a bug in keras_evaluation and its example.

PiperOrigin-RevId: 420787967
This commit is contained in:
Shuang Song 2022-01-10 09:58:12 -08:00 committed by A. Unique TensorFlower
parent 867f3d4c55
commit f301595ba5
2 changed files with 29 additions and 12 deletions

View file

@ -30,19 +30,23 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard
def calculate_losses(model, data, labels): def calculate_losses(model, data, labels, is_logit=False, batch_size=32):
"""Calculate losses of model prediction on data, provided true labels. """Calculate losses of model prediction on data, provided true labels.
Args: Args:
model: model to make prediction model: model to make prediction
data: samples data: samples
labels: true labels of samples (integer valued) labels: true labels of samples (integer valued)
is_logit: whether the result of model.predict is logit or probability
batch_size: the batch size for model.predict
Returns: Returns:
preds: probability vector of each sample pred: probability vector of each sample
loss: cross entropy loss of each sample loss: cross entropy loss of each sample
""" """
pred = model.predict(data) pred = model.predict(data, batch_size=batch_size)
if is_logit:
pred = tf.nn.softmax(pred).numpy()
loss = log_loss(labels, pred) loss = log_loss(labels, pred)
return pred, loss return pred, loss
@ -56,7 +60,9 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
slicing_spec: SlicingSpec = None, slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,), attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
tensorboard_dir=None, tensorboard_dir=None,
tensorboard_merge_classifiers=False): tensorboard_merge_classifiers=False,
is_logit=False,
batch_size=32):
"""Initalizes the callback. """Initalizes the callback.
Args: Args:
@ -67,12 +73,16 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
tensorboard_dir: directory for tensorboard summary tensorboard_dir: directory for tensorboard summary
tensorboard_merge_classifiers: if true, plot different classifiers with tensorboard_merge_classifiers: if true, plot different classifiers with
the same slicing_spec and metric in the same figure the same slicing_spec and metric in the same figure
is_logit: whether the result of model.predict is logit or probability
batch_size: the batch size for model.predict
""" """
self._in_train_data, self._in_train_labels = in_train self._in_train_data, self._in_train_labels = in_train
self._out_train_data, self._out_train_labels = out_train self._out_train_data, self._out_train_labels = out_train
self._slicing_spec = slicing_spec self._slicing_spec = slicing_spec
self._attack_types = attack_types self._attack_types = attack_types
self._tensorboard_merge_classifiers = tensorboard_merge_classifiers self._tensorboard_merge_classifiers = tensorboard_merge_classifiers
self._is_logit = is_logit
self._batch_size = batch_size
if tensorboard_dir: if tensorboard_dir:
if tensorboard_merge_classifiers: if tensorboard_merge_classifiers:
self._writers = {} self._writers = {}
@ -92,7 +102,8 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
(self._in_train_data, self._in_train_labels), (self._in_train_data, self._in_train_labels),
(self._out_train_data, self._out_train_labels), (self._out_train_data, self._out_train_labels),
self._slicing_spec, self._slicing_spec,
self._attack_types) self._attack_types,
self._is_logit, self._batch_size)
logging.info(results) logging.info(results)
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
@ -110,7 +121,9 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
def run_attack_on_keras_model( def run_attack_on_keras_model(
model, in_train, out_train, model, in_train, out_train,
slicing_spec: SlicingSpec = None, slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)): attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
is_logit: bool = False,
batch_size: int = 32):
"""Performs the attack on a trained model. """Performs the attack on a trained model.
Args: Args:
@ -119,6 +132,8 @@ def run_attack_on_keras_model(
out_train: a (out_training samples, out_training labels) tuple out_train: a (out_training samples, out_training labels) tuple
slicing_spec: slicing specification of the attack slicing_spec: slicing specification of the attack
attack_types: a list of attacks, each of type AttackType attack_types: a list of attacks, each of type AttackType
is_logit: whether the result of model.predict is logit or probability
batch_size: the batch size for model.predict
Returns: Returns:
Results of the attack Results of the attack
""" """
@ -126,10 +141,10 @@ def run_attack_on_keras_model(
out_train_data, out_train_labels = out_train out_train_data, out_train_labels = out_train
# Compute predictions and losses # Compute predictions and losses
in_train_pred, in_train_loss = calculate_losses(model, in_train_data, in_train_pred, in_train_loss = calculate_losses(
in_train_labels) model, in_train_data, in_train_labels, is_logit, batch_size)
out_train_pred, out_train_loss = calculate_losses(model, out_train_data, out_train_pred, out_train_loss = calculate_losses(
out_train_labels) model, out_train_data, out_train_labels, is_logit, batch_size)
attack_input = AttackInputData( attack_input = AttackInputData(
logits_train=in_train_pred, logits_test=out_train_pred, logits_train=in_train_pred, logits_test=out_train_pred,
labels_train=in_train_labels, labels_test=out_train_labels, labels_train=in_train_labels, labels_test=out_train_labels,

View file

@ -83,7 +83,8 @@ def main(unused_argv):
attack_types=[AttackType.THRESHOLD_ATTACK, attack_types=[AttackType.THRESHOLD_ATTACK,
AttackType.K_NEAREST_NEIGHBORS], AttackType.K_NEAREST_NEIGHBORS],
tensorboard_dir=FLAGS.model_dir, tensorboard_dir=FLAGS.model_dir,
tensorboard_merge_classifiers=FLAGS.tensorboard_merge_classifiers) tensorboard_merge_classifiers=FLAGS.tensorboard_merge_classifiers,
is_logit=True, batch_size=2048)
# Train model with Keras # Train model with Keras
model.fit( model.fit(
@ -101,7 +102,8 @@ def main(unused_argv):
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True), slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
attack_types=[ attack_types=[
AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS
]) ],
is_logit=True, batch_size=2048)
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
attack_results) attack_results)
print('\n'.join([' %s: %.4f' % (', '.join([s, t, m]), v) for t, s, m, v in print('\n'.join([' %s: %.4f' % (', '.join([s, t, m]), v) for t, s, m, v in