Fix a bug in keras_evaluation and its example.
PiperOrigin-RevId: 420787967
This commit is contained in:
parent
867f3d4c55
commit
f301595ba5
2 changed files with 29 additions and 12 deletions
|
@ -30,19 +30,23 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils
|
|||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard
|
||||
|
||||
|
||||
def calculate_losses(model, data, labels):
|
||||
def calculate_losses(model, data, labels, is_logit=False, batch_size=32):
|
||||
"""Calculate losses of model prediction on data, provided true labels.
|
||||
|
||||
Args:
|
||||
model: model to make prediction
|
||||
data: samples
|
||||
labels: true labels of samples (integer valued)
|
||||
is_logit: whether the result of model.predict is logit or probability
|
||||
batch_size: the batch size for model.predict
|
||||
|
||||
Returns:
|
||||
preds: probability vector of each sample
|
||||
pred: probability vector of each sample
|
||||
loss: cross entropy loss of each sample
|
||||
"""
|
||||
pred = model.predict(data)
|
||||
pred = model.predict(data, batch_size=batch_size)
|
||||
if is_logit:
|
||||
pred = tf.nn.softmax(pred).numpy()
|
||||
loss = log_loss(labels, pred)
|
||||
return pred, loss
|
||||
|
||||
|
@ -56,7 +60,9 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
|||
slicing_spec: SlicingSpec = None,
|
||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
||||
tensorboard_dir=None,
|
||||
tensorboard_merge_classifiers=False):
|
||||
tensorboard_merge_classifiers=False,
|
||||
is_logit=False,
|
||||
batch_size=32):
|
||||
"""Initalizes the callback.
|
||||
|
||||
Args:
|
||||
|
@ -67,12 +73,16 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
|||
tensorboard_dir: directory for tensorboard summary
|
||||
tensorboard_merge_classifiers: if true, plot different classifiers with
|
||||
the same slicing_spec and metric in the same figure
|
||||
is_logit: whether the result of model.predict is logit or probability
|
||||
batch_size: the batch size for model.predict
|
||||
"""
|
||||
self._in_train_data, self._in_train_labels = in_train
|
||||
self._out_train_data, self._out_train_labels = out_train
|
||||
self._slicing_spec = slicing_spec
|
||||
self._attack_types = attack_types
|
||||
self._tensorboard_merge_classifiers = tensorboard_merge_classifiers
|
||||
self._is_logit = is_logit
|
||||
self._batch_size = batch_size
|
||||
if tensorboard_dir:
|
||||
if tensorboard_merge_classifiers:
|
||||
self._writers = {}
|
||||
|
@ -92,7 +102,8 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
|||
(self._in_train_data, self._in_train_labels),
|
||||
(self._out_train_data, self._out_train_labels),
|
||||
self._slicing_spec,
|
||||
self._attack_types)
|
||||
self._attack_types,
|
||||
self._is_logit, self._batch_size)
|
||||
logging.info(results)
|
||||
|
||||
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
||||
|
@ -110,7 +121,9 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
|||
def run_attack_on_keras_model(
|
||||
model, in_train, out_train,
|
||||
slicing_spec: SlicingSpec = None,
|
||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)):
|
||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
||||
is_logit: bool = False,
|
||||
batch_size: int = 32):
|
||||
"""Performs the attack on a trained model.
|
||||
|
||||
Args:
|
||||
|
@ -119,6 +132,8 @@ def run_attack_on_keras_model(
|
|||
out_train: a (out_training samples, out_training labels) tuple
|
||||
slicing_spec: slicing specification of the attack
|
||||
attack_types: a list of attacks, each of type AttackType
|
||||
is_logit: whether the result of model.predict is logit or probability
|
||||
batch_size: the batch size for model.predict
|
||||
Returns:
|
||||
Results of the attack
|
||||
"""
|
||||
|
@ -126,10 +141,10 @@ def run_attack_on_keras_model(
|
|||
out_train_data, out_train_labels = out_train
|
||||
|
||||
# Compute predictions and losses
|
||||
in_train_pred, in_train_loss = calculate_losses(model, in_train_data,
|
||||
in_train_labels)
|
||||
out_train_pred, out_train_loss = calculate_losses(model, out_train_data,
|
||||
out_train_labels)
|
||||
in_train_pred, in_train_loss = calculate_losses(
|
||||
model, in_train_data, in_train_labels, is_logit, batch_size)
|
||||
out_train_pred, out_train_loss = calculate_losses(
|
||||
model, out_train_data, out_train_labels, is_logit, batch_size)
|
||||
attack_input = AttackInputData(
|
||||
logits_train=in_train_pred, logits_test=out_train_pred,
|
||||
labels_train=in_train_labels, labels_test=out_train_labels,
|
||||
|
|
|
@ -83,7 +83,8 @@ def main(unused_argv):
|
|||
attack_types=[AttackType.THRESHOLD_ATTACK,
|
||||
AttackType.K_NEAREST_NEIGHBORS],
|
||||
tensorboard_dir=FLAGS.model_dir,
|
||||
tensorboard_merge_classifiers=FLAGS.tensorboard_merge_classifiers)
|
||||
tensorboard_merge_classifiers=FLAGS.tensorboard_merge_classifiers,
|
||||
is_logit=True, batch_size=2048)
|
||||
|
||||
# Train model with Keras
|
||||
model.fit(
|
||||
|
@ -101,7 +102,8 @@ def main(unused_argv):
|
|||
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
|
||||
attack_types=[
|
||||
AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS
|
||||
])
|
||||
],
|
||||
is_logit=True, batch_size=2048)
|
||||
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
||||
attack_results)
|
||||
print('\n'.join([' %s: %.4f' % (', '.join([s, t, m]), v) for t, s, m, v in
|
||||
|
|
Loading…
Reference in a new issue