Allow customized loss functions for membership inference attack.

PiperOrigin-RevId: 430267951
This commit is contained in:
Shuang Song 2022-02-22 12:17:24 -08:00 committed by A. Unique TensorFlower
parent 39fa1d361f
commit ec7d44237c
4 changed files with 321 additions and 76 deletions

View file

@ -19,7 +19,7 @@ import enum
import glob
import os
import pickle
from typing import Any, Iterable, MutableSequence, Optional, Union
from typing import Any, Callable, Iterable, MutableSequence, Optional, Union
import numpy as np
import pandas as pd
@ -165,6 +165,12 @@ def _log_value(probs, small_value=1e-30):
return -np.log(np.maximum(probs, small_value))
class LossFunction(enum.Enum):
"""An enum that defines loss function to use in `AttackInputData`."""
CROSS_ENTROPY = 'cross_entropy'
SQUARED = 'squared'
@dataclasses.dataclass
class AttackInputData:
"""Input data for running an attack.
@ -196,6 +202,17 @@ class AttackInputData:
entropy_train: Optional[np.ndarray] = None
entropy_test: Optional[np.ndarray] = None
# If loss is not explicitly specified, this function will be used to derive
# loss from logits and labels. It can be a pre-defined `LossFunction`.
# If a callable is provided, it should take in two argument, the 1st is
# labels, the 2nd is logits or probs.
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
LossFunction] = LossFunction.CROSS_ENTROPY
# Whether `loss_function` will be called with logits or probs. If not set
# (None), will decide by availablity of logits and probs and logits is
# preferred when both are available.
loss_function_using_logits: Optional[bool] = None
@property
def num_classes(self):
if self.labels_train is None or self.labels_test is None:
@ -248,21 +265,58 @@ class AttackInputData:
true_labels]
return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1)
@staticmethod
def _get_loss(
loss: Optional[np.ndarray], labels: Optional[np.ndarray],
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
LossFunction],
loss_function_using_logits: Optional[bool]) -> Optional[np.ndarray]:
"""Calculates (if needed) losses.
Args:
loss: the loss of each example.
labels: the scalar label of each example.
logits: the logits vector of each example.
probs: the probability vector of each example.
loss_function: if `loss` is not available, `labels` and one of `logits`
and `probs` are available, we will use this function to compute loss. It
is supposed to take in (label, logits / probs) as input.
loss_function_using_logits: if `loss_function` expects `logits` or
`probs`.
Returns:
Loss (or None if neither the loss nor the labels are present).
"""
if loss is not None:
return loss
if labels is None or (logits is None and probs is None):
return None
if loss_function_using_logits and logits is None:
raise ValueError('We need logits to compute loss, but it is set to None.')
if not loss_function_using_logits and probs is None:
raise ValueError('We need probs to compute loss, but it is set to None.')
predictions = logits if loss_function_using_logits else probs
if loss_function == LossFunction.CROSS_ENTROPY:
loss = utils.log_loss(labels, predictions, loss_function_using_logits)
elif loss_function == LossFunction.SQUARED:
loss = utils.squared_loss(labels, predictions)
else:
loss = loss_function(labels, predictions)
return loss
def get_loss_train(self):
"""Calculates (if needed) cross-entropy losses for the training set.
Returns:
Loss (or None if neither the loss nor the labels are present).
"""
if self.loss_train is None:
if self.labels_train is None:
return None
if self.logits_train is not None:
self.loss_train = utils.log_loss_from_logits(self.labels_train,
self.logits_train)
else:
self.loss_train = utils.log_loss(self.labels_train, self.probs_train)
return self.loss_train
if self.loss_function_using_logits is None:
self.loss_function_using_logits = (self.logits_train is not None)
return self._get_loss(self.loss_train, self.labels_train, self.logits_train,
self.probs_train, self.loss_function,
self.loss_function_using_logits)
def get_loss_test(self):
"""Calculates (if needed) cross-entropy losses for the test set.
@ -270,15 +324,11 @@ class AttackInputData:
Returns:
Loss (or None if neither the loss nor the labels are present).
"""
if self.loss_test is None:
if self.labels_test is None:
return None
if self.logits_test is not None:
self.loss_test = utils.log_loss_from_logits(self.labels_test,
self.logits_test)
else:
self.loss_test = utils.log_loss(self.labels_test, self.probs_test)
return self.loss_test
if self.loss_function_using_logits is None:
self.loss_function_using_logits = bool(self.logits_test)
return self._get_loss(self.loss_test, self.labels_test, self.logits_test,
self.probs_test, self.loss_function,
self.loss_function_using_logits)
def get_entropy_train(self):
"""Calculates prediction entropy for the training set."""

View file

@ -19,13 +19,13 @@ from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import pandas as pd
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import _log_value
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import LossFunction
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleAttackResult
@ -48,9 +48,9 @@ class SingleSliceSpecTest(parameterized.TestCase):
self.assertEqual(str(SingleSliceSpec(feature, value)), expected_str)
class AttackInputDataTest(absltest.TestCase):
class AttackInputDataTest(parameterized.TestCase):
def test_get_loss_from_logits(self):
def test_get_xe_loss_from_logits(self):
attack_input = AttackInputData(
logits_train=np.array([[-0.3, 1.5, 0.2], [2, 3, 0.5]]),
logits_test=np.array([[2, 0.3, 0.2], [0.3, -0.5, 0.2]]),
@ -62,7 +62,7 @@ class AttackInputDataTest(absltest.TestCase):
np.testing.assert_allclose(
attack_input.get_loss_test(), [0.29860897, 0.95618669], atol=1e-7)
def test_get_loss_from_probs(self):
def test_get_xe_loss_from_probs(self):
attack_input = AttackInputData(
probs_train=np.array([[0.1, 0.1, 0.8], [0.8, 0.2, 0]]),
probs_test=np.array([[0, 0.0001, 0.9999], [0.07, 0.18, 0.75]]),
@ -74,6 +74,130 @@ class AttackInputDataTest(absltest.TestCase):
np.testing.assert_allclose(
attack_input.get_loss_test(), [18.42068074, 0.28768207], atol=1e-7)
def test_get_binary_xe_loss_from_logits(self):
attack_input = AttackInputData(
logits_train=np.array([-10, -5, 0., 5, 10]),
logits_test=np.array([-10, -5, 0., 5, 10]),
labels_train=np.zeros((5,)),
labels_test=np.ones((5,)),
loss_function_using_logits=True)
expected_loss0 = np.array([0.000045398, 0.006715348, 0.6931471825, 5, 10])
np.testing.assert_allclose(
attack_input.get_loss_train(), expected_loss0, rtol=1e-2)
np.testing.assert_allclose(
attack_input.get_loss_test(), expected_loss0[::-1], rtol=1e-2)
def test_get_binary_xe_loss_from_probs(self):
attack_input = AttackInputData(
probs_train=np.array([0.2, 0.7, 0.1, 0.99, 0.002, 0.008]),
probs_test=np.array([0.2, 0.7, 0.1, 0.99, 0.002, 0.008]),
labels_train=np.zeros((6,)),
labels_test=np.ones((6,)),
loss_function_using_logits=False)
expected_loss0 = np.array([
0.2231435513, 1.2039728043, 0.1053605157, 4.6051701860, 0.0020020027,
0.0080321717
])
expected_loss1 = np.array([
1.6094379124, 0.3566749439, 2.3025850930, 0.0100503359, 6.2146080984,
4.8283137373
])
np.testing.assert_allclose(
attack_input.get_loss_train(), expected_loss0, atol=1e-7)
np.testing.assert_allclose(
attack_input.get_loss_test(), expected_loss1, atol=1e-7)
@parameterized.named_parameters(
('use_logits', True, np.array([1, 0.]), np.array([0, 4.])),
('use_default', None, np.array([1, 0.]), np.array([0, 4.])),
('use_probs', False, np.array([0, 1.]), np.array([1, 1.])),
)
def test_get_squared_loss(self, loss_function_using_logits, expected_train,
expected_test):
attack_input = AttackInputData(
logits_train=np.array([0, 0.]),
logits_test=np.array([0, 0.]),
probs_train=np.array([1, 1.]),
probs_test=np.array([1, 1.]),
labels_train=np.array([1, 0.]),
labels_test=np.array([0, 2.]),
loss_function=LossFunction.SQUARED,
loss_function_using_logits=loss_function_using_logits,
)
np.testing.assert_allclose(attack_input.get_loss_train(), expected_train)
np.testing.assert_allclose(attack_input.get_loss_test(), expected_test)
@parameterized.named_parameters(
('use_logits', True, np.array([125.]), np.array([121.])),
('use_default', None, np.array([125.]), np.array([121.])),
('use_probs', False, np.array([458.]), np.array([454.])),
)
def test_get_customized_loss(self, loss_function_using_logits, expected_train,
expected_test):
def fake_loss(x, y):
return 2 * x + y
attack_input = AttackInputData(
logits_train=np.array([
123.,
]),
logits_test=np.array([
123.,
]),
probs_train=np.array([
456.,
]),
probs_test=np.array([
456.,
]),
labels_train=np.array([1.]),
labels_test=np.array([-1.]),
loss_function=fake_loss,
loss_function_using_logits=loss_function_using_logits,
)
np.testing.assert_allclose(attack_input.get_loss_train(), expected_train)
np.testing.assert_allclose(attack_input.get_loss_test(), expected_test)
@parameterized.named_parameters(
('both', np.array([0, 0.]), np.array([1, 1.]), np.array([1, 0.])),
('only_logits', np.array([0, 0.]), None, np.array([1, 0.])),
('only_probs', None, np.array([1, 1.]), np.array([0, 1.])),
)
def test_default_loss_function_using_logits(self, logits, probs, expected):
"""Tests for `loss_function_using_logits = None`. Should prefer logits."""
attack_input = AttackInputData(
logits_train=logits,
logits_test=logits,
probs_train=probs,
probs_test=probs,
labels_train=np.array([1, 0.]),
labels_test=np.array([1, 0.]),
loss_function=LossFunction.SQUARED,
)
np.testing.assert_allclose(attack_input.get_loss_train(), expected)
np.testing.assert_allclose(attack_input.get_loss_test(), expected)
@parameterized.parameters(
(None, np.array([1.]), True),
(np.array([1.]), None, False),
)
def test_loss_wrong_input(self, logits, probs, loss_function_using_logits):
attack_input = AttackInputData(
logits_train=logits,
logits_test=logits,
probs_train=probs,
probs_test=probs,
labels_train=np.array([
1.,
]),
labels_test=np.array([0.]),
loss_function_using_logits=loss_function_using_logits,
)
self.assertRaises(ValueError, attack_input.get_loss_train)
self.assertRaises(ValueError, attack_input.get_loss_test)
def test_get_loss_explicitly_provided(self):
attack_input = AttackInputData(
loss_train=np.array([1.0, 3.0, 6.0]),

View file

@ -17,23 +17,58 @@ import numpy as np
from scipy import special
def log_loss(labels: np.ndarray, pred: np.ndarray, small_value=1e-8):
"""Compute the cross entropy loss.
def log_loss(labels: np.ndarray,
pred: np.ndarray,
from_logits=False,
small_value=1e-8) -> np.ndarray:
"""Computes the per-example cross entropy loss.
Args:
labels: numpy array of shape (num_samples,) labels[i] is the true label
(scalar) of the i-th sample
pred: numpy array of shape(num_samples, num_classes) where pred[i] is the
probability vector of the i-th sample
labels: numpy array of shape (num_samples,). labels[i] is the true label
(scalar) of the i-th sample and is one of {0, 1, ..., num_classes-1}.
pred: numpy array of shape (num_samples, num_classes) or (num_samples,). For
categorical cross entropy loss, the shape should be (num_samples,
num_classes) and pred[i] is the logits or probability vector of the i-th
sample. For binary logistic loss, the shape should be (num_samples,) and
pred[i] is the probability of the positive class.
from_logits: whether `pred` is logits or probability vector.
small_value: a scalar. np.log can become -inf if the probability is too
close to 0, so the probability is clipped below by small_value.
Returns:
the cross-entropy loss of each sample
"""
classes = np.unique(labels)
# Binary logistic loss
if pred.ndim == 1:
if classes.min() < 0 or classes.max() > 1:
raise ValueError('Each value in pred is a scalar, but labels are not in',
'{0, 1}.')
if from_logits:
pred = special.expit(pred)
indices_class0 = (labels == 0)
prob_correct = np.copy(pred)
prob_correct[indices_class0] = 1 - prob_correct[indices_class0]
return -np.log(np.maximum(prob_correct, small_value))
# Multi-class categorical cross entropy loss
if classes.min() < 0 or classes.max() >= pred.shape[1]:
raise ValueError('labels should be in the range [0, num_classes-1].')
if from_logits:
pred = special.softmax(pred, axis=-1)
return -np.log(np.maximum(pred[range(labels.size), labels], small_value))
def log_loss_from_logits(labels: np.ndarray, logits: np.ndarray):
"""Compute the cross entropy loss from logits."""
return log_loss(labels, special.softmax(logits, axis=-1))
def squared_loss(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
"""Computes the per-example squared loss.
Args:
y_true: numpy array of shape (num_samples,) representing the true labels.
y_pred: numpy array of shape (num_samples,) representing the predictions.
Returns:
the squared loss of each sample.
"""
return (y_true - y_pred)**2

View file

@ -13,71 +13,107 @@
# limitations under the License.
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
class UtilsTest(absltest.TestCase):
class TestLogLoss(parameterized.TestCase):
def test_log_loss(self):
"""Test computing cross-entropy loss."""
# Test binary case with a few normal values
@parameterized.named_parameters(
('label0', 0,
np.array([
4.60517019, 2.30258509, 1.38629436, 0.69314718, 0.28768207,
0.10536052, 0.01005034
])), ('label1', 1,
np.array([
0.01005034, 0.10536052, 0.28768207, 0.69314718, 1.38629436,
2.30258509, 4.60517019
])))
def test_log_loss_from_probs_2_classes(self, label, expected_losses):
pred = np.array([[0.01, 0.99], [0.1, 0.9], [0.25, 0.75], [0.5, 0.5],
[0.75, 0.25], [0.9, 0.1], [0.99, 0.01]])
# Test the cases when true label (for all samples) is 0 and 1
expected_losses = {
0:
np.array([
4.60517019, 2.30258509, 1.38629436, 0.69314718, 0.28768207,
0.10536052, 0.01005034
]),
1:
np.array([
0.01005034, 0.10536052, 0.28768207, 0.69314718, 1.38629436,
2.30258509, 4.60517019
])
}
for c in [0, 1]: # true label
y = np.ones(shape=pred.shape[0], dtype=int) * c
loss = utils.log_loss(y, pred)
np.testing.assert_allclose(loss, expected_losses[c], atol=1e-7)
y = np.full(pred.shape[0], label)
loss = utils.log_loss(y, pred)
np.testing.assert_allclose(loss, expected_losses, atol=1e-7)
# Test multiclass case with a few normal values
# (values from http://bit.ly/RJJHWA)
@parameterized.named_parameters(
('label0', 0, np.array([1.60943791, 0.51082562, 0.51082562, 0.01005034])),
('label1', 1, np.array([0.35667494, 1.60943791, 2.30258509, 6.2146081])),
('label2', 2, np.array([2.30258509, 1.60943791, 1.2039728, 4.82831374])),
)
def test_log_loss_from_probs_3_classes(self, label, expected_losses):
# Values from http://bit.ly/RJJHWA
pred = np.array([[0.2, 0.7, 0.1], [0.6, 0.2, 0.2], [0.6, 0.1, 0.3],
[0.99, 0.002, 0.008]])
# Test the cases when true label (for all samples) is 0, 1, and 2
expected_losses = {
0: np.array([1.60943791, 0.51082562, 0.51082562, 0.01005034]),
1: np.array([0.35667494, 1.60943791, 2.30258509, 6.2146081]),
2: np.array([2.30258509, 1.60943791, 1.2039728, 4.82831374])
}
for c in range(3): # true label
y = np.ones(shape=pred.shape[0], dtype=int) * c
loss = utils.log_loss(y, pred)
np.testing.assert_allclose(loss, expected_losses[c], atol=1e-7)
y = np.full(pred.shape[0], label)
loss = utils.log_loss(y, pred)
np.testing.assert_allclose(loss, expected_losses, atol=1e-7)
# Test boundary values 0 and 1
pred = np.array([[0, 1]] * 2)
@parameterized.named_parameters(
('small_value1e-8', 1e-8, 18.42068074),
('small_value1e-20', 1e-20, 46.05170186),
('small_value1e-50', 1e-50, 115.12925465),
)
def test_log_loss_from_probs_boundary(self, small_value, expected_loss):
pred = np.array([[0., 1]] * 2)
y = np.array([0, 1])
small_values = [1e-8, 1e-20, 1e-50]
expected_losses = np.array([18.42068074, 46.05170186, 115.12925465])
for i, small_value in enumerate(small_values):
loss = utils.log_loss(y, pred, small_value)
np.testing.assert_allclose(
loss, np.array([expected_losses[i], 0]), atol=1e-7)
loss = utils.log_loss(y, pred, small_value=small_value)
np.testing.assert_allclose(loss, np.array([expected_loss, 0]), atol=1e-7)
def test_log_loss_from_logits(self):
"""Test computing cross-entropy loss from logits."""
logits = np.array([[1, 2, 0, -1], [1, 2, 0, -1], [-1, 3, 0, 0]])
labels = np.array([0, 3, 1])
expected_loss = np.array([1.4401897, 3.4401897, 0.11144278])
loss = utils.log_loss_from_logits(labels, logits)
loss = utils.log_loss(labels, logits, from_logits=True)
np.testing.assert_allclose(expected_loss, loss, atol=1e-7)
@parameterized.named_parameters(
('label0', 0,
np.array([
0.2231435513, 1.2039728043, 0.1053605157, 4.6051701860, 0.0020020027,
0.0080321717
])), ('label1', 1,
np.array([
1.6094379124, 0.3566749439, 2.3025850930, 0.0100503359,
6.2146080984, 4.8283137373
])))
def test_log_loss_binary_from_probs(self, label, expected_loss):
pred = np.array([0.2, 0.7, 0.1, 0.99, 0.002, 0.008])
y = np.full(pred.shape[0], label)
loss = utils.log_loss(y, pred)
np.testing.assert_allclose(expected_loss, loss, atol=1e-7)
@parameterized.named_parameters(
('label0', 0, np.array([0.000045398, 0.006715348, 0.6931471825, 5, 10])),
('label1', 1, np.array([10, 5, 0.6931471825, 0.006715348, 0.000045398])),
)
def test_log_loss_binary_from_logits(self, label, expected_loss):
pred = np.array([-10, -5, 0., 5, 10])
y = np.full(pred.shape[0], label)
loss = utils.log_loss(y, pred, from_logits=True)
np.testing.assert_allclose(expected_loss, loss, rtol=1e-2)
@parameterized.named_parameters(
('binary_mismatch', np.array([0, 1, 2]), np.ones((3,))),
('binary_wrong_label', np.array([-1, 1]), np.ones((2,))),
('multiclass_wrong_label', np.array([0, 3]), np.ones((2, 3))),
)
def test_log_loss_wrong_classes(self, labels, pred):
self.assertRaises(ValueError, utils.log_loss, labels=labels, pred=pred)
class TestSquaredLoss(parameterized.TestCase):
def test_squared_loss(self):
y_true = np.array([1, 2, 3, 4.])
y_pred = np.array([4, 3, 2, 1.])
expected_loss = np.array([9, 1, 1, 9.])
loss = utils.squared_loss(y_true, y_pred)
np.testing.assert_allclose(loss, expected_loss, atol=1e-7)
if __name__ == '__main__':
absltest.main()