In binary log loss for membership inference attack, allow prediction to have shape (n, 1).
PiperOrigin-RevId: 432267275
This commit is contained in:
parent
767788e9cf
commit
2c65cc7910
2 changed files with 16 additions and 3 deletions
|
@ -38,13 +38,17 @@ def log_loss(labels: np.ndarray,
|
||||||
Returns:
|
Returns:
|
||||||
the cross-entropy loss of each sample
|
the cross-entropy loss of each sample
|
||||||
"""
|
"""
|
||||||
|
if labels.shape[0] != pred.shape[0]:
|
||||||
|
raise ValueError('labels and pred should have the same number of examples,',
|
||||||
|
f'but got {labels.shape[0]} and {pred.shape[0]}.')
|
||||||
classes = np.unique(labels)
|
classes = np.unique(labels)
|
||||||
|
|
||||||
# Binary logistic loss
|
# Binary logistic loss
|
||||||
if pred.ndim == 1:
|
if pred.size == pred.shape[0]:
|
||||||
|
pred = pred.flatten()
|
||||||
if classes.min() < 0 or classes.max() > 1:
|
if classes.min() < 0 or classes.max() > 1:
|
||||||
raise ValueError('Each value in pred is a scalar, but labels are not in',
|
raise ValueError('Each value in pred is a scalar, so labels are expected',
|
||||||
'{0, 1}.')
|
f'to be {0, 1}. But got {classes}.')
|
||||||
if from_logits:
|
if from_logits:
|
||||||
pred = special.expit(pred)
|
pred = special.expit(pred)
|
||||||
|
|
||||||
|
|
|
@ -85,6 +85,8 @@ class TestLogLoss(parameterized.TestCase):
|
||||||
y = np.full(pred.shape[0], label)
|
y = np.full(pred.shape[0], label)
|
||||||
loss = utils.log_loss(y, pred)
|
loss = utils.log_loss(y, pred)
|
||||||
np.testing.assert_allclose(expected_loss, loss, atol=1e-7)
|
np.testing.assert_allclose(expected_loss, loss, atol=1e-7)
|
||||||
|
loss = utils.log_loss(y, pred.reshape(-1, 1))
|
||||||
|
np.testing.assert_allclose(expected_loss, loss, atol=1e-7)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('label0', 0, np.array([0.000045398, 0.006715348, 0.6931471825, 5, 10])),
|
('label0', 0, np.array([0.000045398, 0.006715348, 0.6931471825, 5, 10])),
|
||||||
|
@ -95,6 +97,8 @@ class TestLogLoss(parameterized.TestCase):
|
||||||
y = np.full(pred.shape[0], label)
|
y = np.full(pred.shape[0], label)
|
||||||
loss = utils.log_loss(y, pred, from_logits=True)
|
loss = utils.log_loss(y, pred, from_logits=True)
|
||||||
np.testing.assert_allclose(expected_loss, loss, rtol=1e-2)
|
np.testing.assert_allclose(expected_loss, loss, rtol=1e-2)
|
||||||
|
loss = utils.log_loss(y, pred.reshape(-1, 1), from_logits=True)
|
||||||
|
np.testing.assert_allclose(expected_loss, loss, rtol=1e-2)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('binary_mismatch', np.array([0, 1, 2]), np.ones((3,))),
|
('binary_mismatch', np.array([0, 1, 2]), np.ones((3,))),
|
||||||
|
@ -104,6 +108,11 @@ class TestLogLoss(parameterized.TestCase):
|
||||||
def test_log_loss_wrong_classes(self, labels, pred):
|
def test_log_loss_wrong_classes(self, labels, pred):
|
||||||
self.assertRaises(ValueError, utils.log_loss, labels=labels, pred=pred)
|
self.assertRaises(ValueError, utils.log_loss, labels=labels, pred=pred)
|
||||||
|
|
||||||
|
def test_log_loss_wrong_number_of_example(self):
|
||||||
|
labels = np.array([0, 1, 1])
|
||||||
|
pred = np.array([0.2])
|
||||||
|
self.assertRaises(ValueError, utils.log_loss, labels=labels, pred=pred)
|
||||||
|
|
||||||
|
|
||||||
class TestSquaredLoss(parameterized.TestCase):
|
class TestSquaredLoss(parameterized.TestCase):
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue