forked from 626_privacy/tensorflow_privacy
Adds a utility function for formating list into string.
PiperOrigin-RevId: 484026229
This commit is contained in:
parent
7d7b670f5d
commit
f7e1e61823
3 changed files with 50 additions and 2 deletions
|
@ -15,7 +15,10 @@ py_test(
|
|||
srcs = ["utils_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":utils"],
|
||||
deps = [
|
||||
":utils",
|
||||
"//third_party/py/parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
|
|
@ -15,7 +15,8 @@
|
|||
|
||||
import enum
|
||||
import logging
|
||||
from typing import Callable, Optional, Union
|
||||
import numbers
|
||||
from typing import Callable, Iterable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from scipy import special
|
||||
|
@ -254,3 +255,9 @@ def get_loss(
|
|||
else:
|
||||
loss = loss_function(labels, predictions, sample_weight)
|
||||
return loss
|
||||
|
||||
|
||||
def format_number_list(input_list: Iterable[numbers.Number],
|
||||
precision: int = 4) -> str:
|
||||
"""Formats list of numbers as a string."""
|
||||
return ', '.join([f'{x:.{precision}f}' for x in input_list])
|
||||
|
|
|
@ -280,5 +280,43 @@ class TestGetLoss(parameterized.TestCase):
|
|||
mock_squared_loss.assert_called_once()
|
||||
|
||||
|
||||
@parameterized.parameters(
|
||||
# integer list, one element
|
||||
([1], 0, '1'),
|
||||
([-1], 1, '-1.0'),
|
||||
# integer list, multiple elements
|
||||
([1, 2, 3], 0, '1, 2, 3'),
|
||||
([-1, 2, -3], 1, '-1.0, 2.0, -3.0'),
|
||||
# float list, one element
|
||||
([1.1], 0, '1'),
|
||||
([1.1], 1, '1.1'),
|
||||
([-1.1], 3, '-1.100'),
|
||||
# float and integer combined, multiple elements
|
||||
([0, 1.1, 2.22, 3.333], 0, '0, 1, 2, 3'),
|
||||
([0, 1.1, -2.22, 3.333], 1, '0.0, 1.1, -2.2, 3.3'),
|
||||
([0, 1.1, 2.22, 3.333], 3, '0.000, 1.100, 2.220, 3.333'),
|
||||
# inf and nan
|
||||
([np.inf, 1, -2.22, -np.inf, np.nan], 1, 'inf, 1.0, -2.2, -inf, nan'),
|
||||
# empty list
|
||||
([], 1, ''),
|
||||
# iterables other than list
|
||||
((np.inf, 1, 2.2), 0, 'inf, 1, 2'),
|
||||
(range(-1, 3), 1, '-1.0, 0.0, 1.0, 2.0'))
|
||||
class TestPrintNumberList(parameterized.TestCase):
|
||||
|
||||
def test_format_list(self, input_list, precision, expected_output):
|
||||
self.assertEqual(
|
||||
utils.format_number_list(input_list, precision), expected_output)
|
||||
|
||||
def test_format_iterator(self, input_list, precision, expected_output):
|
||||
self.assertEqual(
|
||||
utils.format_number_list(iter(input_list), precision), expected_output)
|
||||
|
||||
def test_format_numpy_array(self, input_list, precision, expected_output):
|
||||
self.assertEqual(
|
||||
utils.format_number_list(np.array(input_list), precision),
|
||||
expected_output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue