From 66d05a22a397402c34a6e3b8fc0f2ad7f8a8d420 Mon Sep 17 00:00:00 2001 From: William Kong Date: Fri, 30 Aug 2024 12:40:33 -0700 Subject: [PATCH] Fix a gradient clipping bug for layer normalization layers with microbatch axes. The previous code passed the unstacked gradients (a list) instead of the stacked gradients (a tensor) to the microbatcher, which led to unexpected behavior. This change passes the right argument and changes the original unit test to catch this bug. PiperOrigin-RevId: 669413064 --- .../registry_functions/layer_normalization.py | 5 ++++- .../registry_functions/layer_normalization_test.py | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py index 849ace6..e79a52b 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py @@ -80,8 +80,11 @@ def layer_normalization_computation( stacked_grads = tf.stack(grads, axis=-1) if num_microbatches is not None: stacked_grads = common_manip_utils.maybe_add_microbatch_axis( - grads, num_microbatches + stacked_grads, num_microbatches ) + # We will need to sum over the new microbatch size axis (axis=1) in order + # to account for microbatch aggregation. + stacked_grads = tf.reduce_sum(stacked_grads, axis=1) reduction_axes = tf.range(1, tf.rank(stacked_grads)) return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py index c0b18d8..f720435 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py @@ -134,7 +134,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase): atol = 1e-1 if self.using_tpu else 1e-2 # Each batched input is a reshape of a `tf.range()` call. - batch_size = 2 + batch_size = 6 example_size = np.prod(input_dims) example_values = tf.range(batch_size * example_size, dtype=tf.float32) x_batch = tf.reshape(example_values, [batch_size] + input_dims) @@ -147,7 +147,9 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase): common_test_utils.assert_replica_values_are_close(self, true_norms) computed_norms = computed_norms.values[0] true_norms = true_norms.values[0] - self.assertEqual(tf.shape(computed_norms)[0], batch_size) + self.assertEqual( + tf.shape(computed_norms)[0], num_microbatches or batch_size + ) self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)