diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD index 358d8c3..e2da95d 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD @@ -102,7 +102,10 @@ py_library( name = "layer_normalization", srcs = ["layer_normalization.py"], srcs_version = "PY3", - deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases"], + deps = [ + "//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils", + "//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases", + ], ) py_test( diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py index 86aca3c..91338a9 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py @@ -15,18 +15,13 @@ from typing import Any, Mapping, Tuple, Union import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases # ============================================================================== # Supported Keras layers # ============================================================================== -def _sqr_norm_fn(grads): - stacked_grads = tf.stack(grads, axis=-1) - reduction_axes = tf.range(1, tf.rank(stacked_grads)) - return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes) - - def layer_normalization_computation( layer_instance: tf.keras.layers.LayerNormalization, input_args: Tuple[Any, ...], @@ -51,9 +46,6 @@ def layer_normalization_computation( See `dense_layer_computation()` in `dense.py`. """ del input_kwargs # Unused in layer normaliztion calls. - if num_microbatches is not None: - raise NotImplementedError("Microbatching is not currently supported.") - # To make sure the watched variables (beta, gamma) generate per-example # gradients, we need to convert trainable variables from shape [S] to # [batch_size, S] via duplication to `tf.shape(inputs)` via broadcasting. @@ -86,4 +78,14 @@ def layer_normalization_computation( layer_instance.gamma = orig_gamma layer_instance.beta = orig_beta - return base_vars, outputs, _sqr_norm_fn + def sqr_norm_fn(grads): + stacked_grads = tf.stack(grads, axis=-1) + if num_microbatches is not None: + stacked_grads = common_manip_utils.maybe_add_microbatch_axis( + grads, num_microbatches + ) + stacked_grads = tf.reduce_sum(stacked_grads, axis=1) + reduction_axes = tf.range(1, tf.rank(stacked_grads)) + return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes) + + return base_vars, outputs, sqr_norm_fn diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py index 8e55214..c2b6429 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py @@ -87,6 +87,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase): layer_name=list(get_layer_norm_layer_generators().keys()), parameter_tuple=get_layer_norm_parameter_tuples(), layer_registry_name=list(get_layer_norm_registries().keys()), + num_microbatches=[None, 2], is_eager=[True, False], ) def test_gradient_norms_on_various_models( @@ -95,6 +96,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase): layer_name, parameter_tuple, layer_registry_name, + num_microbatches, is_eager, ): # Parse inputs to generate test data. @@ -121,7 +123,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase): return common_test_utils.get_computed_and_true_norms_from_model( model=model, per_example_loss_fn=None, - num_microbatches=None, + num_microbatches=num_microbatches, x_batch=[x_batch, x_batch] if model_name == 'tower2' else x_batch, weight_batch=None, registry=get_layer_norm_registries()[layer_registry_name],