forked from 626_privacy/tensorflow_privacy
In binary log loss for membership inference attack, allow prediction to have shape (n, 1).
PiperOrigin-RevId: 432267275
This commit is contained in:
parent
767788e9cf
commit
2c65cc7910
2 changed files with 16 additions and 3 deletions
|
@ -38,13 +38,17 @@ def log_loss(labels: np.ndarray,
|
|||
Returns:
|
||||
the cross-entropy loss of each sample
|
||||
"""
|
||||
if labels.shape[0] != pred.shape[0]:
|
||||
raise ValueError('labels and pred should have the same number of examples,',
|
||||
f'but got {labels.shape[0]} and {pred.shape[0]}.')
|
||||
classes = np.unique(labels)
|
||||
|
||||
# Binary logistic loss
|
||||
if pred.ndim == 1:
|
||||
if pred.size == pred.shape[0]:
|
||||
pred = pred.flatten()
|
||||
if classes.min() < 0 or classes.max() > 1:
|
||||
raise ValueError('Each value in pred is a scalar, but labels are not in',
|
||||
'{0, 1}.')
|
||||
raise ValueError('Each value in pred is a scalar, so labels are expected',
|
||||
f'to be {0, 1}. But got {classes}.')
|
||||
if from_logits:
|
||||
pred = special.expit(pred)
|
||||
|
||||
|
|
|
@ -85,6 +85,8 @@ class TestLogLoss(parameterized.TestCase):
|
|||
y = np.full(pred.shape[0], label)
|
||||
loss = utils.log_loss(y, pred)
|
||||
np.testing.assert_allclose(expected_loss, loss, atol=1e-7)
|
||||
loss = utils.log_loss(y, pred.reshape(-1, 1))
|
||||
np.testing.assert_allclose(expected_loss, loss, atol=1e-7)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('label0', 0, np.array([0.000045398, 0.006715348, 0.6931471825, 5, 10])),
|
||||
|
@ -95,6 +97,8 @@ class TestLogLoss(parameterized.TestCase):
|
|||
y = np.full(pred.shape[0], label)
|
||||
loss = utils.log_loss(y, pred, from_logits=True)
|
||||
np.testing.assert_allclose(expected_loss, loss, rtol=1e-2)
|
||||
loss = utils.log_loss(y, pred.reshape(-1, 1), from_logits=True)
|
||||
np.testing.assert_allclose(expected_loss, loss, rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('binary_mismatch', np.array([0, 1, 2]), np.ones((3,))),
|
||||
|
@ -104,6 +108,11 @@ class TestLogLoss(parameterized.TestCase):
|
|||
def test_log_loss_wrong_classes(self, labels, pred):
|
||||
self.assertRaises(ValueError, utils.log_loss, labels=labels, pred=pred)
|
||||
|
||||
def test_log_loss_wrong_number_of_example(self):
|
||||
labels = np.array([0, 1, 1])
|
||||
pred = np.array([0.2])
|
||||
self.assertRaises(ValueError, utils.log_loss, labels=labels, pred=pred)
|
||||
|
||||
|
||||
class TestSquaredLoss(parameterized.TestCase):
|
||||
|
||||
|
|
Loading…
Reference in a new issue