Allow squared loss to take in labels and predictions of the same number of elements but different shapes.
PiperOrigin-RevId: 474059427
This commit is contained in:
parent
ebae6c086e
commit
08364adcb7
2 changed files with 33 additions and 0 deletions
|
@ -79,6 +79,20 @@ def squared_loss(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
|
|||
Returns:
|
||||
the squared loss of each sample.
|
||||
"""
|
||||
if y_true.ndim != 1:
|
||||
logging.warning(('Squared loss expects the labels to have shape '
|
||||
'(num_examples, ) but got shape %s. Will use np.squeeze.'),
|
||||
y_true.shape)
|
||||
y_true = np.squeeze(y_true)
|
||||
if y_pred.ndim != 1:
|
||||
logging.warning(('Squared loss expects the predictions to have shape '
|
||||
'(num_examples, ) but got shape %s. Will use np.squeeze.'),
|
||||
y_pred.shape)
|
||||
y_pred = np.squeeze(y_pred)
|
||||
if y_true.shape != y_pred.shape:
|
||||
raise ValueError('Squared loss expects the labels and predictions to have '
|
||||
'shape (num_examples, ), but after np.squeeze, the shapes '
|
||||
'are %s and %s.' % (y_true.shape, y_pred.shape))
|
||||
return (y_true - y_pred)**2
|
||||
|
||||
|
||||
|
|
|
@ -142,6 +142,25 @@ class TestSquaredLoss(parameterized.TestCase):
|
|||
loss = utils.squared_loss(y_true, y_pred)
|
||||
np.testing.assert_allclose(loss, expected_loss, atol=1e-7)
|
||||
|
||||
def test_squared_loss_need_squeeze(self):
|
||||
y_true = np.array([1, 2, 3, 4.]).reshape((-1, 1))
|
||||
y_pred = np.array([4, 3, 2, 1.]).reshape((1, -1))
|
||||
expected_loss = np.array([9, 1, 1, 9.])
|
||||
loss = utils.squared_loss(y_true, y_pred)
|
||||
np.testing.assert_allclose(loss, expected_loss, atol=1e-7)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('wrong shape y_true', np.ones((2, 2)), np.ones((4,))),
|
||||
('wrong shape y_pred', np.ones((4,)), np.ones((2, 2))),
|
||||
)
|
||||
def test_squared_loss_wrong_shape(self, y_true, y_pred):
|
||||
self.assertRaises(ValueError, utils.squared_loss, y_true, y_pred)
|
||||
|
||||
def test_squared_loss_different_num_of_elements(self):
|
||||
y_true = np.array([1, 2, 3, 4.])
|
||||
y_pred = np.array([4, 3, 2])
|
||||
self.assertRaises(ValueError, utils.squared_loss, y_true, y_pred)
|
||||
|
||||
|
||||
class TestMultilabelBCELoss(parameterized.TestCase):
|
||||
|
||||
|
|
Loading…
Reference in a new issue