Adds a utility function for formating list into string.

PiperOrigin-RevId: 484026229
This commit is contained in:
Shuang Song 2022-10-26 11:32:41 -07:00 committed by A. Unique TensorFlower
parent 7d7b670f5d
commit f7e1e61823
3 changed files with 50 additions and 2 deletions

View file

@ -15,7 +15,10 @@ py_test(
srcs = ["utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":utils"],
deps = [
":utils",
"//third_party/py/parameterized",
],
)
py_test(

View file

@ -15,7 +15,8 @@
import enum
import logging
from typing import Callable, Optional, Union
import numbers
from typing import Callable, Iterable, Optional, Union
import numpy as np
from scipy import special
@ -254,3 +255,9 @@ def get_loss(
else:
loss = loss_function(labels, predictions, sample_weight)
return loss
def format_number_list(input_list: Iterable[numbers.Number],
precision: int = 4) -> str:
"""Formats list of numbers as a string."""
return ', '.join([f'{x:.{precision}f}' for x in input_list])

View file

@ -280,5 +280,43 @@ class TestGetLoss(parameterized.TestCase):
mock_squared_loss.assert_called_once()
@parameterized.parameters(
# integer list, one element
([1], 0, '1'),
([-1], 1, '-1.0'),
# integer list, multiple elements
([1, 2, 3], 0, '1, 2, 3'),
([-1, 2, -3], 1, '-1.0, 2.0, -3.0'),
# float list, one element
([1.1], 0, '1'),
([1.1], 1, '1.1'),
([-1.1], 3, '-1.100'),
# float and integer combined, multiple elements
([0, 1.1, 2.22, 3.333], 0, '0, 1, 2, 3'),
([0, 1.1, -2.22, 3.333], 1, '0.0, 1.1, -2.2, 3.3'),
([0, 1.1, 2.22, 3.333], 3, '0.000, 1.100, 2.220, 3.333'),
# inf and nan
([np.inf, 1, -2.22, -np.inf, np.nan], 1, 'inf, 1.0, -2.2, -inf, nan'),
# empty list
([], 1, ''),
# iterables other than list
((np.inf, 1, 2.2), 0, 'inf, 1, 2'),
(range(-1, 3), 1, '-1.0, 0.0, 1.0, 2.0'))
class TestPrintNumberList(parameterized.TestCase):
def test_format_list(self, input_list, precision, expected_output):
self.assertEqual(
utils.format_number_list(input_list, precision), expected_output)
def test_format_iterator(self, input_list, precision, expected_output):
self.assertEqual(
utils.format_number_list(iter(input_list), precision), expected_output)
def test_format_numpy_array(self, input_list, precision, expected_output):
self.assertEqual(
utils.format_number_list(np.array(input_list), precision),
expected_output)
if __name__ == '__main__':
absltest.main()