Allow specifying loss function with string.
PiperOrigin-RevId: 465333272
This commit is contained in:
parent
a8a5206841
commit
a9abfbc244
3 changed files with 43 additions and 3 deletions
|
@ -215,10 +215,11 @@ class AttackInputData:
|
|||
entropy_test: Optional[np.ndarray] = None
|
||||
|
||||
# If loss is not explicitly specified, this function will be used to derive
|
||||
# loss from logits and labels. It can be a pre-defined `LossFunction`.
|
||||
# loss from logits and labels. It can be a pre-defined `LossFunction` or its
|
||||
# string representation, or a callable.
|
||||
# If a callable is provided, it should take in two argument, the 1st is
|
||||
# labels, the 2nd is logits or probs.
|
||||
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
|
||||
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray], str,
|
||||
utils.LossFunction] = utils.LossFunction.CROSS_ENTROPY
|
||||
# Whether `loss_function` will be called with logits or probs. If not set
|
||||
# (None), will decide by availablity of logits and probs and logits is
|
||||
|
|
|
@ -146,7 +146,7 @@ def string_to_loss_function(string: str):
|
|||
def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
|
||||
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
|
||||
loss_function: Union[Callable[[np.ndarray, np.ndarray],
|
||||
np.ndarray], LossFunction],
|
||||
np.ndarray], LossFunction, str],
|
||||
loss_function_using_logits: Optional[bool],
|
||||
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
|
||||
"""Calculates (if needed) losses.
|
||||
|
@ -176,6 +176,9 @@ def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
|
|||
raise ValueError('We need probs to compute loss, but it is set to None.')
|
||||
|
||||
predictions = logits if loss_function_using_logits else probs
|
||||
|
||||
if isinstance(loss_function, str):
|
||||
loss_function = string_to_loss_function(loss_function)
|
||||
if loss_function == LossFunction.CROSS_ENTROPY:
|
||||
if multilabel_data:
|
||||
loss = multilabel_bce_loss(labels, predictions,
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
@ -198,5 +200,39 @@ class TestMultilabelBCELoss(parameterized.TestCase):
|
|||
pred, from_logits)
|
||||
|
||||
|
||||
class TestGetLoss(parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('xe', utils.LossFunction.CROSS_ENTROPY, False),
|
||||
('xe str', 'cross_entropy', False),
|
||||
('xe multi', utils.LossFunction.CROSS_ENTROPY, True),
|
||||
('xe multi str', 'cross_entropy', True),
|
||||
('sq', utils.LossFunction.SQUARED, False),
|
||||
('sq str', 'squared', False),
|
||||
)
|
||||
@mock.patch.object(utils, 'squared_loss')
|
||||
@mock.patch.object(utils, 'multilabel_bce_loss')
|
||||
@mock.patch.object(utils, 'log_loss')
|
||||
def test_get_loss_call_loss_function(self, loss_function, multilabel_data,
|
||||
mock_log_loss, mock_multilabel_bce_loss,
|
||||
mock_squared_loss):
|
||||
"""Test if get_loss calls the correct loss_function."""
|
||||
utils.get_loss(
|
||||
loss=None,
|
||||
labels=np.array([0]),
|
||||
logits=np.array([[0.1, -0.1]]),
|
||||
probs=None,
|
||||
loss_function=loss_function,
|
||||
loss_function_using_logits=True,
|
||||
multilabel_data=multilabel_data)
|
||||
if loss_function in ['cross_entropy', utils.LossFunction.CROSS_ENTROPY]:
|
||||
if not multilabel_data:
|
||||
mock_log_loss.assert_called_once()
|
||||
else:
|
||||
mock_multilabel_bce_loss.assert_called_once()
|
||||
else:
|
||||
mock_squared_loss.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue