Implement and test a registry function for tf.keras.layers.LayerNormalization.

PiperOrigin-RevId: 561423397
This commit is contained in:
A. Unique TensorFlower 2023-08-30 12:53:38 -07:00
parent 372c934d14
commit c92610e37a
7 changed files with 314 additions and 9 deletions

View file

@ -48,3 +48,25 @@ py_test(
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
], ],
) )
py_library(
name = "layer_normalization",
srcs = ["layer_normalization.py"],
srcs_version = "PY3",
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases"],
)
py_test(
name = "layer_normalization_test",
srcs = ["layer_normalization_test.py"],
python_version = "PY3",
shard_count = 8,
srcs_version = "PY3",
deps = [
":dense",
":layer_normalization",
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
],
)

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Fast clipping function for `tf.keras.layers.Dense`.""" """Fast clipping function for `tf.keras.layers.Dense`."""
from typing import Any, Dict, Optional, Text, Tuple from typing import Any, Mapping, Tuple, Union
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
@ -22,9 +22,9 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
def dense_layer_computation( def dense_layer_computation(
layer_instance: tf.keras.layers.Dense, layer_instance: tf.keras.layers.Dense,
input_args: Tuple[Any, ...], input_args: Tuple[Any, ...],
input_kwargs: Dict[Text, Any], input_kwargs: Mapping[str, Any],
tape: tf.GradientTape, tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None, num_microbatches: Union[tf.Tensor, None] = None,
) -> type_aliases.RegistryFunctionOutput: ) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Dense`. """Registry function for `tf.keras.layers.Dense`.

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Fast clipping function for `tf.keras.layers.Embedding`.""" """Fast clipping function for `tf.keras.layers.Embedding`."""
from typing import Any, Dict, Optional, Text, Tuple from typing import Any, Mapping, Tuple, Union
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
@ -21,9 +21,9 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
def embedding_layer_computation( def embedding_layer_computation(
layer_instance: tf.keras.layers.Embedding, layer_instance: tf.keras.layers.Embedding,
input_args: Tuple[Any, ...], input_args: Tuple[Any, ...],
input_kwargs: Dict[Text, Any], input_kwargs: Mapping[str, Any],
tape: tf.GradientTape, tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None, num_microbatches: Union[tf.Tensor, None] = None,
) -> type_aliases.RegistryFunctionOutput: ) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Embedding`. """Registry function for `tf.keras.layers.Embedding`.

View file

@ -0,0 +1,89 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast clipping function for `tf.keras.layers.LayerNormalization`."""
from typing import Any, Mapping, Tuple, Union
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
# ==============================================================================
# Supported Keras layers
# ==============================================================================
def _sqr_norm_fn(grads):
stacked_grads = tf.stack(grads, axis=-1)
reduction_axes = tf.range(1, tf.rank(stacked_grads))
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)
def layer_normalization_computation(
layer_instance: tf.keras.layers.LayerNormalization,
input_args: Tuple[Any, ...],
input_kwargs: Mapping[str, Any],
tape: tf.GradientTape,
num_microbatches: Union[tf.Tensor, None] = None,
) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.LayerNormalization`.
This function computes actual per-example gradients and computes their
norms directly, instead of employing a chain-rule trick. This is done using
some slick reshaping calls.
Args:
layer_instance: A `tf.keras.layers.LayerNormalization` instance.
input_args: See `dense_layer_computation()` in `dense.py`.
input_kwargs: See `dense_layer_computation()` in `dense.py`.
tape: See `dense_layer_computation()` in `dense.py`.
num_microbatches: See `dense_layer_computation()` in `dense.py`.
Returns:
See `dense_layer_computation()` in `dense.py`.
"""
del input_kwargs # Unused in layer normaliztion calls.
if num_microbatches is not None:
raise NotImplementedError("Microbatching is not currently supported.")
# To make sure the watched variables (beta, gamma) generate per-example
# gradients, we need to convert trainable variables from shape [S] to
# [batch_size, S] via duplication to `tf.shape(inputs)` via broadcasting.
inputs = input_args[0]
base_vars = []
batch_size = tf.shape(inputs)[0]
def process_variable(var):
"""Expand univariate `var` and the expanded tensor to `base_vars`."""
expanded_var = tf.tile(
tf.expand_dims(var, axis=0), [batch_size] + [1] * len(var.shape)
)
tape.watch(expanded_var)
base_vars.append(expanded_var)
broadcast_shape = [1] * len(inputs.shape)
broadcast_shape[0] = batch_size
for d in layer_instance.axis:
broadcast_shape[d] = tf.shape(inputs)[d]
final_var = tf.reshape(expanded_var, broadcast_shape)
return final_var
orig_gamma = layer_instance.gamma
orig_beta = layer_instance.beta
layer_instance.gamma = process_variable(orig_gamma)
layer_instance.beta = process_variable(orig_beta)
# Do the computation, ensure that the output conforms to the unexpanded
# computation, and restore the state of the original instance.
outputs = layer_instance.call(inputs)
layer_instance.gamma = orig_gamma
layer_instance.beta = orig_beta
return base_vars, outputs, _sqr_norm_fn

View file

@ -0,0 +1,159 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization
# ==============================================================================
# Helper functions.
# ==============================================================================
def get_layer_norm_layer_generators():
return {
'defaults': lambda x: tf.keras.layers.LayerNormalization(axis=x),
}
def get_layer_norm_model_generators():
return {
# TODO(b/274483956): Test more complex models once the we can support
# `nD` inputs for `tf.keras.layers.Dense`.
'func1': common_test_utils.make_one_layer_functional_model,
}
def get_layer_norm_parameter_tuples():
"""Consists of (input_dims, parameter_axes)."""
return [
# Rank-2
([3], -1),
([3], [1]),
# Rank-3
([3, 4], -1),
([3, 4], [1]),
([3, 4], [2]),
([3, 4], [1, 2]),
# Rank-4
([3, 4, 5], -1),
([3, 4, 5], [1]),
([3, 4, 5], [2]),
([3, 4, 5], [3]),
([3, 4, 5], [1, 2]),
([3, 4, 5], [1, 3]),
([3, 4, 5], [2, 3]),
([3, 4, 5], [1, 2, 3]),
]
def get_layer_norm_registries():
ln_registry = layer_registry.LayerRegistry()
ln_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
ln_registry.insert(
tf.keras.layers.LayerNormalization,
layer_normalization.layer_normalization_computation,
)
return {
'layer_norm_only': ln_registry,
}
# ==============================================================================
# Main tests.
# ==============================================================================
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.strategy = tf.distribute.get_strategy()
@parameterized.product(
model_name=list(get_layer_norm_model_generators().keys()),
layer_name=list(get_layer_norm_layer_generators().keys()),
parameter_tuple=get_layer_norm_parameter_tuples(),
layer_registry_name=list(get_layer_norm_registries().keys()),
is_eager=[True, False],
)
def test_gradient_norms_on_various_models(
self,
model_name,
layer_name,
parameter_tuple,
layer_registry_name,
is_eager,
):
# Parse inputs to generate test data.
input_dims, parameter_axes = parameter_tuple
def curried_generator(a, b):
del a, b # Unused by the generator.
layer_norm_generator = get_layer_norm_layer_generators()[layer_name]
return layer_norm_generator(parameter_axes)
# Load shared assets to all devices.
with self.strategy.scope():
dummy_output_dim = 1
model = common_test_utils.get_model_from_generator(
model_generator=get_layer_norm_model_generators()[model_name],
layer_generator=curried_generator,
input_dims=input_dims,
output_dims=[dummy_output_dim],
is_eager=is_eager,
)
# Define the main testing ops. These may be later compiled to a Graph op.
def test_op(x_batch):
return common_test_utils.get_computed_and_true_norms_from_model(
model=model,
per_example_loss_fn=None,
num_microbatches=None,
x_batch=[x_batch, x_batch] if model_name == 'tower2' else x_batch,
weight_batch=None,
registry=get_layer_norm_registries()[layer_registry_name],
)
# TPUs can only run `tf.function`-decorated functions.
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
if using_tpu:
test_op = tf.function(test_op, jit_compile=True, autograph=False)
# TPUs use lower precision than CPUs, so we relax our criterion (see
# `dense_test.py` for additional discussions).
rtol = 1e-2 if using_tpu else 1e-3
atol = 1e-1 if using_tpu else 1e-2
# Each batched input is a reshape of a `tf.range()` call.
batch_size = 2
example_size = np.prod(input_dims)
example_values = tf.range(batch_size * example_size, dtype=tf.float32)
x_batch = tf.reshape(example_values, [batch_size] + input_dims)
batch_size = x_batch.shape[0]
# Set up the device ops and run the test.
computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,))
# TPUs return replica contexts, which must be unwrapped.
if using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0]
true_norms = true_norms.values[0]
self.assertEqual(tf.shape(computed_norms)[0], batch_size)
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,29 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization_test
class GradNormTpuTest(layer_normalization_test.GradNormTest):
def setUp(self):
super().setUp()
self.strategy = ctu.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
if __name__ == '__main__':
tf.test.main()

View file

@ -13,12 +13,12 @@
# limitations under the License. # limitations under the License.
"""A collection of type aliases used throughout the clipping library.""" """A collection of type aliases used throughout the clipping library."""
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
import tensorflow as tf import tensorflow as tf
# Tensorflow aliases. # Tensorflow aliases.
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[str, tf.Tensor]]
InputTensors = PackedTensors InputTensors = PackedTensors
@ -34,7 +34,13 @@ SquareNormFunction = Callable[[OutputTensors], tf.Tensor]
RegistryFunctionOutput = Tuple[Any, OutputTensors, SquareNormFunction] RegistryFunctionOutput = Tuple[Any, OutputTensors, SquareNormFunction]
RegistryFunction = Callable[ RegistryFunction = Callable[
[Any, Tuple[Any, ...], Dict[Text, Any], tf.GradientTape], [
Any,
Tuple[Any, ...],
Mapping[str, Any],
tf.GradientTape,
Union[tf.Tensor, None],
],
RegistryFunctionOutput, RegistryFunctionOutput,
] ]