Allow specifying loss function with string.

PiperOrigin-RevId: 465333272
This commit is contained in:
Shuang Song 2022-08-04 09:30:59 -07:00 committed by A. Unique TensorFlower
parent a8a5206841
commit a9abfbc244
3 changed files with 43 additions and 3 deletions

View file

@ -215,10 +215,11 @@ class AttackInputData:
entropy_test: Optional[np.ndarray] = None
# If loss is not explicitly specified, this function will be used to derive
# loss from logits and labels. It can be a pre-defined `LossFunction`.
# loss from logits and labels. It can be a pre-defined `LossFunction` or its
# string representation, or a callable.
# If a callable is provided, it should take in two argument, the 1st is
# labels, the 2nd is logits or probs.
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray], str,
utils.LossFunction] = utils.LossFunction.CROSS_ENTROPY
# Whether `loss_function` will be called with logits or probs. If not set
# (None), will decide by availablity of logits and probs and logits is

View file

@ -146,7 +146,7 @@ def string_to_loss_function(string: str):
def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
loss_function: Union[Callable[[np.ndarray, np.ndarray],
np.ndarray], LossFunction],
np.ndarray], LossFunction, str],
loss_function_using_logits: Optional[bool],
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
"""Calculates (if needed) losses.
@ -176,6 +176,9 @@ def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
raise ValueError('We need probs to compute loss, but it is set to None.')
predictions = logits if loss_function_using_logits else probs
if isinstance(loss_function, str):
loss_function = string_to_loss_function(loss_function)
if loss_function == LossFunction.CROSS_ENTROPY:
if multilabel_data:
loss = multilabel_bce_loss(labels, predictions,

View file

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
@ -198,5 +200,39 @@ class TestMultilabelBCELoss(parameterized.TestCase):
pred, from_logits)
class TestGetLoss(parameterized.TestCase):
@parameterized.named_parameters(
('xe', utils.LossFunction.CROSS_ENTROPY, False),
('xe str', 'cross_entropy', False),
('xe multi', utils.LossFunction.CROSS_ENTROPY, True),
('xe multi str', 'cross_entropy', True),
('sq', utils.LossFunction.SQUARED, False),
('sq str', 'squared', False),
)
@mock.patch.object(utils, 'squared_loss')
@mock.patch.object(utils, 'multilabel_bce_loss')
@mock.patch.object(utils, 'log_loss')
def test_get_loss_call_loss_function(self, loss_function, multilabel_data,
mock_log_loss, mock_multilabel_bce_loss,
mock_squared_loss):
"""Test if get_loss calls the correct loss_function."""
utils.get_loss(
loss=None,
labels=np.array([0]),
logits=np.array([[0.1, -0.1]]),
probs=None,
loss_function=loss_function,
loss_function_using_logits=True,
multilabel_data=multilabel_data)
if loss_function in ['cross_entropy', utils.LossFunction.CROSS_ENTROPY]:
if not multilabel_data:
mock_log_loss.assert_called_once()
else:
mock_multilabel_bce_loss.assert_called_once()
else:
mock_squared_loss.assert_called_once()
if __name__ == '__main__':
absltest.main()