From ec7d44237c4082b536be301cc5f50548060505b7 Mon Sep 17 00:00:00 2001 From: Shuang Song Date: Tue, 22 Feb 2022 12:17:24 -0800 Subject: [PATCH] Allow customized loss functions for membership inference attack. PiperOrigin-RevId: 430267951 --- .../data_structures.py | 88 +++++++++--- .../data_structures_test.py | 132 +++++++++++++++++- .../membership_inference_attack/utils.py | 53 +++++-- .../membership_inference_attack/utils_test.py | 124 ++++++++++------ 4 files changed, 321 insertions(+), 76 deletions(-) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index b0addd9..dda0d79 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -19,7 +19,7 @@ import enum import glob import os import pickle -from typing import Any, Iterable, MutableSequence, Optional, Union +from typing import Any, Callable, Iterable, MutableSequence, Optional, Union import numpy as np import pandas as pd @@ -165,6 +165,12 @@ def _log_value(probs, small_value=1e-30): return -np.log(np.maximum(probs, small_value)) +class LossFunction(enum.Enum): + """An enum that defines loss function to use in `AttackInputData`.""" + CROSS_ENTROPY = 'cross_entropy' + SQUARED = 'squared' + + @dataclasses.dataclass class AttackInputData: """Input data for running an attack. @@ -196,6 +202,17 @@ class AttackInputData: entropy_train: Optional[np.ndarray] = None entropy_test: Optional[np.ndarray] = None + # If loss is not explicitly specified, this function will be used to derive + # loss from logits and labels. It can be a pre-defined `LossFunction`. + # If a callable is provided, it should take in two argument, the 1st is + # labels, the 2nd is logits or probs. + loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray], + LossFunction] = LossFunction.CROSS_ENTROPY + # Whether `loss_function` will be called with logits or probs. If not set + # (None), will decide by availablity of logits and probs and logits is + # preferred when both are available. + loss_function_using_logits: Optional[bool] = None + @property def num_classes(self): if self.labels_train is None or self.labels_test is None: @@ -248,21 +265,58 @@ class AttackInputData: true_labels] return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1) + @staticmethod + def _get_loss( + loss: Optional[np.ndarray], labels: Optional[np.ndarray], + logits: Optional[np.ndarray], probs: Optional[np.ndarray], + loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray], + LossFunction], + loss_function_using_logits: Optional[bool]) -> Optional[np.ndarray]: + """Calculates (if needed) losses. + + Args: + loss: the loss of each example. + labels: the scalar label of each example. + logits: the logits vector of each example. + probs: the probability vector of each example. + loss_function: if `loss` is not available, `labels` and one of `logits` + and `probs` are available, we will use this function to compute loss. It + is supposed to take in (label, logits / probs) as input. + loss_function_using_logits: if `loss_function` expects `logits` or + `probs`. + + Returns: + Loss (or None if neither the loss nor the labels are present). + """ + if loss is not None: + return loss + if labels is None or (logits is None and probs is None): + return None + if loss_function_using_logits and logits is None: + raise ValueError('We need logits to compute loss, but it is set to None.') + if not loss_function_using_logits and probs is None: + raise ValueError('We need probs to compute loss, but it is set to None.') + + predictions = logits if loss_function_using_logits else probs + if loss_function == LossFunction.CROSS_ENTROPY: + loss = utils.log_loss(labels, predictions, loss_function_using_logits) + elif loss_function == LossFunction.SQUARED: + loss = utils.squared_loss(labels, predictions) + else: + loss = loss_function(labels, predictions) + return loss + def get_loss_train(self): """Calculates (if needed) cross-entropy losses for the training set. Returns: Loss (or None if neither the loss nor the labels are present). """ - if self.loss_train is None: - if self.labels_train is None: - return None - if self.logits_train is not None: - self.loss_train = utils.log_loss_from_logits(self.labels_train, - self.logits_train) - else: - self.loss_train = utils.log_loss(self.labels_train, self.probs_train) - return self.loss_train + if self.loss_function_using_logits is None: + self.loss_function_using_logits = (self.logits_train is not None) + return self._get_loss(self.loss_train, self.labels_train, self.logits_train, + self.probs_train, self.loss_function, + self.loss_function_using_logits) def get_loss_test(self): """Calculates (if needed) cross-entropy losses for the test set. @@ -270,15 +324,11 @@ class AttackInputData: Returns: Loss (or None if neither the loss nor the labels are present). """ - if self.loss_test is None: - if self.labels_test is None: - return None - if self.logits_test is not None: - self.loss_test = utils.log_loss_from_logits(self.labels_test, - self.logits_test) - else: - self.loss_test = utils.log_loss(self.labels_test, self.probs_test) - return self.loss_test + if self.loss_function_using_logits is None: + self.loss_function_using_logits = bool(self.logits_test) + return self._get_loss(self.loss_test, self.labels_test, self.logits_test, + self.probs_test, self.loss_function, + self.loss_function_using_logits) def get_entropy_train(self): """Calculates prediction entropy for the training set.""" diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py index ed6be9a..7dc06e0 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py @@ -19,13 +19,13 @@ from absl.testing import absltest from absl.testing import parameterized import numpy as np import pandas as pd - from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import _log_value from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize +from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import LossFunction from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import RocCurve from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleAttackResult @@ -48,9 +48,9 @@ class SingleSliceSpecTest(parameterized.TestCase): self.assertEqual(str(SingleSliceSpec(feature, value)), expected_str) -class AttackInputDataTest(absltest.TestCase): +class AttackInputDataTest(parameterized.TestCase): - def test_get_loss_from_logits(self): + def test_get_xe_loss_from_logits(self): attack_input = AttackInputData( logits_train=np.array([[-0.3, 1.5, 0.2], [2, 3, 0.5]]), logits_test=np.array([[2, 0.3, 0.2], [0.3, -0.5, 0.2]]), @@ -62,7 +62,7 @@ class AttackInputDataTest(absltest.TestCase): np.testing.assert_allclose( attack_input.get_loss_test(), [0.29860897, 0.95618669], atol=1e-7) - def test_get_loss_from_probs(self): + def test_get_xe_loss_from_probs(self): attack_input = AttackInputData( probs_train=np.array([[0.1, 0.1, 0.8], [0.8, 0.2, 0]]), probs_test=np.array([[0, 0.0001, 0.9999], [0.07, 0.18, 0.75]]), @@ -74,6 +74,130 @@ class AttackInputDataTest(absltest.TestCase): np.testing.assert_allclose( attack_input.get_loss_test(), [18.42068074, 0.28768207], atol=1e-7) + def test_get_binary_xe_loss_from_logits(self): + attack_input = AttackInputData( + logits_train=np.array([-10, -5, 0., 5, 10]), + logits_test=np.array([-10, -5, 0., 5, 10]), + labels_train=np.zeros((5,)), + labels_test=np.ones((5,)), + loss_function_using_logits=True) + expected_loss0 = np.array([0.000045398, 0.006715348, 0.6931471825, 5, 10]) + np.testing.assert_allclose( + attack_input.get_loss_train(), expected_loss0, rtol=1e-2) + np.testing.assert_allclose( + attack_input.get_loss_test(), expected_loss0[::-1], rtol=1e-2) + + def test_get_binary_xe_loss_from_probs(self): + attack_input = AttackInputData( + probs_train=np.array([0.2, 0.7, 0.1, 0.99, 0.002, 0.008]), + probs_test=np.array([0.2, 0.7, 0.1, 0.99, 0.002, 0.008]), + labels_train=np.zeros((6,)), + labels_test=np.ones((6,)), + loss_function_using_logits=False) + + expected_loss0 = np.array([ + 0.2231435513, 1.2039728043, 0.1053605157, 4.6051701860, 0.0020020027, + 0.0080321717 + ]) + expected_loss1 = np.array([ + 1.6094379124, 0.3566749439, 2.3025850930, 0.0100503359, 6.2146080984, + 4.8283137373 + ]) + np.testing.assert_allclose( + attack_input.get_loss_train(), expected_loss0, atol=1e-7) + np.testing.assert_allclose( + attack_input.get_loss_test(), expected_loss1, atol=1e-7) + + @parameterized.named_parameters( + ('use_logits', True, np.array([1, 0.]), np.array([0, 4.])), + ('use_default', None, np.array([1, 0.]), np.array([0, 4.])), + ('use_probs', False, np.array([0, 1.]), np.array([1, 1.])), + ) + def test_get_squared_loss(self, loss_function_using_logits, expected_train, + expected_test): + attack_input = AttackInputData( + logits_train=np.array([0, 0.]), + logits_test=np.array([0, 0.]), + probs_train=np.array([1, 1.]), + probs_test=np.array([1, 1.]), + labels_train=np.array([1, 0.]), + labels_test=np.array([0, 2.]), + loss_function=LossFunction.SQUARED, + loss_function_using_logits=loss_function_using_logits, + ) + np.testing.assert_allclose(attack_input.get_loss_train(), expected_train) + np.testing.assert_allclose(attack_input.get_loss_test(), expected_test) + + @parameterized.named_parameters( + ('use_logits', True, np.array([125.]), np.array([121.])), + ('use_default', None, np.array([125.]), np.array([121.])), + ('use_probs', False, np.array([458.]), np.array([454.])), + ) + def test_get_customized_loss(self, loss_function_using_logits, expected_train, + expected_test): + + def fake_loss(x, y): + return 2 * x + y + + attack_input = AttackInputData( + logits_train=np.array([ + 123., + ]), + logits_test=np.array([ + 123., + ]), + probs_train=np.array([ + 456., + ]), + probs_test=np.array([ + 456., + ]), + labels_train=np.array([1.]), + labels_test=np.array([-1.]), + loss_function=fake_loss, + loss_function_using_logits=loss_function_using_logits, + ) + np.testing.assert_allclose(attack_input.get_loss_train(), expected_train) + np.testing.assert_allclose(attack_input.get_loss_test(), expected_test) + + @parameterized.named_parameters( + ('both', np.array([0, 0.]), np.array([1, 1.]), np.array([1, 0.])), + ('only_logits', np.array([0, 0.]), None, np.array([1, 0.])), + ('only_probs', None, np.array([1, 1.]), np.array([0, 1.])), + ) + def test_default_loss_function_using_logits(self, logits, probs, expected): + """Tests for `loss_function_using_logits = None`. Should prefer logits.""" + attack_input = AttackInputData( + logits_train=logits, + logits_test=logits, + probs_train=probs, + probs_test=probs, + labels_train=np.array([1, 0.]), + labels_test=np.array([1, 0.]), + loss_function=LossFunction.SQUARED, + ) + np.testing.assert_allclose(attack_input.get_loss_train(), expected) + np.testing.assert_allclose(attack_input.get_loss_test(), expected) + + @parameterized.parameters( + (None, np.array([1.]), True), + (np.array([1.]), None, False), + ) + def test_loss_wrong_input(self, logits, probs, loss_function_using_logits): + attack_input = AttackInputData( + logits_train=logits, + logits_test=logits, + probs_train=probs, + probs_test=probs, + labels_train=np.array([ + 1., + ]), + labels_test=np.array([0.]), + loss_function_using_logits=loss_function_using_logits, + ) + self.assertRaises(ValueError, attack_input.get_loss_train) + self.assertRaises(ValueError, attack_input.get_loss_test) + def test_get_loss_explicitly_provided(self): attack_input = AttackInputData( loss_train=np.array([1.0, 3.0, 6.0]), diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils.py index 3610c70..c8fddba 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils.py @@ -17,23 +17,58 @@ import numpy as np from scipy import special -def log_loss(labels: np.ndarray, pred: np.ndarray, small_value=1e-8): - """Compute the cross entropy loss. +def log_loss(labels: np.ndarray, + pred: np.ndarray, + from_logits=False, + small_value=1e-8) -> np.ndarray: + """Computes the per-example cross entropy loss. Args: - labels: numpy array of shape (num_samples,) labels[i] is the true label - (scalar) of the i-th sample - pred: numpy array of shape(num_samples, num_classes) where pred[i] is the - probability vector of the i-th sample + labels: numpy array of shape (num_samples,). labels[i] is the true label + (scalar) of the i-th sample and is one of {0, 1, ..., num_classes-1}. + pred: numpy array of shape (num_samples, num_classes) or (num_samples,). For + categorical cross entropy loss, the shape should be (num_samples, + num_classes) and pred[i] is the logits or probability vector of the i-th + sample. For binary logistic loss, the shape should be (num_samples,) and + pred[i] is the probability of the positive class. + from_logits: whether `pred` is logits or probability vector. small_value: a scalar. np.log can become -inf if the probability is too close to 0, so the probability is clipped below by small_value. Returns: the cross-entropy loss of each sample """ + classes = np.unique(labels) + + # Binary logistic loss + if pred.ndim == 1: + if classes.min() < 0 or classes.max() > 1: + raise ValueError('Each value in pred is a scalar, but labels are not in', + '{0, 1}.') + if from_logits: + pred = special.expit(pred) + + indices_class0 = (labels == 0) + prob_correct = np.copy(pred) + prob_correct[indices_class0] = 1 - prob_correct[indices_class0] + return -np.log(np.maximum(prob_correct, small_value)) + + # Multi-class categorical cross entropy loss + if classes.min() < 0 or classes.max() >= pred.shape[1]: + raise ValueError('labels should be in the range [0, num_classes-1].') + if from_logits: + pred = special.softmax(pred, axis=-1) return -np.log(np.maximum(pred[range(labels.size), labels], small_value)) -def log_loss_from_logits(labels: np.ndarray, logits: np.ndarray): - """Compute the cross entropy loss from logits.""" - return log_loss(labels, special.softmax(logits, axis=-1)) +def squared_loss(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray: + """Computes the per-example squared loss. + + Args: + y_true: numpy array of shape (num_samples,) representing the true labels. + y_pred: numpy array of shape (num_samples,) representing the predictions. + + Returns: + the squared loss of each sample. + """ + return (y_true - y_pred)**2 diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_test.py index bd9d80e..4e82928 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_test.py @@ -13,71 +13,107 @@ # limitations under the License. from absl.testing import absltest +from absl.testing import parameterized import numpy as np from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils -class UtilsTest(absltest.TestCase): +class TestLogLoss(parameterized.TestCase): - def test_log_loss(self): - """Test computing cross-entropy loss.""" - # Test binary case with a few normal values + @parameterized.named_parameters( + ('label0', 0, + np.array([ + 4.60517019, 2.30258509, 1.38629436, 0.69314718, 0.28768207, + 0.10536052, 0.01005034 + ])), ('label1', 1, + np.array([ + 0.01005034, 0.10536052, 0.28768207, 0.69314718, 1.38629436, + 2.30258509, 4.60517019 + ]))) + def test_log_loss_from_probs_2_classes(self, label, expected_losses): pred = np.array([[0.01, 0.99], [0.1, 0.9], [0.25, 0.75], [0.5, 0.5], [0.75, 0.25], [0.9, 0.1], [0.99, 0.01]]) - # Test the cases when true label (for all samples) is 0 and 1 - expected_losses = { - 0: - np.array([ - 4.60517019, 2.30258509, 1.38629436, 0.69314718, 0.28768207, - 0.10536052, 0.01005034 - ]), - 1: - np.array([ - 0.01005034, 0.10536052, 0.28768207, 0.69314718, 1.38629436, - 2.30258509, 4.60517019 - ]) - } - for c in [0, 1]: # true label - y = np.ones(shape=pred.shape[0], dtype=int) * c - loss = utils.log_loss(y, pred) - np.testing.assert_allclose(loss, expected_losses[c], atol=1e-7) + y = np.full(pred.shape[0], label) + loss = utils.log_loss(y, pred) + np.testing.assert_allclose(loss, expected_losses, atol=1e-7) - # Test multiclass case with a few normal values - # (values from http://bit.ly/RJJHWA) + @parameterized.named_parameters( + ('label0', 0, np.array([1.60943791, 0.51082562, 0.51082562, 0.01005034])), + ('label1', 1, np.array([0.35667494, 1.60943791, 2.30258509, 6.2146081])), + ('label2', 2, np.array([2.30258509, 1.60943791, 1.2039728, 4.82831374])), + ) + def test_log_loss_from_probs_3_classes(self, label, expected_losses): + # Values from http://bit.ly/RJJHWA pred = np.array([[0.2, 0.7, 0.1], [0.6, 0.2, 0.2], [0.6, 0.1, 0.3], [0.99, 0.002, 0.008]]) - # Test the cases when true label (for all samples) is 0, 1, and 2 - expected_losses = { - 0: np.array([1.60943791, 0.51082562, 0.51082562, 0.01005034]), - 1: np.array([0.35667494, 1.60943791, 2.30258509, 6.2146081]), - 2: np.array([2.30258509, 1.60943791, 1.2039728, 4.82831374]) - } - for c in range(3): # true label - y = np.ones(shape=pred.shape[0], dtype=int) * c - loss = utils.log_loss(y, pred) - np.testing.assert_allclose(loss, expected_losses[c], atol=1e-7) + y = np.full(pred.shape[0], label) + loss = utils.log_loss(y, pred) + np.testing.assert_allclose(loss, expected_losses, atol=1e-7) - # Test boundary values 0 and 1 - pred = np.array([[0, 1]] * 2) + @parameterized.named_parameters( + ('small_value1e-8', 1e-8, 18.42068074), + ('small_value1e-20', 1e-20, 46.05170186), + ('small_value1e-50', 1e-50, 115.12925465), + ) + def test_log_loss_from_probs_boundary(self, small_value, expected_loss): + pred = np.array([[0., 1]] * 2) y = np.array([0, 1]) - small_values = [1e-8, 1e-20, 1e-50] - expected_losses = np.array([18.42068074, 46.05170186, 115.12925465]) - for i, small_value in enumerate(small_values): - loss = utils.log_loss(y, pred, small_value) - np.testing.assert_allclose( - loss, np.array([expected_losses[i], 0]), atol=1e-7) + loss = utils.log_loss(y, pred, small_value=small_value) + np.testing.assert_allclose(loss, np.array([expected_loss, 0]), atol=1e-7) def test_log_loss_from_logits(self): - """Test computing cross-entropy loss from logits.""" - logits = np.array([[1, 2, 0, -1], [1, 2, 0, -1], [-1, 3, 0, 0]]) labels = np.array([0, 3, 1]) expected_loss = np.array([1.4401897, 3.4401897, 0.11144278]) - loss = utils.log_loss_from_logits(labels, logits) + loss = utils.log_loss(labels, logits, from_logits=True) np.testing.assert_allclose(expected_loss, loss, atol=1e-7) + @parameterized.named_parameters( + ('label0', 0, + np.array([ + 0.2231435513, 1.2039728043, 0.1053605157, 4.6051701860, 0.0020020027, + 0.0080321717 + ])), ('label1', 1, + np.array([ + 1.6094379124, 0.3566749439, 2.3025850930, 0.0100503359, + 6.2146080984, 4.8283137373 + ]))) + def test_log_loss_binary_from_probs(self, label, expected_loss): + pred = np.array([0.2, 0.7, 0.1, 0.99, 0.002, 0.008]) + y = np.full(pred.shape[0], label) + loss = utils.log_loss(y, pred) + np.testing.assert_allclose(expected_loss, loss, atol=1e-7) + + @parameterized.named_parameters( + ('label0', 0, np.array([0.000045398, 0.006715348, 0.6931471825, 5, 10])), + ('label1', 1, np.array([10, 5, 0.6931471825, 0.006715348, 0.000045398])), + ) + def test_log_loss_binary_from_logits(self, label, expected_loss): + pred = np.array([-10, -5, 0., 5, 10]) + y = np.full(pred.shape[0], label) + loss = utils.log_loss(y, pred, from_logits=True) + np.testing.assert_allclose(expected_loss, loss, rtol=1e-2) + + @parameterized.named_parameters( + ('binary_mismatch', np.array([0, 1, 2]), np.ones((3,))), + ('binary_wrong_label', np.array([-1, 1]), np.ones((2,))), + ('multiclass_wrong_label', np.array([0, 3]), np.ones((2, 3))), + ) + def test_log_loss_wrong_classes(self, labels, pred): + self.assertRaises(ValueError, utils.log_loss, labels=labels, pred=pred) + + +class TestSquaredLoss(parameterized.TestCase): + + def test_squared_loss(self): + y_true = np.array([1, 2, 3, 4.]) + y_pred = np.array([4, 3, 2, 1.]) + expected_loss = np.array([9, 1, 1, 9.]) + loss = utils.squared_loss(y_true, y_pred) + np.testing.assert_allclose(loss, expected_loss, atol=1e-7) + if __name__ == '__main__': absltest.main()