Implement and test a registry function for tf.keras.layers.LayerNormalization
.
PiperOrigin-RevId: 561423397
This commit is contained in:
parent
372c934d14
commit
c92610e37a
7 changed files with 314 additions and 9 deletions
|
@ -48,3 +48,25 @@ py_test(
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "layer_normalization",
|
||||||
|
srcs = ["layer_normalization.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "layer_normalization_test",
|
||||||
|
srcs = ["layer_normalization_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
shard_count = 8,
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":dense",
|
||||||
|
":layer_normalization",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Fast clipping function for `tf.keras.layers.Dense`."""
|
"""Fast clipping function for `tf.keras.layers.Dense`."""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Text, Tuple
|
from typing import Any, Mapping, Tuple, Union
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
@ -22,9 +22,9 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
def dense_layer_computation(
|
def dense_layer_computation(
|
||||||
layer_instance: tf.keras.layers.Dense,
|
layer_instance: tf.keras.layers.Dense,
|
||||||
input_args: Tuple[Any, ...],
|
input_args: Tuple[Any, ...],
|
||||||
input_kwargs: Dict[Text, Any],
|
input_kwargs: Mapping[str, Any],
|
||||||
tape: tf.GradientTape,
|
tape: tf.GradientTape,
|
||||||
num_microbatches: Optional[tf.Tensor] = None,
|
num_microbatches: Union[tf.Tensor, None] = None,
|
||||||
) -> type_aliases.RegistryFunctionOutput:
|
) -> type_aliases.RegistryFunctionOutput:
|
||||||
"""Registry function for `tf.keras.layers.Dense`.
|
"""Registry function for `tf.keras.layers.Dense`.
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Fast clipping function for `tf.keras.layers.Embedding`."""
|
"""Fast clipping function for `tf.keras.layers.Embedding`."""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Text, Tuple
|
from typing import Any, Mapping, Tuple, Union
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
|
||||||
|
@ -21,9 +21,9 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
def embedding_layer_computation(
|
def embedding_layer_computation(
|
||||||
layer_instance: tf.keras.layers.Embedding,
|
layer_instance: tf.keras.layers.Embedding,
|
||||||
input_args: Tuple[Any, ...],
|
input_args: Tuple[Any, ...],
|
||||||
input_kwargs: Dict[Text, Any],
|
input_kwargs: Mapping[str, Any],
|
||||||
tape: tf.GradientTape,
|
tape: tf.GradientTape,
|
||||||
num_microbatches: Optional[tf.Tensor] = None,
|
num_microbatches: Union[tf.Tensor, None] = None,
|
||||||
) -> type_aliases.RegistryFunctionOutput:
|
) -> type_aliases.RegistryFunctionOutput:
|
||||||
"""Registry function for `tf.keras.layers.Embedding`.
|
"""Registry function for `tf.keras.layers.Embedding`.
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Fast clipping function for `tf.keras.layers.LayerNormalization`."""
|
||||||
|
|
||||||
|
from typing import Any, Mapping, Tuple, Union
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Supported Keras layers
|
||||||
|
# ==============================================================================
|
||||||
|
def _sqr_norm_fn(grads):
|
||||||
|
stacked_grads = tf.stack(grads, axis=-1)
|
||||||
|
reduction_axes = tf.range(1, tf.rank(stacked_grads))
|
||||||
|
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)
|
||||||
|
|
||||||
|
|
||||||
|
def layer_normalization_computation(
|
||||||
|
layer_instance: tf.keras.layers.LayerNormalization,
|
||||||
|
input_args: Tuple[Any, ...],
|
||||||
|
input_kwargs: Mapping[str, Any],
|
||||||
|
tape: tf.GradientTape,
|
||||||
|
num_microbatches: Union[tf.Tensor, None] = None,
|
||||||
|
) -> type_aliases.RegistryFunctionOutput:
|
||||||
|
"""Registry function for `tf.keras.layers.LayerNormalization`.
|
||||||
|
|
||||||
|
This function computes actual per-example gradients and computes their
|
||||||
|
norms directly, instead of employing a chain-rule trick. This is done using
|
||||||
|
some slick reshaping calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_instance: A `tf.keras.layers.LayerNormalization` instance.
|
||||||
|
input_args: See `dense_layer_computation()` in `dense.py`.
|
||||||
|
input_kwargs: See `dense_layer_computation()` in `dense.py`.
|
||||||
|
tape: See `dense_layer_computation()` in `dense.py`.
|
||||||
|
num_microbatches: See `dense_layer_computation()` in `dense.py`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
See `dense_layer_computation()` in `dense.py`.
|
||||||
|
"""
|
||||||
|
del input_kwargs # Unused in layer normaliztion calls.
|
||||||
|
if num_microbatches is not None:
|
||||||
|
raise NotImplementedError("Microbatching is not currently supported.")
|
||||||
|
|
||||||
|
# To make sure the watched variables (beta, gamma) generate per-example
|
||||||
|
# gradients, we need to convert trainable variables from shape [S] to
|
||||||
|
# [batch_size, S] via duplication to `tf.shape(inputs)` via broadcasting.
|
||||||
|
inputs = input_args[0]
|
||||||
|
base_vars = []
|
||||||
|
batch_size = tf.shape(inputs)[0]
|
||||||
|
|
||||||
|
def process_variable(var):
|
||||||
|
"""Expand univariate `var` and the expanded tensor to `base_vars`."""
|
||||||
|
expanded_var = tf.tile(
|
||||||
|
tf.expand_dims(var, axis=0), [batch_size] + [1] * len(var.shape)
|
||||||
|
)
|
||||||
|
tape.watch(expanded_var)
|
||||||
|
base_vars.append(expanded_var)
|
||||||
|
broadcast_shape = [1] * len(inputs.shape)
|
||||||
|
broadcast_shape[0] = batch_size
|
||||||
|
for d in layer_instance.axis:
|
||||||
|
broadcast_shape[d] = tf.shape(inputs)[d]
|
||||||
|
final_var = tf.reshape(expanded_var, broadcast_shape)
|
||||||
|
return final_var
|
||||||
|
|
||||||
|
orig_gamma = layer_instance.gamma
|
||||||
|
orig_beta = layer_instance.beta
|
||||||
|
layer_instance.gamma = process_variable(orig_gamma)
|
||||||
|
layer_instance.beta = process_variable(orig_beta)
|
||||||
|
|
||||||
|
# Do the computation, ensure that the output conforms to the unexpanded
|
||||||
|
# computation, and restore the state of the original instance.
|
||||||
|
outputs = layer_instance.call(inputs)
|
||||||
|
layer_instance.gamma = orig_gamma
|
||||||
|
layer_instance.beta = orig_beta
|
||||||
|
|
||||||
|
return base_vars, outputs, _sqr_norm_fn
|
|
@ -0,0 +1,159 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Helper functions.
|
||||||
|
# ==============================================================================
|
||||||
|
def get_layer_norm_layer_generators():
|
||||||
|
return {
|
||||||
|
'defaults': lambda x: tf.keras.layers.LayerNormalization(axis=x),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_layer_norm_model_generators():
|
||||||
|
return {
|
||||||
|
# TODO(b/274483956): Test more complex models once the we can support
|
||||||
|
# `nD` inputs for `tf.keras.layers.Dense`.
|
||||||
|
'func1': common_test_utils.make_one_layer_functional_model,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_layer_norm_parameter_tuples():
|
||||||
|
"""Consists of (input_dims, parameter_axes)."""
|
||||||
|
return [
|
||||||
|
# Rank-2
|
||||||
|
([3], -1),
|
||||||
|
([3], [1]),
|
||||||
|
# Rank-3
|
||||||
|
([3, 4], -1),
|
||||||
|
([3, 4], [1]),
|
||||||
|
([3, 4], [2]),
|
||||||
|
([3, 4], [1, 2]),
|
||||||
|
# Rank-4
|
||||||
|
([3, 4, 5], -1),
|
||||||
|
([3, 4, 5], [1]),
|
||||||
|
([3, 4, 5], [2]),
|
||||||
|
([3, 4, 5], [3]),
|
||||||
|
([3, 4, 5], [1, 2]),
|
||||||
|
([3, 4, 5], [1, 3]),
|
||||||
|
([3, 4, 5], [2, 3]),
|
||||||
|
([3, 4, 5], [1, 2, 3]),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_layer_norm_registries():
|
||||||
|
ln_registry = layer_registry.LayerRegistry()
|
||||||
|
ln_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
|
||||||
|
ln_registry.insert(
|
||||||
|
tf.keras.layers.LayerNormalization,
|
||||||
|
layer_normalization.layer_normalization_computation,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
'layer_norm_only': ln_registry,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Main tests.
|
||||||
|
# ==============================================================================
|
||||||
|
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.strategy = tf.distribute.get_strategy()
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
model_name=list(get_layer_norm_model_generators().keys()),
|
||||||
|
layer_name=list(get_layer_norm_layer_generators().keys()),
|
||||||
|
parameter_tuple=get_layer_norm_parameter_tuples(),
|
||||||
|
layer_registry_name=list(get_layer_norm_registries().keys()),
|
||||||
|
is_eager=[True, False],
|
||||||
|
)
|
||||||
|
def test_gradient_norms_on_various_models(
|
||||||
|
self,
|
||||||
|
model_name,
|
||||||
|
layer_name,
|
||||||
|
parameter_tuple,
|
||||||
|
layer_registry_name,
|
||||||
|
is_eager,
|
||||||
|
):
|
||||||
|
# Parse inputs to generate test data.
|
||||||
|
input_dims, parameter_axes = parameter_tuple
|
||||||
|
|
||||||
|
def curried_generator(a, b):
|
||||||
|
del a, b # Unused by the generator.
|
||||||
|
layer_norm_generator = get_layer_norm_layer_generators()[layer_name]
|
||||||
|
return layer_norm_generator(parameter_axes)
|
||||||
|
|
||||||
|
# Load shared assets to all devices.
|
||||||
|
with self.strategy.scope():
|
||||||
|
dummy_output_dim = 1
|
||||||
|
model = common_test_utils.get_model_from_generator(
|
||||||
|
model_generator=get_layer_norm_model_generators()[model_name],
|
||||||
|
layer_generator=curried_generator,
|
||||||
|
input_dims=input_dims,
|
||||||
|
output_dims=[dummy_output_dim],
|
||||||
|
is_eager=is_eager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define the main testing ops. These may be later compiled to a Graph op.
|
||||||
|
def test_op(x_batch):
|
||||||
|
return common_test_utils.get_computed_and_true_norms_from_model(
|
||||||
|
model=model,
|
||||||
|
per_example_loss_fn=None,
|
||||||
|
num_microbatches=None,
|
||||||
|
x_batch=[x_batch, x_batch] if model_name == 'tower2' else x_batch,
|
||||||
|
weight_batch=None,
|
||||||
|
registry=get_layer_norm_registries()[layer_registry_name],
|
||||||
|
)
|
||||||
|
|
||||||
|
# TPUs can only run `tf.function`-decorated functions.
|
||||||
|
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
|
||||||
|
if using_tpu:
|
||||||
|
test_op = tf.function(test_op, jit_compile=True, autograph=False)
|
||||||
|
|
||||||
|
# TPUs use lower precision than CPUs, so we relax our criterion (see
|
||||||
|
# `dense_test.py` for additional discussions).
|
||||||
|
rtol = 1e-2 if using_tpu else 1e-3
|
||||||
|
atol = 1e-1 if using_tpu else 1e-2
|
||||||
|
|
||||||
|
# Each batched input is a reshape of a `tf.range()` call.
|
||||||
|
batch_size = 2
|
||||||
|
example_size = np.prod(input_dims)
|
||||||
|
example_values = tf.range(batch_size * example_size, dtype=tf.float32)
|
||||||
|
x_batch = tf.reshape(example_values, [batch_size] + input_dims)
|
||||||
|
batch_size = x_batch.shape[0]
|
||||||
|
# Set up the device ops and run the test.
|
||||||
|
computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,))
|
||||||
|
# TPUs return replica contexts, which must be unwrapped.
|
||||||
|
if using_tpu:
|
||||||
|
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||||
|
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||||
|
computed_norms = computed_norms.values[0]
|
||||||
|
true_norms = true_norms.values[0]
|
||||||
|
self.assertEqual(tf.shape(computed_norms)[0], batch_size)
|
||||||
|
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization_test
|
||||||
|
|
||||||
|
|
||||||
|
class GradNormTpuTest(layer_normalization_test.GradNormTest):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.strategy = ctu.create_tpu_strategy()
|
||||||
|
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -13,12 +13,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""A collection of type aliases used throughout the clipping library."""
|
"""A collection of type aliases used throughout the clipping library."""
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
|
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
# Tensorflow aliases.
|
# Tensorflow aliases.
|
||||||
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
|
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[str, tf.Tensor]]
|
||||||
|
|
||||||
InputTensors = PackedTensors
|
InputTensors = PackedTensors
|
||||||
|
|
||||||
|
@ -34,7 +34,13 @@ SquareNormFunction = Callable[[OutputTensors], tf.Tensor]
|
||||||
RegistryFunctionOutput = Tuple[Any, OutputTensors, SquareNormFunction]
|
RegistryFunctionOutput = Tuple[Any, OutputTensors, SquareNormFunction]
|
||||||
|
|
||||||
RegistryFunction = Callable[
|
RegistryFunction = Callable[
|
||||||
[Any, Tuple[Any, ...], Dict[Text, Any], tf.GradientTape],
|
[
|
||||||
|
Any,
|
||||||
|
Tuple[Any, ...],
|
||||||
|
Mapping[str, Any],
|
||||||
|
tf.GradientTape,
|
||||||
|
Union[tf.Tensor, None],
|
||||||
|
],
|
||||||
RegistryFunctionOutput,
|
RegistryFunctionOutput,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue