Allow specifying loss function with string.

PiperOrigin-RevId: 465333272
This commit is contained in:
Shuang Song 2022-08-04 09:30:59 -07:00 committed by A. Unique TensorFlower
parent a8a5206841
commit a9abfbc244
3 changed files with 43 additions and 3 deletions

View file

@ -215,10 +215,11 @@ class AttackInputData:
entropy_test: Optional[np.ndarray] = None entropy_test: Optional[np.ndarray] = None
# If loss is not explicitly specified, this function will be used to derive # If loss is not explicitly specified, this function will be used to derive
# loss from logits and labels. It can be a pre-defined `LossFunction`. # loss from logits and labels. It can be a pre-defined `LossFunction` or its
# string representation, or a callable.
# If a callable is provided, it should take in two argument, the 1st is # If a callable is provided, it should take in two argument, the 1st is
# labels, the 2nd is logits or probs. # labels, the 2nd is logits or probs.
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray], loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray], str,
utils.LossFunction] = utils.LossFunction.CROSS_ENTROPY utils.LossFunction] = utils.LossFunction.CROSS_ENTROPY
# Whether `loss_function` will be called with logits or probs. If not set # Whether `loss_function` will be called with logits or probs. If not set
# (None), will decide by availablity of logits and probs and logits is # (None), will decide by availablity of logits and probs and logits is

View file

@ -146,7 +146,7 @@ def string_to_loss_function(string: str):
def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray], def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
logits: Optional[np.ndarray], probs: Optional[np.ndarray], logits: Optional[np.ndarray], probs: Optional[np.ndarray],
loss_function: Union[Callable[[np.ndarray, np.ndarray], loss_function: Union[Callable[[np.ndarray, np.ndarray],
np.ndarray], LossFunction], np.ndarray], LossFunction, str],
loss_function_using_logits: Optional[bool], loss_function_using_logits: Optional[bool],
multilabel_data: Optional[bool]) -> Optional[np.ndarray]: multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
"""Calculates (if needed) losses. """Calculates (if needed) losses.
@ -176,6 +176,9 @@ def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
raise ValueError('We need probs to compute loss, but it is set to None.') raise ValueError('We need probs to compute loss, but it is set to None.')
predictions = logits if loss_function_using_logits else probs predictions = logits if loss_function_using_logits else probs
if isinstance(loss_function, str):
loss_function = string_to_loss_function(loss_function)
if loss_function == LossFunction.CROSS_ENTROPY: if loss_function == LossFunction.CROSS_ENTROPY:
if multilabel_data: if multilabel_data:
loss = multilabel_bce_loss(labels, predictions, loss = multilabel_bce_loss(labels, predictions,

View file

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from unittest import mock
from absl.testing import absltest from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
@ -198,5 +200,39 @@ class TestMultilabelBCELoss(parameterized.TestCase):
pred, from_logits) pred, from_logits)
class TestGetLoss(parameterized.TestCase):
@parameterized.named_parameters(
('xe', utils.LossFunction.CROSS_ENTROPY, False),
('xe str', 'cross_entropy', False),
('xe multi', utils.LossFunction.CROSS_ENTROPY, True),
('xe multi str', 'cross_entropy', True),
('sq', utils.LossFunction.SQUARED, False),
('sq str', 'squared', False),
)
@mock.patch.object(utils, 'squared_loss')
@mock.patch.object(utils, 'multilabel_bce_loss')
@mock.patch.object(utils, 'log_loss')
def test_get_loss_call_loss_function(self, loss_function, multilabel_data,
mock_log_loss, mock_multilabel_bce_loss,
mock_squared_loss):
"""Test if get_loss calls the correct loss_function."""
utils.get_loss(
loss=None,
labels=np.array([0]),
logits=np.array([[0.1, -0.1]]),
probs=None,
loss_function=loss_function,
loss_function_using_logits=True,
multilabel_data=multilabel_data)
if loss_function in ['cross_entropy', utils.LossFunction.CROSS_ENTROPY]:
if not multilabel_data:
mock_log_loss.assert_called_once()
else:
mock_multilabel_bce_loss.assert_called_once()
else:
mock_squared_loss.assert_called_once()
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()