Adds a utility function for formating list into string.
PiperOrigin-RevId: 484026229
This commit is contained in:
parent
7d7b670f5d
commit
f7e1e61823
3 changed files with 50 additions and 2 deletions
|
@ -15,7 +15,10 @@ py_test(
|
||||||
srcs = ["utils_test.py"],
|
srcs = ["utils_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [":utils"],
|
deps = [
|
||||||
|
":utils",
|
||||||
|
"//third_party/py/parameterized",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
|
|
@ -15,7 +15,8 @@
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
from typing import Callable, Optional, Union
|
import numbers
|
||||||
|
from typing import Callable, Iterable, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import special
|
from scipy import special
|
||||||
|
@ -254,3 +255,9 @@ def get_loss(
|
||||||
else:
|
else:
|
||||||
loss = loss_function(labels, predictions, sample_weight)
|
loss = loss_function(labels, predictions, sample_weight)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def format_number_list(input_list: Iterable[numbers.Number],
|
||||||
|
precision: int = 4) -> str:
|
||||||
|
"""Formats list of numbers as a string."""
|
||||||
|
return ', '.join([f'{x:.{precision}f}' for x in input_list])
|
||||||
|
|
|
@ -280,5 +280,43 @@ class TestGetLoss(parameterized.TestCase):
|
||||||
mock_squared_loss.assert_called_once()
|
mock_squared_loss.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
# integer list, one element
|
||||||
|
([1], 0, '1'),
|
||||||
|
([-1], 1, '-1.0'),
|
||||||
|
# integer list, multiple elements
|
||||||
|
([1, 2, 3], 0, '1, 2, 3'),
|
||||||
|
([-1, 2, -3], 1, '-1.0, 2.0, -3.0'),
|
||||||
|
# float list, one element
|
||||||
|
([1.1], 0, '1'),
|
||||||
|
([1.1], 1, '1.1'),
|
||||||
|
([-1.1], 3, '-1.100'),
|
||||||
|
# float and integer combined, multiple elements
|
||||||
|
([0, 1.1, 2.22, 3.333], 0, '0, 1, 2, 3'),
|
||||||
|
([0, 1.1, -2.22, 3.333], 1, '0.0, 1.1, -2.2, 3.3'),
|
||||||
|
([0, 1.1, 2.22, 3.333], 3, '0.000, 1.100, 2.220, 3.333'),
|
||||||
|
# inf and nan
|
||||||
|
([np.inf, 1, -2.22, -np.inf, np.nan], 1, 'inf, 1.0, -2.2, -inf, nan'),
|
||||||
|
# empty list
|
||||||
|
([], 1, ''),
|
||||||
|
# iterables other than list
|
||||||
|
((np.inf, 1, 2.2), 0, 'inf, 1, 2'),
|
||||||
|
(range(-1, 3), 1, '-1.0, 0.0, 1.0, 2.0'))
|
||||||
|
class TestPrintNumberList(parameterized.TestCase):
|
||||||
|
|
||||||
|
def test_format_list(self, input_list, precision, expected_output):
|
||||||
|
self.assertEqual(
|
||||||
|
utils.format_number_list(input_list, precision), expected_output)
|
||||||
|
|
||||||
|
def test_format_iterator(self, input_list, precision, expected_output):
|
||||||
|
self.assertEqual(
|
||||||
|
utils.format_number_list(iter(input_list), precision), expected_output)
|
||||||
|
|
||||||
|
def test_format_numpy_array(self, input_list, precision, expected_output):
|
||||||
|
self.assertEqual(
|
||||||
|
utils.format_number_list(np.array(input_list), precision),
|
||||||
|
expected_output)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
Loading…
Reference in a new issue