Fixing calculating loss on logits.

PiperOrigin-RevId: 329966058
This commit is contained in:
Vadym Doroshenko 2020-09-03 12:06:07 -07:00 committed by A. Unique TensorFlower
parent f4fc9b2623
commit 8f3a61b50d
10 changed files with 144 additions and 109 deletions

View file

@ -22,7 +22,7 @@ from dataclasses import dataclass
import numpy as np
import pandas as pd
from sklearn import metrics
import tensorflow_privacy.privacy.membership_inference_attack.utils as utils
ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)'
@ -173,21 +173,19 @@ class AttackInputData:
'Please set labels_train and labels_test')
return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1
@staticmethod
def _get_loss(logits: np.ndarray, true_labels: np.ndarray):
return logits[range(logits.shape[0]), true_labels]
def get_loss_train(self):
"""Calculates cross-entropy losses for the training set."""
if self.loss_train is not None:
return self.loss_train
return self._get_loss(self.logits_train, self.labels_train)
"""Calculates (if needed) cross-entropy losses for the training set."""
if self.loss_train is None:
self.loss_train = utils.log_loss_from_logits(self.labels_train,
self.logits_train)
return self.loss_train
def get_loss_test(self):
"""Calculates cross-entropy losses for the test set."""
if self.loss_test is not None:
return self.loss_test
return self._get_loss(self.logits_test, self.labels_test)
"""Calculates (if needed) cross-entropy losses for the test set."""
if self.loss_test is None:
self.loss_test = utils.log_loss_from_logits(self.labels_test,
self.logits_test)
return self.loss_test
def get_train_size(self):
"""Returns size of the training set."""
@ -365,11 +363,13 @@ class AttackResults:
advantages.append(float(attack_result.get_attacker_advantage()))
aucs.append(float(attack_result.get_auc()))
df = pd.DataFrame({'slice feature': slice_features,
'slice value': slice_values,
'attack type': attack_types,
'attack advantage': advantages,
'roc auc': aucs})
df = pd.DataFrame({
'slice feature': slice_features,
'slice value': slice_values,
'attack type': attack_types,
'attack advantage': advantages,
'roc auc': aucs
})
return df
def summary(self, by_slices=False) -> str:
@ -452,3 +452,27 @@ class AttackResults:
"""Loads AttackResults from a pickle file."""
with open(filepath, 'rb') as inp:
return pickle.load(inp)
def get_flattened_attack_metrics(results: AttackResults):
"""Get flattened attack metrics.
Args:
results: membership inference attack results.
Returns:
properties: a list of (slice, attack_type, metric name)
values: a list of metric values, i-th element correspond to properties[i]
"""
properties = []
values = []
for attack_result in results.single_attack_results:
slice_spec = attack_result.slice_spec
prop = [str(slice_spec), str(attack_result.attack_type)]
properties += [prop + ['adv'], prop + ['auc']]
values += [
float(attack_result.get_attacker_advantage()),
float(attack_result.get_auc())
]
return properties, values

View file

@ -47,13 +47,15 @@ class AttackInputDataTest(absltest.TestCase):
def test_get_loss(self):
attack_input = AttackInputData(
logits_train=np.array([[0.3, 0.5, 0.2], [0.2, 0.3, 0.5]]),
logits_test=np.array([[0.2, 0.3, 0.5], [0.3, 0.5, 0.2]]),
logits_train=np.array([[-0.3, 1.5, 0.2], [2, 3, 0.5]]),
logits_test=np.array([[2, 0.3, 0.2], [0.3, -0.5, 0.2]]),
labels_train=np.array([1, 0]),
labels_test=np.array([0, 1]))
labels_test=np.array([0, 2]))
np.testing.assert_equal(attack_input.get_loss_train().tolist(), [0.5, 0.2])
np.testing.assert_equal(attack_input.get_loss_test().tolist(), [0.2, 0.5])
np.testing.assert_allclose(
attack_input.get_loss_train(), [0.36313551, 1.37153903], atol=1e-7)
np.testing.assert_allclose(
attack_input.get_loss_test(), [0.29860897, 0.95618669], atol=1e-7)
def test_get_loss_explicitly_provided(self):
attack_input = AttackInputData(
@ -237,8 +239,9 @@ class AttackResultsTest(absltest.TestCase):
self.assertEqual(repr(results), repr(loaded_results))
def test_calculate_pd_dataframe(self):
single_results = [self.perfect_classifier_result,
self.random_classifier_result]
single_results = [
self.perfect_classifier_result, self.random_classifier_result
]
results = AttackResults(single_results)
df = results.calculate_pd_dataframe()
df_expected = pd.DataFrame({

View file

@ -24,8 +24,8 @@ import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
@ -86,7 +86,7 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
self._attack_types)
logging.info(results)
attack_properties, attack_values = get_all_attack_results(results)
attack_properties, attack_values = get_flattened_attack_metrics(results)
print('Attack result:')
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
zip(attack_properties, attack_values)]))

View file

@ -21,11 +21,10 @@ from absl import flags
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import MembershipInferenceCallback
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import run_attack_on_keras_model
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
@ -40,11 +39,16 @@ flags.DEFINE_string('model_dir', None, 'Model directory.')
def cnn_model():
"""Define a CNN model."""
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(16, 8, strides=2, padding='same',
activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(
16,
8,
strides=2,
padding='same',
activation='relu',
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPool2D(2, 1),
tf.keras.layers.Conv2D(32, 4, strides=2, padding='valid',
activation='relu'),
tf.keras.layers.Conv2D(
32, 4, strides=2, padding='valid', activation='relu'),
tf.keras.layers.MaxPool2D(2, 1),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(32, activation='relu'),
@ -83,31 +87,34 @@ def main(unused_argv):
# Get callback for membership inference attack.
mia_callback = MembershipInferenceCallback(
(train_data, train_labels),
(test_data, test_labels),
(train_data, train_labels), (test_data, test_labels),
attack_types=[AttackType.THRESHOLD_ATTACK],
tensorboard_dir=FLAGS.model_dir)
# Train model with Keras
model.fit(train_data, train_labels,
epochs=FLAGS.epochs,
validation_data=(test_data, test_labels),
batch_size=FLAGS.batch_size,
callbacks=[mia_callback],
verbose=2)
model.fit(
train_data,
train_labels,
epochs=FLAGS.epochs,
validation_data=(test_data, test_labels),
batch_size=FLAGS.batch_size,
callbacks=[mia_callback],
verbose=2)
print('End of training attack:')
attack_results = run_attack_on_keras_model(
model,
(train_data, train_labels),
(test_data, test_labels),
model, (train_data, train_labels), (test_data, test_labels),
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS]
)
attack_types=[
AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS
])
attack_properties, attack_values = get_all_attack_results(attack_results)
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
zip(attack_properties, attack_values)]))
attack_properties, attack_values = get_flattened_attack_metrics(
attack_results)
print('\n'.join([
' %s: %.4f' % (', '.join(p), r)
for p, r in zip(attack_properties, attack_values)
]))
if __name__ == '__main__':

View file

@ -23,7 +23,7 @@ import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack import keras_evaluation
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
class UtilsTest(absltest.TestCase):
@ -67,7 +67,7 @@ class UtilsTest(absltest.TestCase):
(self.test_data, self.test_labels),
attack_types=[AttackType.THRESHOLD_ATTACK])
self.assertIsInstance(results, AttackResults)
attack_properties, attack_values = get_all_attack_results(results)
attack_properties, attack_values = get_flattened_attack_metrics(results)
self.assertLen(attack_properties, 2)
self.assertLen(attack_values, 2)

View file

@ -26,8 +26,8 @@ import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
@ -101,7 +101,7 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
self._attack_types)
logging.info(results)
attack_properties, attack_values = get_all_attack_results(results)
attack_properties, attack_values = get_flattened_attack_metrics(results)
print('Attack result:')
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
zip(attack_properties, attack_values)]))

View file

@ -22,10 +22,10 @@ from absl import logging
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import MembershipInferenceTrainingHook
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import run_attack_on_tf_estimator_model
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
@ -63,9 +63,7 @@ def cnn_model_fn(features, labels, mode):
global_step = tf.train.get_global_step()
train_op = optimizer.minimize(loss=scalar_loss, global_step=global_step)
return tf.estimator.EstimatorSpec(
mode=mode,
loss=scalar_loss,
train_op=train_op)
mode=mode, loss=scalar_loss, train_op=train_op)
# Add evaluation metrics (for EVAL mode).
elif mode == tf.estimator.ModeKeys.EVAL:
@ -108,8 +106,8 @@ def main(unused_argv):
train_data, train_labels, test_data, test_labels = load_mnist()
# Instantiate the tf.Estimator.
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn,
model_dir=FLAGS.model_dir)
mnist_classifier = tf.estimator.Estimator(
model_fn=cnn_model_fn, model_dir=FLAGS.model_dir)
# A function to construct input_fn given (data, label), to be used by the
# membership inference training hook.
@ -124,9 +122,7 @@ def main(unused_argv):
else:
summary_writer = None
mia_hook = MembershipInferenceTrainingHook(
mnist_classifier,
(train_data, train_labels),
(test_data, test_labels),
mnist_classifier, (train_data, train_labels), (test_data, test_labels),
input_fn_constructor,
attack_types=[AttackType.THRESHOLD_ATTACK],
writer=summary_writer)
@ -145,8 +141,8 @@ def main(unused_argv):
steps_per_epoch = 60000 // FLAGS.batch_size
for epoch in range(1, FLAGS.epochs + 1):
# Train the model, with the membership inference hook.
mnist_classifier.train(input_fn=train_input_fn, steps=steps_per_epoch,
hooks=[mia_hook])
mnist_classifier.train(
input_fn=train_input_fn, steps=steps_per_epoch, hooks=[mia_hook])
# Evaluate the model and print results
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
@ -155,16 +151,18 @@ def main(unused_argv):
print('End of training attack')
attack_results = run_attack_on_tf_estimator_model(
mnist_classifier,
(train_data, train_labels),
(test_data, test_labels),
mnist_classifier, (train_data, train_labels), (test_data, test_labels),
input_fn_constructor,
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS]
)
attack_properties, attack_values = get_all_attack_results(attack_results)
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
zip(attack_properties, attack_values)]))
attack_types=[
AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS
])
attack_properties, attack_values = get_flattened_attack_metrics(
attack_results)
print('\n'.join([
' %s: %.4f' % (', '.join(p), r)
for p, r in zip(attack_properties, attack_values)
]))
if __name__ == '__main__':

View file

@ -23,7 +23,7 @@ import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack import tf_estimator_evaluation
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
class UtilsTest(absltest.TestCase):
@ -88,7 +88,7 @@ class UtilsTest(absltest.TestCase):
self.test_labels,
attack_types=[AttackType.THRESHOLD_ATTACK])
self.assertIsInstance(results, AttackResults)
attack_properties, attack_values = get_all_attack_results(results)
attack_properties, attack_values = get_flattened_attack_metrics(results)
self.assertLen(attack_properties, 2)
self.assertLen(attack_values, 2)
@ -104,7 +104,7 @@ class UtilsTest(absltest.TestCase):
input_fn_constructor,
attack_types=[AttackType.THRESHOLD_ATTACK])
self.assertIsInstance(results, AttackResults)
attack_properties, attack_values = get_all_attack_results(results)
attack_properties, attack_values = get_flattened_attack_metrics(results)
self.assertLen(attack_properties, 2)
self.assertLen(attack_values, 2)

View file

@ -18,10 +18,9 @@
from typing import Text, Dict, Union, List, Any, Tuple
import numpy as np
import scipy.special
from sklearn import metrics
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
ArrayDict = Dict[Text, np.ndarray]
Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
@ -74,25 +73,6 @@ def prepend_to_keys(in_dict: Dict[Text, Any], prefix: Text) -> Dict[Text, Any]:
return {prefix + k: v for k, v in in_dict.items()}
# ------------------------------------------------------------------------------
# Utilities for managing result.
# ------------------------------------------------------------------------------
def get_all_attack_results(results: AttackResults):
"""Get all results as a list of attack properties and a list of attack result."""
properties = []
values = []
for attack_result in results.single_attack_results:
slice_spec = attack_result.slice_spec
prop = [str(slice_spec), str(attack_result.attack_type)]
properties += [prop + ['adv'], prop + ['auc']]
values += [float(attack_result.get_attacker_advantage()),
float(attack_result.get_auc())]
return properties, values
# ------------------------------------------------------------------------------
# Subsampling and data selection functionality
# ------------------------------------------------------------------------------
@ -245,19 +225,24 @@ def compute_performance_metrics(true_labels: np.ndarray,
# ------------------------------------------------------------------------------
def log_loss(y, pred, small_value=1e-8):
def log_loss(labels: np.ndarray, pred: np.ndarray, small_value=1e-8):
"""Compute the cross entropy loss.
Args:
y: numpy array, y[i] is the true label (scalar) of the i-th sample
labels: numpy array, labels[i] is the true label (scalar) of the i-th sample
pred: numpy array, pred[i] is the probability vector of the i-th sample
small_value: np.log can become -inf if the probability is too close to 0,
so the probability is clipped below by small_value.
small_value: np.log can become -inf if the probability is too close to 0, so
the probability is clipped below by small_value.
Returns:
the cross-entropy loss of each sample
"""
return -np.log(np.maximum(pred[range(y.size), y], small_value))
return -np.log(np.maximum(pred[range(labels.size), labels], small_value))
def log_loss_from_logits(labels: np.ndarray, logits: np.ndarray):
"""Compute the cross entropy loss from logits."""
return log_loss(labels, scipy.special.softmax(logits, axis=-1))
# ------------------------------------------------------------------------------

View file

@ -39,8 +39,10 @@ class UtilsTest(absltest.TestCase):
results = utils.compute_performance_metrics(true, pred, threshold=0.5)
for k in ['precision', 'recall', 'accuracy', 'f1_score', 'fpr', 'tpr',
'thresholds', 'auc', 'advantage']:
for k in [
'precision', 'recall', 'accuracy', 'f1_score', 'fpr', 'tpr',
'thresholds', 'auc', 'advantage'
]:
self.assertIn(k, results)
np.testing.assert_almost_equal(results['accuracy'], 1. / 2.)
@ -107,10 +109,16 @@ class UtilsTest(absltest.TestCase):
[0.75, 0.25], [0.9, 0.1], [0.99, 0.01]])
# Test the cases when true label (for all samples) is 0 and 1
expected_losses = {
0: np.array([4.60517019, 2.30258509, 1.38629436, 0.69314718, 0.28768207,
0.10536052, 0.01005034]),
1: np.array([0.01005034, 0.10536052, 0.28768207, 0.69314718, 1.38629436,
2.30258509, 4.60517019])
0:
np.array([
4.60517019, 2.30258509, 1.38629436, 0.69314718, 0.28768207,
0.10536052, 0.01005034
]),
1:
np.array([
0.01005034, 0.10536052, 0.28768207, 0.69314718, 1.38629436,
2.30258509, 4.60517019
])
}
for c in [0, 1]: # true label
y = np.ones(shape=pred.shape[0], dtype=int) * c
@ -139,8 +147,18 @@ class UtilsTest(absltest.TestCase):
expected_losses = np.array([18.42068074, 46.05170186, 115.12925465])
for i, small_value in enumerate(small_values):
loss = utils.log_loss(y, pred, small_value)
np.testing.assert_allclose(loss, np.array([expected_losses[i], 0]),
atol=1e-7)
np.testing.assert_allclose(
loss, np.array([expected_losses[i], 0]), atol=1e-7)
def test_log_loss_from_logits(self):
"""Test computing cross-entropy loss from logits."""
logits = np.array([[1, 2, 0, -1], [1, 2, 0, -1], [-1, 3, 0, 0]])
labels = np.array([0, 3, 1])
expected_loss = np.array([1.4401897, 3.4401897, 0.11144278])
loss = utils.log_loss_from_logits(labels, logits)
np.testing.assert_allclose(expected_loss, loss, atol=1e-7)
if __name__ == '__main__':