From 08364adcb724f0ea60c64ac92381b343476b63e5 Mon Sep 17 00:00:00 2001 From: Shuang Song Date: Tue, 13 Sep 2022 10:32:25 -0700 Subject: [PATCH] Allow squared loss to take in labels and predictions of the same number of elements but different shapes. PiperOrigin-RevId: 474059427 --- .../privacy/privacy_tests/utils.py | 14 ++++++++++++++ .../privacy/privacy_tests/utils_test.py | 19 +++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/tensorflow_privacy/privacy/privacy_tests/utils.py b/tensorflow_privacy/privacy/privacy_tests/utils.py index d448f95..ae97a8d 100644 --- a/tensorflow_privacy/privacy/privacy_tests/utils.py +++ b/tensorflow_privacy/privacy/privacy_tests/utils.py @@ -79,6 +79,20 @@ def squared_loss(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray: Returns: the squared loss of each sample. """ + if y_true.ndim != 1: + logging.warning(('Squared loss expects the labels to have shape ' + '(num_examples, ) but got shape %s. Will use np.squeeze.'), + y_true.shape) + y_true = np.squeeze(y_true) + if y_pred.ndim != 1: + logging.warning(('Squared loss expects the predictions to have shape ' + '(num_examples, ) but got shape %s. Will use np.squeeze.'), + y_pred.shape) + y_pred = np.squeeze(y_pred) + if y_true.shape != y_pred.shape: + raise ValueError('Squared loss expects the labels and predictions to have ' + 'shape (num_examples, ), but after np.squeeze, the shapes ' + 'are %s and %s.' % (y_true.shape, y_pred.shape)) return (y_true - y_pred)**2 diff --git a/tensorflow_privacy/privacy/privacy_tests/utils_test.py b/tensorflow_privacy/privacy/privacy_tests/utils_test.py index df097f6..4c37680 100644 --- a/tensorflow_privacy/privacy/privacy_tests/utils_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/utils_test.py @@ -142,6 +142,25 @@ class TestSquaredLoss(parameterized.TestCase): loss = utils.squared_loss(y_true, y_pred) np.testing.assert_allclose(loss, expected_loss, atol=1e-7) + def test_squared_loss_need_squeeze(self): + y_true = np.array([1, 2, 3, 4.]).reshape((-1, 1)) + y_pred = np.array([4, 3, 2, 1.]).reshape((1, -1)) + expected_loss = np.array([9, 1, 1, 9.]) + loss = utils.squared_loss(y_true, y_pred) + np.testing.assert_allclose(loss, expected_loss, atol=1e-7) + + @parameterized.named_parameters( + ('wrong shape y_true', np.ones((2, 2)), np.ones((4,))), + ('wrong shape y_pred', np.ones((4,)), np.ones((2, 2))), + ) + def test_squared_loss_wrong_shape(self, y_true, y_pred): + self.assertRaises(ValueError, utils.squared_loss, y_true, y_pred) + + def test_squared_loss_different_num_of_elements(self): + y_true = np.array([1, 2, 3, 4.]) + y_pred = np.array([4, 3, 2]) + self.assertRaises(ValueError, utils.squared_loss, y_true, y_pred) + class TestMultilabelBCELoss(parameterized.TestCase):