Fix a bug in keras_evaluation and its example.
PiperOrigin-RevId: 420787967
This commit is contained in:
parent
867f3d4c55
commit
f301595ba5
2 changed files with 29 additions and 12 deletions
|
@ -30,19 +30,23 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard
|
||||||
|
|
||||||
|
|
||||||
def calculate_losses(model, data, labels):
|
def calculate_losses(model, data, labels, is_logit=False, batch_size=32):
|
||||||
"""Calculate losses of model prediction on data, provided true labels.
|
"""Calculate losses of model prediction on data, provided true labels.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: model to make prediction
|
model: model to make prediction
|
||||||
data: samples
|
data: samples
|
||||||
labels: true labels of samples (integer valued)
|
labels: true labels of samples (integer valued)
|
||||||
|
is_logit: whether the result of model.predict is logit or probability
|
||||||
|
batch_size: the batch size for model.predict
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
preds: probability vector of each sample
|
pred: probability vector of each sample
|
||||||
loss: cross entropy loss of each sample
|
loss: cross entropy loss of each sample
|
||||||
"""
|
"""
|
||||||
pred = model.predict(data)
|
pred = model.predict(data, batch_size=batch_size)
|
||||||
|
if is_logit:
|
||||||
|
pred = tf.nn.softmax(pred).numpy()
|
||||||
loss = log_loss(labels, pred)
|
loss = log_loss(labels, pred)
|
||||||
return pred, loss
|
return pred, loss
|
||||||
|
|
||||||
|
@ -56,7 +60,9 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
||||||
slicing_spec: SlicingSpec = None,
|
slicing_spec: SlicingSpec = None,
|
||||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
||||||
tensorboard_dir=None,
|
tensorboard_dir=None,
|
||||||
tensorboard_merge_classifiers=False):
|
tensorboard_merge_classifiers=False,
|
||||||
|
is_logit=False,
|
||||||
|
batch_size=32):
|
||||||
"""Initalizes the callback.
|
"""Initalizes the callback.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -67,12 +73,16 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
||||||
tensorboard_dir: directory for tensorboard summary
|
tensorboard_dir: directory for tensorboard summary
|
||||||
tensorboard_merge_classifiers: if true, plot different classifiers with
|
tensorboard_merge_classifiers: if true, plot different classifiers with
|
||||||
the same slicing_spec and metric in the same figure
|
the same slicing_spec and metric in the same figure
|
||||||
|
is_logit: whether the result of model.predict is logit or probability
|
||||||
|
batch_size: the batch size for model.predict
|
||||||
"""
|
"""
|
||||||
self._in_train_data, self._in_train_labels = in_train
|
self._in_train_data, self._in_train_labels = in_train
|
||||||
self._out_train_data, self._out_train_labels = out_train
|
self._out_train_data, self._out_train_labels = out_train
|
||||||
self._slicing_spec = slicing_spec
|
self._slicing_spec = slicing_spec
|
||||||
self._attack_types = attack_types
|
self._attack_types = attack_types
|
||||||
self._tensorboard_merge_classifiers = tensorboard_merge_classifiers
|
self._tensorboard_merge_classifiers = tensorboard_merge_classifiers
|
||||||
|
self._is_logit = is_logit
|
||||||
|
self._batch_size = batch_size
|
||||||
if tensorboard_dir:
|
if tensorboard_dir:
|
||||||
if tensorboard_merge_classifiers:
|
if tensorboard_merge_classifiers:
|
||||||
self._writers = {}
|
self._writers = {}
|
||||||
|
@ -92,7 +102,8 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
||||||
(self._in_train_data, self._in_train_labels),
|
(self._in_train_data, self._in_train_labels),
|
||||||
(self._out_train_data, self._out_train_labels),
|
(self._out_train_data, self._out_train_labels),
|
||||||
self._slicing_spec,
|
self._slicing_spec,
|
||||||
self._attack_types)
|
self._attack_types,
|
||||||
|
self._is_logit, self._batch_size)
|
||||||
logging.info(results)
|
logging.info(results)
|
||||||
|
|
||||||
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
||||||
|
@ -110,7 +121,9 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
||||||
def run_attack_on_keras_model(
|
def run_attack_on_keras_model(
|
||||||
model, in_train, out_train,
|
model, in_train, out_train,
|
||||||
slicing_spec: SlicingSpec = None,
|
slicing_spec: SlicingSpec = None,
|
||||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)):
|
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
||||||
|
is_logit: bool = False,
|
||||||
|
batch_size: int = 32):
|
||||||
"""Performs the attack on a trained model.
|
"""Performs the attack on a trained model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -119,6 +132,8 @@ def run_attack_on_keras_model(
|
||||||
out_train: a (out_training samples, out_training labels) tuple
|
out_train: a (out_training samples, out_training labels) tuple
|
||||||
slicing_spec: slicing specification of the attack
|
slicing_spec: slicing specification of the attack
|
||||||
attack_types: a list of attacks, each of type AttackType
|
attack_types: a list of attacks, each of type AttackType
|
||||||
|
is_logit: whether the result of model.predict is logit or probability
|
||||||
|
batch_size: the batch size for model.predict
|
||||||
Returns:
|
Returns:
|
||||||
Results of the attack
|
Results of the attack
|
||||||
"""
|
"""
|
||||||
|
@ -126,10 +141,10 @@ def run_attack_on_keras_model(
|
||||||
out_train_data, out_train_labels = out_train
|
out_train_data, out_train_labels = out_train
|
||||||
|
|
||||||
# Compute predictions and losses
|
# Compute predictions and losses
|
||||||
in_train_pred, in_train_loss = calculate_losses(model, in_train_data,
|
in_train_pred, in_train_loss = calculate_losses(
|
||||||
in_train_labels)
|
model, in_train_data, in_train_labels, is_logit, batch_size)
|
||||||
out_train_pred, out_train_loss = calculate_losses(model, out_train_data,
|
out_train_pred, out_train_loss = calculate_losses(
|
||||||
out_train_labels)
|
model, out_train_data, out_train_labels, is_logit, batch_size)
|
||||||
attack_input = AttackInputData(
|
attack_input = AttackInputData(
|
||||||
logits_train=in_train_pred, logits_test=out_train_pred,
|
logits_train=in_train_pred, logits_test=out_train_pred,
|
||||||
labels_train=in_train_labels, labels_test=out_train_labels,
|
labels_train=in_train_labels, labels_test=out_train_labels,
|
||||||
|
|
|
@ -83,7 +83,8 @@ def main(unused_argv):
|
||||||
attack_types=[AttackType.THRESHOLD_ATTACK,
|
attack_types=[AttackType.THRESHOLD_ATTACK,
|
||||||
AttackType.K_NEAREST_NEIGHBORS],
|
AttackType.K_NEAREST_NEIGHBORS],
|
||||||
tensorboard_dir=FLAGS.model_dir,
|
tensorboard_dir=FLAGS.model_dir,
|
||||||
tensorboard_merge_classifiers=FLAGS.tensorboard_merge_classifiers)
|
tensorboard_merge_classifiers=FLAGS.tensorboard_merge_classifiers,
|
||||||
|
is_logit=True, batch_size=2048)
|
||||||
|
|
||||||
# Train model with Keras
|
# Train model with Keras
|
||||||
model.fit(
|
model.fit(
|
||||||
|
@ -101,7 +102,8 @@ def main(unused_argv):
|
||||||
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
|
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
|
||||||
attack_types=[
|
attack_types=[
|
||||||
AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS
|
AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS
|
||||||
])
|
],
|
||||||
|
is_logit=True, batch_size=2048)
|
||||||
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
||||||
attack_results)
|
attack_results)
|
||||||
print('\n'.join([' %s: %.4f' % (', '.join([s, t, m]), v) for t, s, m, v in
|
print('\n'.join([' %s: %.4f' % (', '.join([s, t, m]), v) for t, s, m, v in
|
||||||
|
|
Loading…
Reference in a new issue