Fix a gradient clipping bug for layer normalization layers with microbatch axes.

The previous code passed the unstacked gradients (a list) instead of the stacked gradients (a tensor) to the microbatcher, which led to unexpected behavior. This change passes the right argument and changes the original unit test to catch this bug.

PiperOrigin-RevId: 669413064
This commit is contained in:
William Kong 2024-08-30 12:40:33 -07:00 committed by A. Unique TensorFlower
parent b3963971e3
commit 66d05a22a3
2 changed files with 8 additions and 3 deletions

View file

@ -80,8 +80,11 @@ def layer_normalization_computation(
stacked_grads = tf.stack(grads, axis=-1) stacked_grads = tf.stack(grads, axis=-1)
if num_microbatches is not None: if num_microbatches is not None:
stacked_grads = common_manip_utils.maybe_add_microbatch_axis( stacked_grads = common_manip_utils.maybe_add_microbatch_axis(
grads, num_microbatches stacked_grads, num_microbatches
) )
# We will need to sum over the new microbatch size axis (axis=1) in order
# to account for microbatch aggregation.
stacked_grads = tf.reduce_sum(stacked_grads, axis=1)
reduction_axes = tf.range(1, tf.rank(stacked_grads)) reduction_axes = tf.range(1, tf.rank(stacked_grads))
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes) return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)

View file

@ -134,7 +134,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
atol = 1e-1 if self.using_tpu else 1e-2 atol = 1e-1 if self.using_tpu else 1e-2
# Each batched input is a reshape of a `tf.range()` call. # Each batched input is a reshape of a `tf.range()` call.
batch_size = 2 batch_size = 6
example_size = np.prod(input_dims) example_size = np.prod(input_dims)
example_values = tf.range(batch_size * example_size, dtype=tf.float32) example_values = tf.range(batch_size * example_size, dtype=tf.float32)
x_batch = tf.reshape(example_values, [batch_size] + input_dims) x_batch = tf.reshape(example_values, [batch_size] + input_dims)
@ -147,7 +147,9 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
common_test_utils.assert_replica_values_are_close(self, true_norms) common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0] computed_norms = computed_norms.values[0]
true_norms = true_norms.values[0] true_norms = true_norms.values[0]
self.assertEqual(tf.shape(computed_norms)[0], batch_size) self.assertEqual(
tf.shape(computed_norms)[0], num_microbatches or batch_size
)
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol) self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)