Add additional tests and checks on the passed loss function.
PiperOrigin-RevId: 532225904
This commit is contained in:
parent
8fdac5f833
commit
0f5acf868e
2 changed files with 94 additions and 1 deletions
|
@ -128,6 +128,16 @@ def compute_gradient_norms(
|
||||||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||||
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
||||||
losses = per_example_loss_fn(y_batch, model_outputs)
|
losses = per_example_loss_fn(y_batch, model_outputs)
|
||||||
|
if losses.shape is None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The unreduced (or per-example) loss's shape cannot be `None`"
|
||||||
|
)
|
||||||
|
if len(losses.shape) != 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
'The unreduced (or per-example) loss needs to have a shape of length '
|
||||||
|
'one, but received an unreduced loss of shape length %s'
|
||||||
|
% len(losses.shape)
|
||||||
|
)
|
||||||
if num_microbatches is not None:
|
if num_microbatches is not None:
|
||||||
losses = tf.reduce_mean(
|
losses = tf.reduce_mean(
|
||||||
lr.add_microbatch_axis(losses, num_microbatches), axis=1
|
lr.add_microbatch_axis(losses, num_microbatches), axis=1
|
||||||
|
@ -239,6 +249,11 @@ def compute_clipped_gradients_and_outputs(
|
||||||
`input_model`, weighted by the loss weights generated by a specific
|
`input_model`, weighted by the loss weights generated by a specific
|
||||||
`compute_clip_weights()` call.
|
`compute_clip_weights()` call.
|
||||||
"""
|
"""
|
||||||
|
if input_model.loss.reduction == 'none':
|
||||||
|
raise NotImplementedError(
|
||||||
|
'Fast gradient clipping does not support '
|
||||||
|
'models with unreduced loss functions.'
|
||||||
|
)
|
||||||
if clipping_loss is None:
|
if clipping_loss is None:
|
||||||
clipping_loss = input_model.compiled_loss
|
clipping_loss = input_model.compiled_loss
|
||||||
gradient_norms = compute_gradient_norms(
|
gradient_norms = compute_gradient_norms(
|
||||||
|
|
|
@ -426,7 +426,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
# TODO(wkong): Test sparse input tensors when the GitHub CI environment
|
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
|
||||||
# supports them for embeddings.
|
# supports them for embeddings.
|
||||||
@parameterized.product(
|
@parameterized.product(
|
||||||
x_batch=[
|
x_batch=[
|
||||||
|
@ -541,5 +541,83 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipGradsComputeClippedGradsAndOutputsTest(
|
||||||
|
tf.test.TestCase, parameterized.TestCase
|
||||||
|
):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
dense_generator = lambda a, b: tf.keras.layers.Dense(b)
|
||||||
|
self._input_dim = 2
|
||||||
|
self._output_dim = 3
|
||||||
|
self._model = make_two_layer_sequential_model(
|
||||||
|
dense_generator, self._input_dim, self._output_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
batch_size=[1, 2, 10],
|
||||||
|
l2_norm_clip=[0.1, 1.0, 10],
|
||||||
|
is_eager=[True, False],
|
||||||
|
reduction=['auto', 'sum', 'sum_over_batch_size', 'none'],
|
||||||
|
)
|
||||||
|
def test_clipped_gradients_on_different_losses(
|
||||||
|
self, batch_size, l2_norm_clip, is_eager, reduction
|
||||||
|
):
|
||||||
|
loss_fn = tf.keras.losses.MeanSquaredError(reduction=reduction)
|
||||||
|
self._model.compile(loss=loss_fn, run_eagerly=is_eager)
|
||||||
|
x_batch = tf.reshape(
|
||||||
|
tf.range(batch_size * self._input_dim, dtype=tf.float32),
|
||||||
|
[batch_size, -1],
|
||||||
|
)
|
||||||
|
y_batch = tf.reshape(
|
||||||
|
1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1]
|
||||||
|
)
|
||||||
|
# Stop early for efficiency.
|
||||||
|
if reduction == 'none':
|
||||||
|
self.assertRaises(
|
||||||
|
NotImplementedError,
|
||||||
|
# function tested
|
||||||
|
clip_grads.compute_clipped_gradients_and_outputs,
|
||||||
|
# function args
|
||||||
|
self._model,
|
||||||
|
x_batch,
|
||||||
|
y_batch,
|
||||||
|
l2_norm_clip,
|
||||||
|
layer_registry.make_default_layer_registry(),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
# NOTE: losses from this point are scalar losses.
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
y_pred = self._model(x_batch)
|
||||||
|
loss_value = loss_fn(y_pred, y_batch)
|
||||||
|
true_grads = tape.gradient(loss_value, self._model.trainable_variables)
|
||||||
|
clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs(
|
||||||
|
self._model,
|
||||||
|
x_batch,
|
||||||
|
y_batch,
|
||||||
|
l2_norm_clip,
|
||||||
|
layer_registry.make_default_layer_registry(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Computes the L2 norm manually.
|
||||||
|
def compute_l2_norm(t):
|
||||||
|
sqr_sum_fn = lambda x: tf.reduce_sum(tf.square(x))
|
||||||
|
return tf.sqrt(tf.add_n(tf.nest.map_structure(sqr_sum_fn, t)))
|
||||||
|
|
||||||
|
true_norm = compute_l2_norm(true_grads)
|
||||||
|
computed_norm = compute_l2_norm(clipped_grads)
|
||||||
|
norm_bound = (
|
||||||
|
l2_norm_clip * batch_size if reduction == 'sum' else l2_norm_clip
|
||||||
|
)
|
||||||
|
if true_norm >= norm_bound:
|
||||||
|
# All of the per-example gradient norms should be less than the L2 norm
|
||||||
|
# clip value. Hence, by the triangle inequality, the gradient norm of the
|
||||||
|
# summed loss (averaged loss) should be less than the clip value times
|
||||||
|
# the batch size (just the clip value).
|
||||||
|
self.assertLessEqual(computed_norm, norm_bound)
|
||||||
|
else:
|
||||||
|
self.assertAlmostEqual(computed_norm, true_norm)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue