Allow customized loss functions for membership inference attack.
PiperOrigin-RevId: 430267951
This commit is contained in:
parent
39fa1d361f
commit
ec7d44237c
4 changed files with 321 additions and 76 deletions
|
@ -19,7 +19,7 @@ import enum
|
|||
import glob
|
||||
import os
|
||||
import pickle
|
||||
from typing import Any, Iterable, MutableSequence, Optional, Union
|
||||
from typing import Any, Callable, Iterable, MutableSequence, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
@ -165,6 +165,12 @@ def _log_value(probs, small_value=1e-30):
|
|||
return -np.log(np.maximum(probs, small_value))
|
||||
|
||||
|
||||
class LossFunction(enum.Enum):
|
||||
"""An enum that defines loss function to use in `AttackInputData`."""
|
||||
CROSS_ENTROPY = 'cross_entropy'
|
||||
SQUARED = 'squared'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AttackInputData:
|
||||
"""Input data for running an attack.
|
||||
|
@ -196,6 +202,17 @@ class AttackInputData:
|
|||
entropy_train: Optional[np.ndarray] = None
|
||||
entropy_test: Optional[np.ndarray] = None
|
||||
|
||||
# If loss is not explicitly specified, this function will be used to derive
|
||||
# loss from logits and labels. It can be a pre-defined `LossFunction`.
|
||||
# If a callable is provided, it should take in two argument, the 1st is
|
||||
# labels, the 2nd is logits or probs.
|
||||
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
|
||||
LossFunction] = LossFunction.CROSS_ENTROPY
|
||||
# Whether `loss_function` will be called with logits or probs. If not set
|
||||
# (None), will decide by availablity of logits and probs and logits is
|
||||
# preferred when both are available.
|
||||
loss_function_using_logits: Optional[bool] = None
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
if self.labels_train is None or self.labels_test is None:
|
||||
|
@ -248,21 +265,58 @@ class AttackInputData:
|
|||
true_labels]
|
||||
return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1)
|
||||
|
||||
@staticmethod
|
||||
def _get_loss(
|
||||
loss: Optional[np.ndarray], labels: Optional[np.ndarray],
|
||||
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
|
||||
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
|
||||
LossFunction],
|
||||
loss_function_using_logits: Optional[bool]) -> Optional[np.ndarray]:
|
||||
"""Calculates (if needed) losses.
|
||||
|
||||
Args:
|
||||
loss: the loss of each example.
|
||||
labels: the scalar label of each example.
|
||||
logits: the logits vector of each example.
|
||||
probs: the probability vector of each example.
|
||||
loss_function: if `loss` is not available, `labels` and one of `logits`
|
||||
and `probs` are available, we will use this function to compute loss. It
|
||||
is supposed to take in (label, logits / probs) as input.
|
||||
loss_function_using_logits: if `loss_function` expects `logits` or
|
||||
`probs`.
|
||||
|
||||
Returns:
|
||||
Loss (or None if neither the loss nor the labels are present).
|
||||
"""
|
||||
if loss is not None:
|
||||
return loss
|
||||
if labels is None or (logits is None and probs is None):
|
||||
return None
|
||||
if loss_function_using_logits and logits is None:
|
||||
raise ValueError('We need logits to compute loss, but it is set to None.')
|
||||
if not loss_function_using_logits and probs is None:
|
||||
raise ValueError('We need probs to compute loss, but it is set to None.')
|
||||
|
||||
predictions = logits if loss_function_using_logits else probs
|
||||
if loss_function == LossFunction.CROSS_ENTROPY:
|
||||
loss = utils.log_loss(labels, predictions, loss_function_using_logits)
|
||||
elif loss_function == LossFunction.SQUARED:
|
||||
loss = utils.squared_loss(labels, predictions)
|
||||
else:
|
||||
loss = loss_function(labels, predictions)
|
||||
return loss
|
||||
|
||||
def get_loss_train(self):
|
||||
"""Calculates (if needed) cross-entropy losses for the training set.
|
||||
|
||||
Returns:
|
||||
Loss (or None if neither the loss nor the labels are present).
|
||||
"""
|
||||
if self.loss_train is None:
|
||||
if self.labels_train is None:
|
||||
return None
|
||||
if self.logits_train is not None:
|
||||
self.loss_train = utils.log_loss_from_logits(self.labels_train,
|
||||
self.logits_train)
|
||||
else:
|
||||
self.loss_train = utils.log_loss(self.labels_train, self.probs_train)
|
||||
return self.loss_train
|
||||
if self.loss_function_using_logits is None:
|
||||
self.loss_function_using_logits = (self.logits_train is not None)
|
||||
return self._get_loss(self.loss_train, self.labels_train, self.logits_train,
|
||||
self.probs_train, self.loss_function,
|
||||
self.loss_function_using_logits)
|
||||
|
||||
def get_loss_test(self):
|
||||
"""Calculates (if needed) cross-entropy losses for the test set.
|
||||
|
@ -270,15 +324,11 @@ class AttackInputData:
|
|||
Returns:
|
||||
Loss (or None if neither the loss nor the labels are present).
|
||||
"""
|
||||
if self.loss_test is None:
|
||||
if self.labels_test is None:
|
||||
return None
|
||||
if self.logits_test is not None:
|
||||
self.loss_test = utils.log_loss_from_logits(self.labels_test,
|
||||
self.logits_test)
|
||||
else:
|
||||
self.loss_test = utils.log_loss(self.labels_test, self.probs_test)
|
||||
return self.loss_test
|
||||
if self.loss_function_using_logits is None:
|
||||
self.loss_function_using_logits = bool(self.logits_test)
|
||||
return self._get_loss(self.loss_test, self.labels_test, self.logits_test,
|
||||
self.probs_test, self.loss_function,
|
||||
self.loss_function_using_logits)
|
||||
|
||||
def get_entropy_train(self):
|
||||
"""Calculates prediction entropy for the training set."""
|
||||
|
|
|
@ -19,13 +19,13 @@ from absl.testing import absltest
|
|||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import _log_value
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import LossFunction
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import RocCurve
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleAttackResult
|
||||
|
@ -48,9 +48,9 @@ class SingleSliceSpecTest(parameterized.TestCase):
|
|||
self.assertEqual(str(SingleSliceSpec(feature, value)), expected_str)
|
||||
|
||||
|
||||
class AttackInputDataTest(absltest.TestCase):
|
||||
class AttackInputDataTest(parameterized.TestCase):
|
||||
|
||||
def test_get_loss_from_logits(self):
|
||||
def test_get_xe_loss_from_logits(self):
|
||||
attack_input = AttackInputData(
|
||||
logits_train=np.array([[-0.3, 1.5, 0.2], [2, 3, 0.5]]),
|
||||
logits_test=np.array([[2, 0.3, 0.2], [0.3, -0.5, 0.2]]),
|
||||
|
@ -62,7 +62,7 @@ class AttackInputDataTest(absltest.TestCase):
|
|||
np.testing.assert_allclose(
|
||||
attack_input.get_loss_test(), [0.29860897, 0.95618669], atol=1e-7)
|
||||
|
||||
def test_get_loss_from_probs(self):
|
||||
def test_get_xe_loss_from_probs(self):
|
||||
attack_input = AttackInputData(
|
||||
probs_train=np.array([[0.1, 0.1, 0.8], [0.8, 0.2, 0]]),
|
||||
probs_test=np.array([[0, 0.0001, 0.9999], [0.07, 0.18, 0.75]]),
|
||||
|
@ -74,6 +74,130 @@ class AttackInputDataTest(absltest.TestCase):
|
|||
np.testing.assert_allclose(
|
||||
attack_input.get_loss_test(), [18.42068074, 0.28768207], atol=1e-7)
|
||||
|
||||
def test_get_binary_xe_loss_from_logits(self):
|
||||
attack_input = AttackInputData(
|
||||
logits_train=np.array([-10, -5, 0., 5, 10]),
|
||||
logits_test=np.array([-10, -5, 0., 5, 10]),
|
||||
labels_train=np.zeros((5,)),
|
||||
labels_test=np.ones((5,)),
|
||||
loss_function_using_logits=True)
|
||||
expected_loss0 = np.array([0.000045398, 0.006715348, 0.6931471825, 5, 10])
|
||||
np.testing.assert_allclose(
|
||||
attack_input.get_loss_train(), expected_loss0, rtol=1e-2)
|
||||
np.testing.assert_allclose(
|
||||
attack_input.get_loss_test(), expected_loss0[::-1], rtol=1e-2)
|
||||
|
||||
def test_get_binary_xe_loss_from_probs(self):
|
||||
attack_input = AttackInputData(
|
||||
probs_train=np.array([0.2, 0.7, 0.1, 0.99, 0.002, 0.008]),
|
||||
probs_test=np.array([0.2, 0.7, 0.1, 0.99, 0.002, 0.008]),
|
||||
labels_train=np.zeros((6,)),
|
||||
labels_test=np.ones((6,)),
|
||||
loss_function_using_logits=False)
|
||||
|
||||
expected_loss0 = np.array([
|
||||
0.2231435513, 1.2039728043, 0.1053605157, 4.6051701860, 0.0020020027,
|
||||
0.0080321717
|
||||
])
|
||||
expected_loss1 = np.array([
|
||||
1.6094379124, 0.3566749439, 2.3025850930, 0.0100503359, 6.2146080984,
|
||||
4.8283137373
|
||||
])
|
||||
np.testing.assert_allclose(
|
||||
attack_input.get_loss_train(), expected_loss0, atol=1e-7)
|
||||
np.testing.assert_allclose(
|
||||
attack_input.get_loss_test(), expected_loss1, atol=1e-7)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('use_logits', True, np.array([1, 0.]), np.array([0, 4.])),
|
||||
('use_default', None, np.array([1, 0.]), np.array([0, 4.])),
|
||||
('use_probs', False, np.array([0, 1.]), np.array([1, 1.])),
|
||||
)
|
||||
def test_get_squared_loss(self, loss_function_using_logits, expected_train,
|
||||
expected_test):
|
||||
attack_input = AttackInputData(
|
||||
logits_train=np.array([0, 0.]),
|
||||
logits_test=np.array([0, 0.]),
|
||||
probs_train=np.array([1, 1.]),
|
||||
probs_test=np.array([1, 1.]),
|
||||
labels_train=np.array([1, 0.]),
|
||||
labels_test=np.array([0, 2.]),
|
||||
loss_function=LossFunction.SQUARED,
|
||||
loss_function_using_logits=loss_function_using_logits,
|
||||
)
|
||||
np.testing.assert_allclose(attack_input.get_loss_train(), expected_train)
|
||||
np.testing.assert_allclose(attack_input.get_loss_test(), expected_test)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('use_logits', True, np.array([125.]), np.array([121.])),
|
||||
('use_default', None, np.array([125.]), np.array([121.])),
|
||||
('use_probs', False, np.array([458.]), np.array([454.])),
|
||||
)
|
||||
def test_get_customized_loss(self, loss_function_using_logits, expected_train,
|
||||
expected_test):
|
||||
|
||||
def fake_loss(x, y):
|
||||
return 2 * x + y
|
||||
|
||||
attack_input = AttackInputData(
|
||||
logits_train=np.array([
|
||||
123.,
|
||||
]),
|
||||
logits_test=np.array([
|
||||
123.,
|
||||
]),
|
||||
probs_train=np.array([
|
||||
456.,
|
||||
]),
|
||||
probs_test=np.array([
|
||||
456.,
|
||||
]),
|
||||
labels_train=np.array([1.]),
|
||||
labels_test=np.array([-1.]),
|
||||
loss_function=fake_loss,
|
||||
loss_function_using_logits=loss_function_using_logits,
|
||||
)
|
||||
np.testing.assert_allclose(attack_input.get_loss_train(), expected_train)
|
||||
np.testing.assert_allclose(attack_input.get_loss_test(), expected_test)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('both', np.array([0, 0.]), np.array([1, 1.]), np.array([1, 0.])),
|
||||
('only_logits', np.array([0, 0.]), None, np.array([1, 0.])),
|
||||
('only_probs', None, np.array([1, 1.]), np.array([0, 1.])),
|
||||
)
|
||||
def test_default_loss_function_using_logits(self, logits, probs, expected):
|
||||
"""Tests for `loss_function_using_logits = None`. Should prefer logits."""
|
||||
attack_input = AttackInputData(
|
||||
logits_train=logits,
|
||||
logits_test=logits,
|
||||
probs_train=probs,
|
||||
probs_test=probs,
|
||||
labels_train=np.array([1, 0.]),
|
||||
labels_test=np.array([1, 0.]),
|
||||
loss_function=LossFunction.SQUARED,
|
||||
)
|
||||
np.testing.assert_allclose(attack_input.get_loss_train(), expected)
|
||||
np.testing.assert_allclose(attack_input.get_loss_test(), expected)
|
||||
|
||||
@parameterized.parameters(
|
||||
(None, np.array([1.]), True),
|
||||
(np.array([1.]), None, False),
|
||||
)
|
||||
def test_loss_wrong_input(self, logits, probs, loss_function_using_logits):
|
||||
attack_input = AttackInputData(
|
||||
logits_train=logits,
|
||||
logits_test=logits,
|
||||
probs_train=probs,
|
||||
probs_test=probs,
|
||||
labels_train=np.array([
|
||||
1.,
|
||||
]),
|
||||
labels_test=np.array([0.]),
|
||||
loss_function_using_logits=loss_function_using_logits,
|
||||
)
|
||||
self.assertRaises(ValueError, attack_input.get_loss_train)
|
||||
self.assertRaises(ValueError, attack_input.get_loss_test)
|
||||
|
||||
def test_get_loss_explicitly_provided(self):
|
||||
attack_input = AttackInputData(
|
||||
loss_train=np.array([1.0, 3.0, 6.0]),
|
||||
|
|
|
@ -17,23 +17,58 @@ import numpy as np
|
|||
from scipy import special
|
||||
|
||||
|
||||
def log_loss(labels: np.ndarray, pred: np.ndarray, small_value=1e-8):
|
||||
"""Compute the cross entropy loss.
|
||||
def log_loss(labels: np.ndarray,
|
||||
pred: np.ndarray,
|
||||
from_logits=False,
|
||||
small_value=1e-8) -> np.ndarray:
|
||||
"""Computes the per-example cross entropy loss.
|
||||
|
||||
Args:
|
||||
labels: numpy array of shape (num_samples,) labels[i] is the true label
|
||||
(scalar) of the i-th sample
|
||||
pred: numpy array of shape(num_samples, num_classes) where pred[i] is the
|
||||
probability vector of the i-th sample
|
||||
labels: numpy array of shape (num_samples,). labels[i] is the true label
|
||||
(scalar) of the i-th sample and is one of {0, 1, ..., num_classes-1}.
|
||||
pred: numpy array of shape (num_samples, num_classes) or (num_samples,). For
|
||||
categorical cross entropy loss, the shape should be (num_samples,
|
||||
num_classes) and pred[i] is the logits or probability vector of the i-th
|
||||
sample. For binary logistic loss, the shape should be (num_samples,) and
|
||||
pred[i] is the probability of the positive class.
|
||||
from_logits: whether `pred` is logits or probability vector.
|
||||
small_value: a scalar. np.log can become -inf if the probability is too
|
||||
close to 0, so the probability is clipped below by small_value.
|
||||
|
||||
Returns:
|
||||
the cross-entropy loss of each sample
|
||||
"""
|
||||
classes = np.unique(labels)
|
||||
|
||||
# Binary logistic loss
|
||||
if pred.ndim == 1:
|
||||
if classes.min() < 0 or classes.max() > 1:
|
||||
raise ValueError('Each value in pred is a scalar, but labels are not in',
|
||||
'{0, 1}.')
|
||||
if from_logits:
|
||||
pred = special.expit(pred)
|
||||
|
||||
indices_class0 = (labels == 0)
|
||||
prob_correct = np.copy(pred)
|
||||
prob_correct[indices_class0] = 1 - prob_correct[indices_class0]
|
||||
return -np.log(np.maximum(prob_correct, small_value))
|
||||
|
||||
# Multi-class categorical cross entropy loss
|
||||
if classes.min() < 0 or classes.max() >= pred.shape[1]:
|
||||
raise ValueError('labels should be in the range [0, num_classes-1].')
|
||||
if from_logits:
|
||||
pred = special.softmax(pred, axis=-1)
|
||||
return -np.log(np.maximum(pred[range(labels.size), labels], small_value))
|
||||
|
||||
|
||||
def log_loss_from_logits(labels: np.ndarray, logits: np.ndarray):
|
||||
"""Compute the cross entropy loss from logits."""
|
||||
return log_loss(labels, special.softmax(logits, axis=-1))
|
||||
def squared_loss(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
|
||||
"""Computes the per-example squared loss.
|
||||
|
||||
Args:
|
||||
y_true: numpy array of shape (num_samples,) representing the true labels.
|
||||
y_pred: numpy array of shape (num_samples,) representing the predictions.
|
||||
|
||||
Returns:
|
||||
the squared loss of each sample.
|
||||
"""
|
||||
return (y_true - y_pred)**2
|
||||
|
|
|
@ -13,71 +13,107 @@
|
|||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
|
||||
|
||||
|
||||
class UtilsTest(absltest.TestCase):
|
||||
class TestLogLoss(parameterized.TestCase):
|
||||
|
||||
def test_log_loss(self):
|
||||
"""Test computing cross-entropy loss."""
|
||||
# Test binary case with a few normal values
|
||||
@parameterized.named_parameters(
|
||||
('label0', 0,
|
||||
np.array([
|
||||
4.60517019, 2.30258509, 1.38629436, 0.69314718, 0.28768207,
|
||||
0.10536052, 0.01005034
|
||||
])), ('label1', 1,
|
||||
np.array([
|
||||
0.01005034, 0.10536052, 0.28768207, 0.69314718, 1.38629436,
|
||||
2.30258509, 4.60517019
|
||||
])))
|
||||
def test_log_loss_from_probs_2_classes(self, label, expected_losses):
|
||||
pred = np.array([[0.01, 0.99], [0.1, 0.9], [0.25, 0.75], [0.5, 0.5],
|
||||
[0.75, 0.25], [0.9, 0.1], [0.99, 0.01]])
|
||||
# Test the cases when true label (for all samples) is 0 and 1
|
||||
expected_losses = {
|
||||
0:
|
||||
np.array([
|
||||
4.60517019, 2.30258509, 1.38629436, 0.69314718, 0.28768207,
|
||||
0.10536052, 0.01005034
|
||||
]),
|
||||
1:
|
||||
np.array([
|
||||
0.01005034, 0.10536052, 0.28768207, 0.69314718, 1.38629436,
|
||||
2.30258509, 4.60517019
|
||||
])
|
||||
}
|
||||
for c in [0, 1]: # true label
|
||||
y = np.ones(shape=pred.shape[0], dtype=int) * c
|
||||
loss = utils.log_loss(y, pred)
|
||||
np.testing.assert_allclose(loss, expected_losses[c], atol=1e-7)
|
||||
y = np.full(pred.shape[0], label)
|
||||
loss = utils.log_loss(y, pred)
|
||||
np.testing.assert_allclose(loss, expected_losses, atol=1e-7)
|
||||
|
||||
# Test multiclass case with a few normal values
|
||||
# (values from http://bit.ly/RJJHWA)
|
||||
@parameterized.named_parameters(
|
||||
('label0', 0, np.array([1.60943791, 0.51082562, 0.51082562, 0.01005034])),
|
||||
('label1', 1, np.array([0.35667494, 1.60943791, 2.30258509, 6.2146081])),
|
||||
('label2', 2, np.array([2.30258509, 1.60943791, 1.2039728, 4.82831374])),
|
||||
)
|
||||
def test_log_loss_from_probs_3_classes(self, label, expected_losses):
|
||||
# Values from http://bit.ly/RJJHWA
|
||||
pred = np.array([[0.2, 0.7, 0.1], [0.6, 0.2, 0.2], [0.6, 0.1, 0.3],
|
||||
[0.99, 0.002, 0.008]])
|
||||
# Test the cases when true label (for all samples) is 0, 1, and 2
|
||||
expected_losses = {
|
||||
0: np.array([1.60943791, 0.51082562, 0.51082562, 0.01005034]),
|
||||
1: np.array([0.35667494, 1.60943791, 2.30258509, 6.2146081]),
|
||||
2: np.array([2.30258509, 1.60943791, 1.2039728, 4.82831374])
|
||||
}
|
||||
for c in range(3): # true label
|
||||
y = np.ones(shape=pred.shape[0], dtype=int) * c
|
||||
loss = utils.log_loss(y, pred)
|
||||
np.testing.assert_allclose(loss, expected_losses[c], atol=1e-7)
|
||||
y = np.full(pred.shape[0], label)
|
||||
loss = utils.log_loss(y, pred)
|
||||
np.testing.assert_allclose(loss, expected_losses, atol=1e-7)
|
||||
|
||||
# Test boundary values 0 and 1
|
||||
pred = np.array([[0, 1]] * 2)
|
||||
@parameterized.named_parameters(
|
||||
('small_value1e-8', 1e-8, 18.42068074),
|
||||
('small_value1e-20', 1e-20, 46.05170186),
|
||||
('small_value1e-50', 1e-50, 115.12925465),
|
||||
)
|
||||
def test_log_loss_from_probs_boundary(self, small_value, expected_loss):
|
||||
pred = np.array([[0., 1]] * 2)
|
||||
y = np.array([0, 1])
|
||||
small_values = [1e-8, 1e-20, 1e-50]
|
||||
expected_losses = np.array([18.42068074, 46.05170186, 115.12925465])
|
||||
for i, small_value in enumerate(small_values):
|
||||
loss = utils.log_loss(y, pred, small_value)
|
||||
np.testing.assert_allclose(
|
||||
loss, np.array([expected_losses[i], 0]), atol=1e-7)
|
||||
loss = utils.log_loss(y, pred, small_value=small_value)
|
||||
np.testing.assert_allclose(loss, np.array([expected_loss, 0]), atol=1e-7)
|
||||
|
||||
def test_log_loss_from_logits(self):
|
||||
"""Test computing cross-entropy loss from logits."""
|
||||
|
||||
logits = np.array([[1, 2, 0, -1], [1, 2, 0, -1], [-1, 3, 0, 0]])
|
||||
labels = np.array([0, 3, 1])
|
||||
expected_loss = np.array([1.4401897, 3.4401897, 0.11144278])
|
||||
|
||||
loss = utils.log_loss_from_logits(labels, logits)
|
||||
loss = utils.log_loss(labels, logits, from_logits=True)
|
||||
np.testing.assert_allclose(expected_loss, loss, atol=1e-7)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('label0', 0,
|
||||
np.array([
|
||||
0.2231435513, 1.2039728043, 0.1053605157, 4.6051701860, 0.0020020027,
|
||||
0.0080321717
|
||||
])), ('label1', 1,
|
||||
np.array([
|
||||
1.6094379124, 0.3566749439, 2.3025850930, 0.0100503359,
|
||||
6.2146080984, 4.8283137373
|
||||
])))
|
||||
def test_log_loss_binary_from_probs(self, label, expected_loss):
|
||||
pred = np.array([0.2, 0.7, 0.1, 0.99, 0.002, 0.008])
|
||||
y = np.full(pred.shape[0], label)
|
||||
loss = utils.log_loss(y, pred)
|
||||
np.testing.assert_allclose(expected_loss, loss, atol=1e-7)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('label0', 0, np.array([0.000045398, 0.006715348, 0.6931471825, 5, 10])),
|
||||
('label1', 1, np.array([10, 5, 0.6931471825, 0.006715348, 0.000045398])),
|
||||
)
|
||||
def test_log_loss_binary_from_logits(self, label, expected_loss):
|
||||
pred = np.array([-10, -5, 0., 5, 10])
|
||||
y = np.full(pred.shape[0], label)
|
||||
loss = utils.log_loss(y, pred, from_logits=True)
|
||||
np.testing.assert_allclose(expected_loss, loss, rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('binary_mismatch', np.array([0, 1, 2]), np.ones((3,))),
|
||||
('binary_wrong_label', np.array([-1, 1]), np.ones((2,))),
|
||||
('multiclass_wrong_label', np.array([0, 3]), np.ones((2, 3))),
|
||||
)
|
||||
def test_log_loss_wrong_classes(self, labels, pred):
|
||||
self.assertRaises(ValueError, utils.log_loss, labels=labels, pred=pred)
|
||||
|
||||
|
||||
class TestSquaredLoss(parameterized.TestCase):
|
||||
|
||||
def test_squared_loss(self):
|
||||
y_true = np.array([1, 2, 3, 4.])
|
||||
y_pred = np.array([4, 3, 2, 1.])
|
||||
expected_loss = np.array([9, 1, 1, 9.])
|
||||
loss = utils.squared_loss(y_true, y_pred)
|
||||
np.testing.assert_allclose(loss, expected_loss, atol=1e-7)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue