diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD index 94db8b3..fd9fc31 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD @@ -48,3 +48,25 @@ py_test( "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", ], ) + +py_library( + name = "layer_normalization", + srcs = ["layer_normalization.py"], + srcs_version = "PY3", + deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases"], +) + +py_test( + name = "layer_normalization_test", + srcs = ["layer_normalization_test.py"], + python_version = "PY3", + shard_count = 8, + srcs_version = "PY3", + deps = [ + ":dense", + ":layer_normalization", + "//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads", + "//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils", + "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", + ], +) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py index aa61e3b..8bf3724 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py @@ -13,7 +13,7 @@ # limitations under the License. """Fast clipping function for `tf.keras.layers.Dense`.""" -from typing import Any, Dict, Optional, Text, Tuple +from typing import Any, Mapping, Tuple, Union import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases @@ -22,9 +22,9 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases def dense_layer_computation( layer_instance: tf.keras.layers.Dense, input_args: Tuple[Any, ...], - input_kwargs: Dict[Text, Any], + input_kwargs: Mapping[str, Any], tape: tf.GradientTape, - num_microbatches: Optional[tf.Tensor] = None, + num_microbatches: Union[tf.Tensor, None] = None, ) -> type_aliases.RegistryFunctionOutput: """Registry function for `tf.keras.layers.Dense`. diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding.py index 057ef39..13ab5f5 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding.py @@ -13,7 +13,7 @@ # limitations under the License. """Fast clipping function for `tf.keras.layers.Embedding`.""" -from typing import Any, Dict, Optional, Text, Tuple +from typing import Any, Mapping, Tuple, Union import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases @@ -21,9 +21,9 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases def embedding_layer_computation( layer_instance: tf.keras.layers.Embedding, input_args: Tuple[Any, ...], - input_kwargs: Dict[Text, Any], + input_kwargs: Mapping[str, Any], tape: tf.GradientTape, - num_microbatches: Optional[tf.Tensor] = None, + num_microbatches: Union[tf.Tensor, None] = None, ) -> type_aliases.RegistryFunctionOutput: """Registry function for `tf.keras.layers.Embedding`. diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py new file mode 100644 index 0000000..86aca3c --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py @@ -0,0 +1,89 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast clipping function for `tf.keras.layers.LayerNormalization`.""" + +from typing import Any, Mapping, Tuple, Union +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases + + +# ============================================================================== +# Supported Keras layers +# ============================================================================== +def _sqr_norm_fn(grads): + stacked_grads = tf.stack(grads, axis=-1) + reduction_axes = tf.range(1, tf.rank(stacked_grads)) + return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes) + + +def layer_normalization_computation( + layer_instance: tf.keras.layers.LayerNormalization, + input_args: Tuple[Any, ...], + input_kwargs: Mapping[str, Any], + tape: tf.GradientTape, + num_microbatches: Union[tf.Tensor, None] = None, +) -> type_aliases.RegistryFunctionOutput: + """Registry function for `tf.keras.layers.LayerNormalization`. + + This function computes actual per-example gradients and computes their + norms directly, instead of employing a chain-rule trick. This is done using + some slick reshaping calls. + + Args: + layer_instance: A `tf.keras.layers.LayerNormalization` instance. + input_args: See `dense_layer_computation()` in `dense.py`. + input_kwargs: See `dense_layer_computation()` in `dense.py`. + tape: See `dense_layer_computation()` in `dense.py`. + num_microbatches: See `dense_layer_computation()` in `dense.py`. + + Returns: + See `dense_layer_computation()` in `dense.py`. + """ + del input_kwargs # Unused in layer normaliztion calls. + if num_microbatches is not None: + raise NotImplementedError("Microbatching is not currently supported.") + + # To make sure the watched variables (beta, gamma) generate per-example + # gradients, we need to convert trainable variables from shape [S] to + # [batch_size, S] via duplication to `tf.shape(inputs)` via broadcasting. + inputs = input_args[0] + base_vars = [] + batch_size = tf.shape(inputs)[0] + + def process_variable(var): + """Expand univariate `var` and the expanded tensor to `base_vars`.""" + expanded_var = tf.tile( + tf.expand_dims(var, axis=0), [batch_size] + [1] * len(var.shape) + ) + tape.watch(expanded_var) + base_vars.append(expanded_var) + broadcast_shape = [1] * len(inputs.shape) + broadcast_shape[0] = batch_size + for d in layer_instance.axis: + broadcast_shape[d] = tf.shape(inputs)[d] + final_var = tf.reshape(expanded_var, broadcast_shape) + return final_var + + orig_gamma = layer_instance.gamma + orig_beta = layer_instance.beta + layer_instance.gamma = process_variable(orig_gamma) + layer_instance.beta = process_variable(orig_beta) + + # Do the computation, ensure that the output conforms to the unexpanded + # computation, and restore the state of the original instance. + outputs = layer_instance.call(inputs) + layer_instance.gamma = orig_gamma + layer_instance.beta = orig_beta + + return base_vars, outputs, _sqr_norm_fn diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py new file mode 100644 index 0000000..8e55214 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py @@ -0,0 +1,159 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import numpy as np +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils +from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization + + +# ============================================================================== +# Helper functions. +# ============================================================================== +def get_layer_norm_layer_generators(): + return { + 'defaults': lambda x: tf.keras.layers.LayerNormalization(axis=x), + } + + +def get_layer_norm_model_generators(): + return { + # TODO(b/274483956): Test more complex models once the we can support + # `nD` inputs for `tf.keras.layers.Dense`. + 'func1': common_test_utils.make_one_layer_functional_model, + } + + +def get_layer_norm_parameter_tuples(): + """Consists of (input_dims, parameter_axes).""" + return [ + # Rank-2 + ([3], -1), + ([3], [1]), + # Rank-3 + ([3, 4], -1), + ([3, 4], [1]), + ([3, 4], [2]), + ([3, 4], [1, 2]), + # Rank-4 + ([3, 4, 5], -1), + ([3, 4, 5], [1]), + ([3, 4, 5], [2]), + ([3, 4, 5], [3]), + ([3, 4, 5], [1, 2]), + ([3, 4, 5], [1, 3]), + ([3, 4, 5], [2, 3]), + ([3, 4, 5], [1, 2, 3]), + ] + + +def get_layer_norm_registries(): + ln_registry = layer_registry.LayerRegistry() + ln_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation) + ln_registry.insert( + tf.keras.layers.LayerNormalization, + layer_normalization.layer_normalization_computation, + ) + return { + 'layer_norm_only': ln_registry, + } + + +# ============================================================================== +# Main tests. +# ============================================================================== +class GradNormTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + self.strategy = tf.distribute.get_strategy() + + @parameterized.product( + model_name=list(get_layer_norm_model_generators().keys()), + layer_name=list(get_layer_norm_layer_generators().keys()), + parameter_tuple=get_layer_norm_parameter_tuples(), + layer_registry_name=list(get_layer_norm_registries().keys()), + is_eager=[True, False], + ) + def test_gradient_norms_on_various_models( + self, + model_name, + layer_name, + parameter_tuple, + layer_registry_name, + is_eager, + ): + # Parse inputs to generate test data. + input_dims, parameter_axes = parameter_tuple + + def curried_generator(a, b): + del a, b # Unused by the generator. + layer_norm_generator = get_layer_norm_layer_generators()[layer_name] + return layer_norm_generator(parameter_axes) + + # Load shared assets to all devices. + with self.strategy.scope(): + dummy_output_dim = 1 + model = common_test_utils.get_model_from_generator( + model_generator=get_layer_norm_model_generators()[model_name], + layer_generator=curried_generator, + input_dims=input_dims, + output_dims=[dummy_output_dim], + is_eager=is_eager, + ) + + # Define the main testing ops. These may be later compiled to a Graph op. + def test_op(x_batch): + return common_test_utils.get_computed_and_true_norms_from_model( + model=model, + per_example_loss_fn=None, + num_microbatches=None, + x_batch=[x_batch, x_batch] if model_name == 'tower2' else x_batch, + weight_batch=None, + registry=get_layer_norm_registries()[layer_registry_name], + ) + + # TPUs can only run `tf.function`-decorated functions. + using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy) + if using_tpu: + test_op = tf.function(test_op, jit_compile=True, autograph=False) + + # TPUs use lower precision than CPUs, so we relax our criterion (see + # `dense_test.py` for additional discussions). + rtol = 1e-2 if using_tpu else 1e-3 + atol = 1e-1 if using_tpu else 1e-2 + + # Each batched input is a reshape of a `tf.range()` call. + batch_size = 2 + example_size = np.prod(input_dims) + example_values = tf.range(batch_size * example_size, dtype=tf.float32) + x_batch = tf.reshape(example_values, [batch_size] + input_dims) + batch_size = x_batch.shape[0] + # Set up the device ops and run the test. + computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,)) + # TPUs return replica contexts, which must be unwrapped. + if using_tpu: + common_test_utils.assert_replica_values_are_close(self, computed_norms) + common_test_utils.assert_replica_values_are_close(self, true_norms) + computed_norms = computed_norms.values[0] + true_norms = true_norms.values[0] + self.assertEqual(tf.shape(computed_norms)[0], batch_size) + self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_tpu_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_tpu_test.py new file mode 100644 index 0000000..6891d60 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_tpu_test.py @@ -0,0 +1,29 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization_test + + +class GradNormTpuTest(layer_normalization_test.GradNormTest): + + def setUp(self): + super().setUp() + self.strategy = ctu.create_tpu_strategy() + self.assertIn('TPU', self.strategy.extended.worker_devices[0]) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py b/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py index 181da5a..4b460b3 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py @@ -13,12 +13,12 @@ # limitations under the License. """A collection of type aliases used throughout the clipping library.""" -from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union import tensorflow as tf # Tensorflow aliases. -PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] +PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[str, tf.Tensor]] InputTensors = PackedTensors @@ -34,7 +34,13 @@ SquareNormFunction = Callable[[OutputTensors], tf.Tensor] RegistryFunctionOutput = Tuple[Any, OutputTensors, SquareNormFunction] RegistryFunction = Callable[ - [Any, Tuple[Any, ...], Dict[Text, Any], tf.GradientTape], + [ + Any, + Tuple[Any, ...], + Mapping[str, Any], + tf.GradientTape, + Union[tf.Tensor, None], + ], RegistryFunctionOutput, ]