From 2c65cc7910197fa0b71621e150cfb77f7a0802c5 Mon Sep 17 00:00:00 2001 From: Shuang Song Date: Thu, 3 Mar 2022 13:17:05 -0800 Subject: [PATCH] In binary log loss for membership inference attack, allow prediction to have shape (n, 1). PiperOrigin-RevId: 432267275 --- .../privacy_tests/membership_inference_attack/utils.py | 10 +++++++--- .../membership_inference_attack/utils_test.py | 9 +++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils.py index c8fddba..0ee3ddd 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils.py @@ -38,13 +38,17 @@ def log_loss(labels: np.ndarray, Returns: the cross-entropy loss of each sample """ + if labels.shape[0] != pred.shape[0]: + raise ValueError('labels and pred should have the same number of examples,', + f'but got {labels.shape[0]} and {pred.shape[0]}.') classes = np.unique(labels) # Binary logistic loss - if pred.ndim == 1: + if pred.size == pred.shape[0]: + pred = pred.flatten() if classes.min() < 0 or classes.max() > 1: - raise ValueError('Each value in pred is a scalar, but labels are not in', - '{0, 1}.') + raise ValueError('Each value in pred is a scalar, so labels are expected', + f'to be {0, 1}. But got {classes}.') if from_logits: pred = special.expit(pred) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_test.py index 4e82928..15f639b 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_test.py @@ -85,6 +85,8 @@ class TestLogLoss(parameterized.TestCase): y = np.full(pred.shape[0], label) loss = utils.log_loss(y, pred) np.testing.assert_allclose(expected_loss, loss, atol=1e-7) + loss = utils.log_loss(y, pred.reshape(-1, 1)) + np.testing.assert_allclose(expected_loss, loss, atol=1e-7) @parameterized.named_parameters( ('label0', 0, np.array([0.000045398, 0.006715348, 0.6931471825, 5, 10])), @@ -95,6 +97,8 @@ class TestLogLoss(parameterized.TestCase): y = np.full(pred.shape[0], label) loss = utils.log_loss(y, pred, from_logits=True) np.testing.assert_allclose(expected_loss, loss, rtol=1e-2) + loss = utils.log_loss(y, pred.reshape(-1, 1), from_logits=True) + np.testing.assert_allclose(expected_loss, loss, rtol=1e-2) @parameterized.named_parameters( ('binary_mismatch', np.array([0, 1, 2]), np.ones((3,))), @@ -104,6 +108,11 @@ class TestLogLoss(parameterized.TestCase): def test_log_loss_wrong_classes(self, labels, pred): self.assertRaises(ValueError, utils.log_loss, labels=labels, pred=pred) + def test_log_loss_wrong_number_of_example(self): + labels = np.array([0, 1, 1]) + pred = np.array([0.2]) + self.assertRaises(ValueError, utils.log_loss, labels=labels, pred=pred) + class TestSquaredLoss(parameterized.TestCase):