Add additional tests and checks on the passed loss function.

PiperOrigin-RevId: 532225904
This commit is contained in:
A. Unique TensorFlower 2023-05-15 14:26:41 -07:00
parent 8fdac5f833
commit 0f5acf868e
2 changed files with 94 additions and 1 deletions

View file

@ -128,6 +128,16 @@ def compute_gradient_norms(
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
per_example_loss_fn = input_model.loss.from_config(loss_config)
losses = per_example_loss_fn(y_batch, model_outputs)
if losses.shape is None:
raise NotImplementedError(
"The unreduced (or per-example) loss's shape cannot be `None`"
)
if len(losses.shape) != 1:
raise NotImplementedError(
'The unreduced (or per-example) loss needs to have a shape of length '
'one, but received an unreduced loss of shape length %s'
% len(losses.shape)
)
if num_microbatches is not None:
losses = tf.reduce_mean(
lr.add_microbatch_axis(losses, num_microbatches), axis=1
@ -239,6 +249,11 @@ def compute_clipped_gradients_and_outputs(
`input_model`, weighted by the loss weights generated by a specific
`compute_clip_weights()` call.
"""
if input_model.loss.reduction == 'none':
raise NotImplementedError(
'Fast gradient clipping does not support '
'models with unreduced loss functions.'
)
if clipping_loss is None:
clipping_loss = input_model.compiled_loss
gradient_norms = compute_gradient_norms(

View file

@ -426,7 +426,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
# TODO(wkong): Test sparse input tensors when the GitHub CI environment
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
# supports them for embeddings.
@parameterized.product(
x_batch=[
@ -541,5 +541,83 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
class ClipGradsComputeClippedGradsAndOutputsTest(
tf.test.TestCase, parameterized.TestCase
):
def setUp(self):
super().setUp()
dense_generator = lambda a, b: tf.keras.layers.Dense(b)
self._input_dim = 2
self._output_dim = 3
self._model = make_two_layer_sequential_model(
dense_generator, self._input_dim, self._output_dim
)
@parameterized.product(
batch_size=[1, 2, 10],
l2_norm_clip=[0.1, 1.0, 10],
is_eager=[True, False],
reduction=['auto', 'sum', 'sum_over_batch_size', 'none'],
)
def test_clipped_gradients_on_different_losses(
self, batch_size, l2_norm_clip, is_eager, reduction
):
loss_fn = tf.keras.losses.MeanSquaredError(reduction=reduction)
self._model.compile(loss=loss_fn, run_eagerly=is_eager)
x_batch = tf.reshape(
tf.range(batch_size * self._input_dim, dtype=tf.float32),
[batch_size, -1],
)
y_batch = tf.reshape(
1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1]
)
# Stop early for efficiency.
if reduction == 'none':
self.assertRaises(
NotImplementedError,
# function tested
clip_grads.compute_clipped_gradients_and_outputs,
# function args
self._model,
x_batch,
y_batch,
l2_norm_clip,
layer_registry.make_default_layer_registry(),
)
return
# NOTE: losses from this point are scalar losses.
with tf.GradientTape() as tape:
y_pred = self._model(x_batch)
loss_value = loss_fn(y_pred, y_batch)
true_grads = tape.gradient(loss_value, self._model.trainable_variables)
clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs(
self._model,
x_batch,
y_batch,
l2_norm_clip,
layer_registry.make_default_layer_registry(),
)
# Computes the L2 norm manually.
def compute_l2_norm(t):
sqr_sum_fn = lambda x: tf.reduce_sum(tf.square(x))
return tf.sqrt(tf.add_n(tf.nest.map_structure(sqr_sum_fn, t)))
true_norm = compute_l2_norm(true_grads)
computed_norm = compute_l2_norm(clipped_grads)
norm_bound = (
l2_norm_clip * batch_size if reduction == 'sum' else l2_norm_clip
)
if true_norm >= norm_bound:
# All of the per-example gradient norms should be less than the L2 norm
# clip value. Hence, by the triangle inequality, the gradient norm of the
# summed loss (averaged loss) should be less than the clip value times
# the batch size (just the clip value).
self.assertLessEqual(computed_norm, norm_bound)
else:
self.assertAlmostEqual(computed_norm, true_norm)
if __name__ == '__main__':
tf.test.main()