In binary log loss for membership inference attack, allow prediction to have shape (n, 1).

PiperOrigin-RevId: 432267275
This commit is contained in:
Shuang Song 2022-03-03 13:17:05 -08:00 committed by A. Unique TensorFlower
parent 767788e9cf
commit 2c65cc7910
2 changed files with 16 additions and 3 deletions

View file

@ -38,13 +38,17 @@ def log_loss(labels: np.ndarray,
Returns: Returns:
the cross-entropy loss of each sample the cross-entropy loss of each sample
""" """
if labels.shape[0] != pred.shape[0]:
raise ValueError('labels and pred should have the same number of examples,',
f'but got {labels.shape[0]} and {pred.shape[0]}.')
classes = np.unique(labels) classes = np.unique(labels)
# Binary logistic loss # Binary logistic loss
if pred.ndim == 1: if pred.size == pred.shape[0]:
pred = pred.flatten()
if classes.min() < 0 or classes.max() > 1: if classes.min() < 0 or classes.max() > 1:
raise ValueError('Each value in pred is a scalar, but labels are not in', raise ValueError('Each value in pred is a scalar, so labels are expected',
'{0, 1}.') f'to be {0, 1}. But got {classes}.')
if from_logits: if from_logits:
pred = special.expit(pred) pred = special.expit(pred)

View file

@ -85,6 +85,8 @@ class TestLogLoss(parameterized.TestCase):
y = np.full(pred.shape[0], label) y = np.full(pred.shape[0], label)
loss = utils.log_loss(y, pred) loss = utils.log_loss(y, pred)
np.testing.assert_allclose(expected_loss, loss, atol=1e-7) np.testing.assert_allclose(expected_loss, loss, atol=1e-7)
loss = utils.log_loss(y, pred.reshape(-1, 1))
np.testing.assert_allclose(expected_loss, loss, atol=1e-7)
@parameterized.named_parameters( @parameterized.named_parameters(
('label0', 0, np.array([0.000045398, 0.006715348, 0.6931471825, 5, 10])), ('label0', 0, np.array([0.000045398, 0.006715348, 0.6931471825, 5, 10])),
@ -95,6 +97,8 @@ class TestLogLoss(parameterized.TestCase):
y = np.full(pred.shape[0], label) y = np.full(pred.shape[0], label)
loss = utils.log_loss(y, pred, from_logits=True) loss = utils.log_loss(y, pred, from_logits=True)
np.testing.assert_allclose(expected_loss, loss, rtol=1e-2) np.testing.assert_allclose(expected_loss, loss, rtol=1e-2)
loss = utils.log_loss(y, pred.reshape(-1, 1), from_logits=True)
np.testing.assert_allclose(expected_loss, loss, rtol=1e-2)
@parameterized.named_parameters( @parameterized.named_parameters(
('binary_mismatch', np.array([0, 1, 2]), np.ones((3,))), ('binary_mismatch', np.array([0, 1, 2]), np.ones((3,))),
@ -104,6 +108,11 @@ class TestLogLoss(parameterized.TestCase):
def test_log_loss_wrong_classes(self, labels, pred): def test_log_loss_wrong_classes(self, labels, pred):
self.assertRaises(ValueError, utils.log_loss, labels=labels, pred=pred) self.assertRaises(ValueError, utils.log_loss, labels=labels, pred=pred)
def test_log_loss_wrong_number_of_example(self):
labels = np.array([0, 1, 1])
pred = np.array([0.2])
self.assertRaises(ValueError, utils.log_loss, labels=labels, pred=pred)
class TestSquaredLoss(parameterized.TestCase): class TestSquaredLoss(parameterized.TestCase):