Allow squared loss to take in labels and predictions of the same number of elements but different shapes.

PiperOrigin-RevId: 474059427
This commit is contained in:
Shuang Song 2022-09-13 10:32:25 -07:00 committed by A. Unique TensorFlower
parent ebae6c086e
commit 08364adcb7
2 changed files with 33 additions and 0 deletions

View file

@ -79,6 +79,20 @@ def squared_loss(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
Returns:
the squared loss of each sample.
"""
if y_true.ndim != 1:
logging.warning(('Squared loss expects the labels to have shape '
'(num_examples, ) but got shape %s. Will use np.squeeze.'),
y_true.shape)
y_true = np.squeeze(y_true)
if y_pred.ndim != 1:
logging.warning(('Squared loss expects the predictions to have shape '
'(num_examples, ) but got shape %s. Will use np.squeeze.'),
y_pred.shape)
y_pred = np.squeeze(y_pred)
if y_true.shape != y_pred.shape:
raise ValueError('Squared loss expects the labels and predictions to have '
'shape (num_examples, ), but after np.squeeze, the shapes '
'are %s and %s.' % (y_true.shape, y_pred.shape))
return (y_true - y_pred)**2

View file

@ -142,6 +142,25 @@ class TestSquaredLoss(parameterized.TestCase):
loss = utils.squared_loss(y_true, y_pred)
np.testing.assert_allclose(loss, expected_loss, atol=1e-7)
def test_squared_loss_need_squeeze(self):
y_true = np.array([1, 2, 3, 4.]).reshape((-1, 1))
y_pred = np.array([4, 3, 2, 1.]).reshape((1, -1))
expected_loss = np.array([9, 1, 1, 9.])
loss = utils.squared_loss(y_true, y_pred)
np.testing.assert_allclose(loss, expected_loss, atol=1e-7)
@parameterized.named_parameters(
('wrong shape y_true', np.ones((2, 2)), np.ones((4,))),
('wrong shape y_pred', np.ones((4,)), np.ones((2, 2))),
)
def test_squared_loss_wrong_shape(self, y_true, y_pred):
self.assertRaises(ValueError, utils.squared_loss, y_true, y_pred)
def test_squared_loss_different_num_of_elements(self):
y_true = np.array([1, 2, 3, 4.])
y_pred = np.array([4, 3, 2])
self.assertRaises(ValueError, utils.squared_loss, y_true, y_pred)
class TestMultilabelBCELoss(parameterized.TestCase):