diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index c122443..dc558c1 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -215,10 +215,11 @@ class AttackInputData: entropy_test: Optional[np.ndarray] = None # If loss is not explicitly specified, this function will be used to derive - # loss from logits and labels. It can be a pre-defined `LossFunction`. + # loss from logits and labels. It can be a pre-defined `LossFunction` or its + # string representation, or a callable. # If a callable is provided, it should take in two argument, the 1st is # labels, the 2nd is logits or probs. - loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray], + loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray], str, utils.LossFunction] = utils.LossFunction.CROSS_ENTROPY # Whether `loss_function` will be called with logits or probs. If not set # (None), will decide by availablity of logits and probs and logits is diff --git a/tensorflow_privacy/privacy/privacy_tests/utils.py b/tensorflow_privacy/privacy/privacy_tests/utils.py index b09384a..d448f95 100644 --- a/tensorflow_privacy/privacy/privacy_tests/utils.py +++ b/tensorflow_privacy/privacy/privacy_tests/utils.py @@ -146,7 +146,7 @@ def string_to_loss_function(string: str): def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray], logits: Optional[np.ndarray], probs: Optional[np.ndarray], loss_function: Union[Callable[[np.ndarray, np.ndarray], - np.ndarray], LossFunction], + np.ndarray], LossFunction, str], loss_function_using_logits: Optional[bool], multilabel_data: Optional[bool]) -> Optional[np.ndarray]: """Calculates (if needed) losses. @@ -176,6 +176,9 @@ def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray], raise ValueError('We need probs to compute loss, but it is set to None.') predictions = logits if loss_function_using_logits else probs + + if isinstance(loss_function, str): + loss_function = string_to_loss_function(loss_function) if loss_function == LossFunction.CROSS_ENTROPY: if multilabel_data: loss = multilabel_bce_loss(labels, predictions, diff --git a/tensorflow_privacy/privacy/privacy_tests/utils_test.py b/tensorflow_privacy/privacy/privacy_tests/utils_test.py index b65a3b9..df097f6 100644 --- a/tensorflow_privacy/privacy/privacy_tests/utils_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/utils_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + from absl.testing import absltest from absl.testing import parameterized import numpy as np @@ -198,5 +200,39 @@ class TestMultilabelBCELoss(parameterized.TestCase): pred, from_logits) +class TestGetLoss(parameterized.TestCase): + + @parameterized.named_parameters( + ('xe', utils.LossFunction.CROSS_ENTROPY, False), + ('xe str', 'cross_entropy', False), + ('xe multi', utils.LossFunction.CROSS_ENTROPY, True), + ('xe multi str', 'cross_entropy', True), + ('sq', utils.LossFunction.SQUARED, False), + ('sq str', 'squared', False), + ) + @mock.patch.object(utils, 'squared_loss') + @mock.patch.object(utils, 'multilabel_bce_loss') + @mock.patch.object(utils, 'log_loss') + def test_get_loss_call_loss_function(self, loss_function, multilabel_data, + mock_log_loss, mock_multilabel_bce_loss, + mock_squared_loss): + """Test if get_loss calls the correct loss_function.""" + utils.get_loss( + loss=None, + labels=np.array([0]), + logits=np.array([[0.1, -0.1]]), + probs=None, + loss_function=loss_function, + loss_function_using_logits=True, + multilabel_data=multilabel_data) + if loss_function in ['cross_entropy', utils.LossFunction.CROSS_ENTROPY]: + if not multilabel_data: + mock_log_loss.assert_called_once() + else: + mock_multilabel_bce_loss.assert_called_once() + else: + mock_squared_loss.assert_called_once() + + if __name__ == '__main__': absltest.main()