diff --git a/tensorflow_privacy/privacy/privacy_tests/BUILD b/tensorflow_privacy/privacy/privacy_tests/BUILD index 2a0eefc..ad133d3 100644 --- a/tensorflow_privacy/privacy/privacy_tests/BUILD +++ b/tensorflow_privacy/privacy/privacy_tests/BUILD @@ -15,7 +15,10 @@ py_test( srcs = ["utils_test.py"], python_version = "PY3", srcs_version = "PY3", - deps = [":utils"], + deps = [ + ":utils", + "//third_party/py/parameterized", + ], ) py_test( diff --git a/tensorflow_privacy/privacy/privacy_tests/utils.py b/tensorflow_privacy/privacy/privacy_tests/utils.py index e8f31aa..607500a 100644 --- a/tensorflow_privacy/privacy/privacy_tests/utils.py +++ b/tensorflow_privacy/privacy/privacy_tests/utils.py @@ -15,7 +15,8 @@ import enum import logging -from typing import Callable, Optional, Union +import numbers +from typing import Callable, Iterable, Optional, Union import numpy as np from scipy import special @@ -254,3 +255,9 @@ def get_loss( else: loss = loss_function(labels, predictions, sample_weight) return loss + + +def format_number_list(input_list: Iterable[numbers.Number], + precision: int = 4) -> str: + """Formats list of numbers as a string.""" + return ', '.join([f'{x:.{precision}f}' for x in input_list]) diff --git a/tensorflow_privacy/privacy/privacy_tests/utils_test.py b/tensorflow_privacy/privacy/privacy_tests/utils_test.py index cb592a8..edc58ac 100644 --- a/tensorflow_privacy/privacy/privacy_tests/utils_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/utils_test.py @@ -280,5 +280,43 @@ class TestGetLoss(parameterized.TestCase): mock_squared_loss.assert_called_once() +@parameterized.parameters( + # integer list, one element + ([1], 0, '1'), + ([-1], 1, '-1.0'), + # integer list, multiple elements + ([1, 2, 3], 0, '1, 2, 3'), + ([-1, 2, -3], 1, '-1.0, 2.0, -3.0'), + # float list, one element + ([1.1], 0, '1'), + ([1.1], 1, '1.1'), + ([-1.1], 3, '-1.100'), + # float and integer combined, multiple elements + ([0, 1.1, 2.22, 3.333], 0, '0, 1, 2, 3'), + ([0, 1.1, -2.22, 3.333], 1, '0.0, 1.1, -2.2, 3.3'), + ([0, 1.1, 2.22, 3.333], 3, '0.000, 1.100, 2.220, 3.333'), + # inf and nan + ([np.inf, 1, -2.22, -np.inf, np.nan], 1, 'inf, 1.0, -2.2, -inf, nan'), + # empty list + ([], 1, ''), + # iterables other than list + ((np.inf, 1, 2.2), 0, 'inf, 1, 2'), + (range(-1, 3), 1, '-1.0, 0.0, 1.0, 2.0')) +class TestPrintNumberList(parameterized.TestCase): + + def test_format_list(self, input_list, precision, expected_output): + self.assertEqual( + utils.format_number_list(input_list, precision), expected_output) + + def test_format_iterator(self, input_list, precision, expected_output): + self.assertEqual( + utils.format_number_list(iter(input_list), precision), expected_output) + + def test_format_numpy_array(self, input_list, precision, expected_output): + self.assertEqual( + utils.format_number_list(np.array(input_list), precision), + expected_output) + + if __name__ == '__main__': absltest.main()