Fix a gradient clipping bug for layer normalization layers with microbatch axes.
The previous code passed the unstacked gradients (a list) instead of the stacked gradients (a tensor) to the microbatcher, which led to unexpected behavior. This change passes the right argument and changes the original unit test to catch this bug. PiperOrigin-RevId: 669413064
This commit is contained in:
parent
b3963971e3
commit
66d05a22a3
2 changed files with 8 additions and 3 deletions
|
@ -80,8 +80,11 @@ def layer_normalization_computation(
|
|||
stacked_grads = tf.stack(grads, axis=-1)
|
||||
if num_microbatches is not None:
|
||||
stacked_grads = common_manip_utils.maybe_add_microbatch_axis(
|
||||
grads, num_microbatches
|
||||
stacked_grads, num_microbatches
|
||||
)
|
||||
# We will need to sum over the new microbatch size axis (axis=1) in order
|
||||
# to account for microbatch aggregation.
|
||||
stacked_grads = tf.reduce_sum(stacked_grads, axis=1)
|
||||
reduction_axes = tf.range(1, tf.rank(stacked_grads))
|
||||
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)
|
||||
|
||||
|
|
|
@ -134,7 +134,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
|||
atol = 1e-1 if self.using_tpu else 1e-2
|
||||
|
||||
# Each batched input is a reshape of a `tf.range()` call.
|
||||
batch_size = 2
|
||||
batch_size = 6
|
||||
example_size = np.prod(input_dims)
|
||||
example_values = tf.range(batch_size * example_size, dtype=tf.float32)
|
||||
x_batch = tf.reshape(example_values, [batch_size] + input_dims)
|
||||
|
@ -147,7 +147,9 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
|||
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||
computed_norms = computed_norms.values[0]
|
||||
true_norms = true_norms.values[0]
|
||||
self.assertEqual(tf.shape(computed_norms)[0], batch_size)
|
||||
self.assertEqual(
|
||||
tf.shape(computed_norms)[0], num_microbatches or batch_size
|
||||
)
|
||||
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue