Adds a utility function for formating list into string.

PiperOrigin-RevId: 484026229
This commit is contained in:
Shuang Song 2022-10-26 11:32:41 -07:00 committed by A. Unique TensorFlower
parent 7d7b670f5d
commit f7e1e61823
3 changed files with 50 additions and 2 deletions

View file

@ -15,7 +15,10 @@ py_test(
srcs = ["utils_test.py"], srcs = ["utils_test.py"],
python_version = "PY3", python_version = "PY3",
srcs_version = "PY3", srcs_version = "PY3",
deps = [":utils"], deps = [
":utils",
"//third_party/py/parameterized",
],
) )
py_test( py_test(

View file

@ -15,7 +15,8 @@
import enum import enum
import logging import logging
from typing import Callable, Optional, Union import numbers
from typing import Callable, Iterable, Optional, Union
import numpy as np import numpy as np
from scipy import special from scipy import special
@ -254,3 +255,9 @@ def get_loss(
else: else:
loss = loss_function(labels, predictions, sample_weight) loss = loss_function(labels, predictions, sample_weight)
return loss return loss
def format_number_list(input_list: Iterable[numbers.Number],
precision: int = 4) -> str:
"""Formats list of numbers as a string."""
return ', '.join([f'{x:.{precision}f}' for x in input_list])

View file

@ -280,5 +280,43 @@ class TestGetLoss(parameterized.TestCase):
mock_squared_loss.assert_called_once() mock_squared_loss.assert_called_once()
@parameterized.parameters(
# integer list, one element
([1], 0, '1'),
([-1], 1, '-1.0'),
# integer list, multiple elements
([1, 2, 3], 0, '1, 2, 3'),
([-1, 2, -3], 1, '-1.0, 2.0, -3.0'),
# float list, one element
([1.1], 0, '1'),
([1.1], 1, '1.1'),
([-1.1], 3, '-1.100'),
# float and integer combined, multiple elements
([0, 1.1, 2.22, 3.333], 0, '0, 1, 2, 3'),
([0, 1.1, -2.22, 3.333], 1, '0.0, 1.1, -2.2, 3.3'),
([0, 1.1, 2.22, 3.333], 3, '0.000, 1.100, 2.220, 3.333'),
# inf and nan
([np.inf, 1, -2.22, -np.inf, np.nan], 1, 'inf, 1.0, -2.2, -inf, nan'),
# empty list
([], 1, ''),
# iterables other than list
((np.inf, 1, 2.2), 0, 'inf, 1, 2'),
(range(-1, 3), 1, '-1.0, 0.0, 1.0, 2.0'))
class TestPrintNumberList(parameterized.TestCase):
def test_format_list(self, input_list, precision, expected_output):
self.assertEqual(
utils.format_number_list(input_list, precision), expected_output)
def test_format_iterator(self, input_list, precision, expected_output):
self.assertEqual(
utils.format_number_list(iter(input_list), precision), expected_output)
def test_format_numpy_array(self, input_list, precision, expected_output):
self.assertEqual(
utils.format_number_list(np.array(input_list), precision),
expected_output)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()