Allow specifying loss function with string.
PiperOrigin-RevId: 465333272
This commit is contained in:
parent
a8a5206841
commit
a9abfbc244
3 changed files with 43 additions and 3 deletions
|
@ -215,10 +215,11 @@ class AttackInputData:
|
||||||
entropy_test: Optional[np.ndarray] = None
|
entropy_test: Optional[np.ndarray] = None
|
||||||
|
|
||||||
# If loss is not explicitly specified, this function will be used to derive
|
# If loss is not explicitly specified, this function will be used to derive
|
||||||
# loss from logits and labels. It can be a pre-defined `LossFunction`.
|
# loss from logits and labels. It can be a pre-defined `LossFunction` or its
|
||||||
|
# string representation, or a callable.
|
||||||
# If a callable is provided, it should take in two argument, the 1st is
|
# If a callable is provided, it should take in two argument, the 1st is
|
||||||
# labels, the 2nd is logits or probs.
|
# labels, the 2nd is logits or probs.
|
||||||
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
|
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray], str,
|
||||||
utils.LossFunction] = utils.LossFunction.CROSS_ENTROPY
|
utils.LossFunction] = utils.LossFunction.CROSS_ENTROPY
|
||||||
# Whether `loss_function` will be called with logits or probs. If not set
|
# Whether `loss_function` will be called with logits or probs. If not set
|
||||||
# (None), will decide by availablity of logits and probs and logits is
|
# (None), will decide by availablity of logits and probs and logits is
|
||||||
|
|
|
@ -146,7 +146,7 @@ def string_to_loss_function(string: str):
|
||||||
def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
|
def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
|
||||||
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
|
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
|
||||||
loss_function: Union[Callable[[np.ndarray, np.ndarray],
|
loss_function: Union[Callable[[np.ndarray, np.ndarray],
|
||||||
np.ndarray], LossFunction],
|
np.ndarray], LossFunction, str],
|
||||||
loss_function_using_logits: Optional[bool],
|
loss_function_using_logits: Optional[bool],
|
||||||
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
|
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
|
||||||
"""Calculates (if needed) losses.
|
"""Calculates (if needed) losses.
|
||||||
|
@ -176,6 +176,9 @@ def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
|
||||||
raise ValueError('We need probs to compute loss, but it is set to None.')
|
raise ValueError('We need probs to compute loss, but it is set to None.')
|
||||||
|
|
||||||
predictions = logits if loss_function_using_logits else probs
|
predictions = logits if loss_function_using_logits else probs
|
||||||
|
|
||||||
|
if isinstance(loss_function, str):
|
||||||
|
loss_function = string_to_loss_function(loss_function)
|
||||||
if loss_function == LossFunction.CROSS_ENTROPY:
|
if loss_function == LossFunction.CROSS_ENTROPY:
|
||||||
if multilabel_data:
|
if multilabel_data:
|
||||||
loss = multilabel_bce_loss(labels, predictions,
|
loss = multilabel_bce_loss(labels, predictions,
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -198,5 +200,39 @@ class TestMultilabelBCELoss(parameterized.TestCase):
|
||||||
pred, from_logits)
|
pred, from_logits)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLoss(parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('xe', utils.LossFunction.CROSS_ENTROPY, False),
|
||||||
|
('xe str', 'cross_entropy', False),
|
||||||
|
('xe multi', utils.LossFunction.CROSS_ENTROPY, True),
|
||||||
|
('xe multi str', 'cross_entropy', True),
|
||||||
|
('sq', utils.LossFunction.SQUARED, False),
|
||||||
|
('sq str', 'squared', False),
|
||||||
|
)
|
||||||
|
@mock.patch.object(utils, 'squared_loss')
|
||||||
|
@mock.patch.object(utils, 'multilabel_bce_loss')
|
||||||
|
@mock.patch.object(utils, 'log_loss')
|
||||||
|
def test_get_loss_call_loss_function(self, loss_function, multilabel_data,
|
||||||
|
mock_log_loss, mock_multilabel_bce_loss,
|
||||||
|
mock_squared_loss):
|
||||||
|
"""Test if get_loss calls the correct loss_function."""
|
||||||
|
utils.get_loss(
|
||||||
|
loss=None,
|
||||||
|
labels=np.array([0]),
|
||||||
|
logits=np.array([[0.1, -0.1]]),
|
||||||
|
probs=None,
|
||||||
|
loss_function=loss_function,
|
||||||
|
loss_function_using_logits=True,
|
||||||
|
multilabel_data=multilabel_data)
|
||||||
|
if loss_function in ['cross_entropy', utils.LossFunction.CROSS_ENTROPY]:
|
||||||
|
if not multilabel_data:
|
||||||
|
mock_log_loss.assert_called_once()
|
||||||
|
else:
|
||||||
|
mock_multilabel_bce_loss.assert_called_once()
|
||||||
|
else:
|
||||||
|
mock_squared_loss.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
Loading…
Reference in a new issue