From b19088f048b429363cfd7c7aff7256db4fc85dc1 Mon Sep 17 00:00:00 2001 From: William Kong Date: Wed, 22 Nov 2023 07:17:12 -0800 Subject: [PATCH] Implement and test a registry function for `tf.keras.layers.MultiHeadAttention`. PiperOrigin-RevId: 584620638 --- .../registry_functions/BUILD | 26 ++ .../multi_head_attention.py | 214 ++++++++++ .../multi_head_attention_test.py | 375 ++++++++++++++++++ .../multi_head_attention_tpu_test.py | 30 ++ 4 files changed, 645 insertions(+) create mode 100644 tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/multi_head_attention.py create mode 100644 tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/multi_head_attention_test.py create mode 100644 tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/multi_head_attention_tpu_test.py diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD index 2f7ec0e..b72263e 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD @@ -176,3 +176,29 @@ py_test( "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", ], ) + +py_library( + name = "multi_head_attention", + srcs = ["multi_head_attention.py"], + srcs_version = "PY3", + deps = [ + ":einsum_dense", + "//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils", + "//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases", + ], +) + +py_test( + name = "multi_head_attention_test", + srcs = ["multi_head_attention_test.py"], + python_version = "PY3", + shard_count = 8, + srcs_version = "PY3", + deps = [ + ":dense", + ":multi_head_attention", + "//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads", + "//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils", + "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", + ], +) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/multi_head_attention.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/multi_head_attention.py new file mode 100644 index 0000000..4078c06 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/multi_head_attention.py @@ -0,0 +1,214 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast clipping function for `tf.keras.layers.MultiHeadAttention`.""" + +from collections.abc import Mapping, Sequence +from typing import Any, Optional +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_dense + + +def multi_head_attention_layer_computation( + layer_instance: tf.keras.layers.MultiHeadAttention, + input_args: Sequence[Any], + input_kwargs: Mapping[str, Any], + tape: tf.GradientTape, + num_microbatches: Optional[tf.Tensor] = None, +) -> type_aliases.RegistryFunctionOutput: + """Registry function for `tf.keras.layers.MultiHeadAttention`. + + This function essentially applies the registry function for + `tf.keras.layers.EinsumDense` three times. Some hints about the nature of + the Einsum transforms are given below. + + ------------------- + ABOUT INPUT SHAPES + ------------------- + For a given {query, key, value} input `I` of shape + + [Eq. A] tf.shape(I) == [n, a[0],... , a[k-1], b] + + where `n` is the batch size, the corresponding Einsum equation for its + `EinsumDense` transform is given by: + + {n a[0] ... a[k-1] b},{b c d}->{n a[1] ... a[k-1] c d} + + where `c` corresponds to the number of attention heads + (`layer_instance.num_heads`) and `d` corresponds to the size per head + (`layer_instance.key_dim` or `layer_instance.value_dim`). + + It is expected that the rank of the query, key, and value inputs are the same. + + ------------------ + ABOUT OUTPUT SHAPE + ------------------ + Suppose the shape of the `query` input `Q` is given by [Eq. A] above with + `I == Q`. Then, if `layer_instance.output_shape is None`, the output `O` of + the layer satisfies `tf.shape(Q) == tf.shape(O)`. However, if we have + `layer_instance.output_shape is not None`, then + + tf.shape(Q) == [n, a[0], ..., a[k-1], *layer_instance.output_shape] + + Args: + layer_instance: A `tf.keras.layers.MultiHeadAttention` instance. + input_args: See `dense_layer_computation()`. + input_kwargs: See `dense_layer_computation()`. + tape: See `dense_layer_computation()`. + num_microbatches: See `dense_layer_computation()`. + + Returns: + See `dense_layer_computation()`. + """ + # ---------------------- + # PREPROCESS THE INPUTS. + # ---------------------- + query = ( + input_kwargs.get("query") + if input_kwargs.get("query") is not None + else input_args[0] + ) + value = ( + input_kwargs.get("value") + if input_kwargs.get("value") is not None + else input_args[1] + ) + key = input_kwargs.get("key") + attention_mask = input_kwargs.get("attention_mask") + return_attention_scores = input_kwargs.get("return_attention_scores") + training = input_kwargs.get("training") + use_causal_mask = input_kwargs.get("use_causal_mask") + attention_mask = layer_instance._compute_attention_mask( # pylint: disable=protected-access + query, + value, + key=key, + attention_mask=attention_mask, + use_causal_mask=use_causal_mask, + ) + if not layer_instance._built_from_signature: # pylint: disable=protected-access + layer_instance._build_from_signature(query=query, value=value, key=key) # pylint: disable=protected-access + if key is None: + key = value + + query_lengths = 0 + query_is_ragged = isinstance(query, tf.RaggedTensor) + if query_is_ragged: + query_lengths = query.nested_row_lengths() + query = query.to_tensor() + + key_is_ragged = isinstance(key, tf.RaggedTensor) + value_is_ragged = isinstance(value, tf.RaggedTensor) + if key_is_ragged and value_is_ragged: + bounding_shape = tf.math.maximum( + key.bounding_shape(), value.bounding_shape() + ) + key = key.to_tensor(shape=bounding_shape) + value = value.to_tensor(shape=bounding_shape) + elif key_is_ragged: + key = key.to_tensor(shape=tf.shape(value)) + elif value_is_ragged: + value = value.to_tensor(shape=tf.shape(key)) + else: + pass + # ------------------------------ + # APPLY THE FAST CLIPPING TRICK. + # ------------------------------ + # trainable_op: W_q * QUERY + query_base_vars, query, query_sqr_norm_fn = ( + einsum_dense.einsum_layer_computation( + layer_instance._query_dense, # pylint: disable=protected-access + (query,), + {}, + tape, + num_microbatches, + ) + ) + # trainable_op: W_k * KEY + key_base_vars, key, key_sqr_norm_fn = einsum_dense.einsum_layer_computation( + layer_instance._key_dense, # pylint: disable=protected-access + (key,), + {}, + tape, + num_microbatches, + ) + # trainable_op: W_v * VALUE + value_base_vars, value, value_sqr_norm_fn = ( + einsum_dense.einsum_layer_computation( + layer_instance._value_dense, # pylint: disable=protected-access + (value,), + {}, + tape, + num_microbatches, + ) + ) + # op: TEMP = ATTENTION(W_q * QUERY, W_k * KEY, W_v * VALUE) + temp_output, attention_scores = layer_instance._compute_attention( # pylint: disable=protected-access + query, + key, + value, + attention_mask, + training, + ) + # trainable_op: W_o * OUTPUT + ( + attention_output_base_vars, + attention_output, + attention_output_sqr_norm_fn, + ) = einsum_dense.einsum_layer_computation( + layer_instance._output_dense, # pylint: disable=protected-access + (temp_output,), + {}, + tape, + num_microbatches, + ) + # ------------------------ + # POSTPROCESS THE OUTPUTS. + # ------------------------ + # Get registry output tensors ready. + if query_is_ragged: + attention_output = tf.RaggedTensor.from_tensor( + attention_output, query_lengths + ) + outputs = attention_output + if return_attention_scores: + outputs = (attention_output, attention_scores) + base_vars = [ + query_base_vars, + key_base_vars, + value_base_vars, + attention_output_base_vars, + ] + + # The square norm function should just aggregate the squared norms + # corresponding to each trainable op. + def sqr_norm_fn(grad_list): + if len(grad_list) != 4: + raise ValueError( + "Expected a container of 4 gradients for the `MultiheadAttention` " + "square norm function's input. Instead, received a container of " + "size " + + str(len(grad_list)) + ) + combined_sqr_norms = tf.stack( + [ + query_sqr_norm_fn(grad_list[0]), + key_sqr_norm_fn(grad_list[1]), + value_sqr_norm_fn(grad_list[2]), + attention_output_sqr_norm_fn(grad_list[3]), + ], + axis=1, + ) + return tf.reduce_sum(combined_sqr_norms, axis=1) + + return base_vars, outputs, sqr_norm_fn diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/multi_head_attention_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/multi_head_attention_test.py new file mode 100644 index 0000000..8f0f289 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/multi_head_attention_test.py @@ -0,0 +1,375 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils +from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import multi_head_attention + + +def get_attention_layer_generators(): + def basic_attention_layer( + num_heads, + key_dim, + value_dim, + dropout, + use_bias, + output_shape, + ): + return tf.keras.layers.MultiHeadAttention( + num_heads, + key_dim, + value_dim=value_dim, + dropout=dropout, + use_bias=use_bias, + output_shape=output_shape, + ) + + return { + 'basic_attention_layer': basic_attention_layer, + } + + +def make_one_layer_attention_model(layer_generator, input_dims, output_dims): + """Creates a 1-layer MultiHeadAttention model.""" + inputs, input_args, input_kwargs = get_multi_head_attention_model_inputs( + input_dims + ) + layer1 = layer_generator(input_dims, output_dims) + del output_dims + temp1 = layer1(*input_args, **input_kwargs) + outputs = common_test_utils.reshape_and_sum(temp1) + return tf.keras.Model(inputs=inputs, outputs=outputs) + + +def make_two_layer_attention_model(layer_generator, input_dims, output_dims): + """Creates a 2-layer MultiHeadAttention model.""" + inputs, input_args, input_kwargs = get_multi_head_attention_model_inputs( + input_dims + ) + layer1 = layer_generator(input_dims, output_dims) + temp1 = layer1(*input_args, **input_kwargs) + temp2 = tf.keras.layers.Dense(1)(temp1) + outputs = common_test_utils.reshape_and_sum(temp2) + return tf.keras.Model(inputs=inputs, outputs=outputs) + + +def get_attention_model_generators(): + return { + 'seq1_mha': make_one_layer_attention_model, + 'seq2_mha': make_two_layer_attention_model, + } + + +def get_attention_parameter_tuples(): + # (query_input_dims, value_input_dims, use_key, use_attention_mask, + # num_heads, key_dim, value_dim, dropout, use_bias, output_shape) + return [ + # Small instances, default flags. + ([2, 3], [3, 4], False, False, 2, 2, 3, 0.0, True, None), # defaults + ([2, 3], [3, 4], True, False, 2, 2, 3, 0.0, True, None), # use key + ([2, 3], [3, 4], False, False, 2, 2, 3, 0.0, False, None), # no bias + ([2, 3], [3, 4], False, False, 2, 2, 3, 0.1, True, None), # dropout + ([2, 3], [3, 4], False, False, 2, 2, 3, 0.0, True, [3]), # output shape + ([2, 3], [3, 4], False, False, 1, 2, 3, 0.0, True, 3), # single head + ([2, 3], [3, 4], False, True, 2, 2, 3, 0.0, True, None), # attention mask + ] + + +def get_attention_layer_registries(): + attention_and_dense = layer_registry.LayerRegistry() + attention_and_dense.insert( + tf.keras.layers.MultiHeadAttention, + multi_head_attention.multi_head_attention_layer_computation, + ) + attention_and_dense.insert( + tf.keras.layers.Dense, + dense.dense_layer_computation, + ) + return { + 'attention_and_dense': attention_and_dense, + } + + +def get_multi_head_attention_example_inputs( + query_input_dims, + value_input_dims, + ragged_key=False, + ragged_value=False, + ragged_query=False, +): + """Generates example MultiHeadAttention concrete inputs for testing.""" + # Each batched input is a reshape of a `tf.range()` call. + batch_size = 2 + # Query input tensor. + query_size = tf.reduce_prod(query_input_dims) + query_tsr = tf.range(batch_size * query_size, dtype=tf.float32) / tf.cast( + query_size, tf.float32 + ) + query_batch = tf.reshape(query_tsr, [batch_size] + query_input_dims) + # Value input tensor. + value_size = tf.reduce_prod(value_input_dims) + value_tsr = ( + 2.0 + * tf.range(batch_size * value_size, dtype=tf.float32) + / tf.cast(value_size, tf.float32) + ) + value_batch = tf.reshape(value_tsr, [batch_size] + value_input_dims) + # Key input tensor (optional). + key_tsr = ( + 3.0 + * tf.range(batch_size * value_size, dtype=tf.float32) + / tf.cast(value_size, tf.float32) + ) + key_batch = tf.reshape(key_tsr, [batch_size] + value_input_dims) + # Attention mask input tensor (optional). + mask_size = tf.reduce_prod(query_input_dims[:-1]) * tf.reduce_prod( + value_input_dims[:-1] + ) + mask_input_dims = query_input_dims[:-1] + value_input_dims[:-1] + mask_tsr = tf.random.uniform([int(batch_size * mask_size)]) <= 0.5 + mask_batch = tf.reshape(mask_tsr, [batch_size] + mask_input_dims) + # Convert to ragged if needed. + if ragged_query: + query_batch = tf.RaggedTensor.from_tensor(query_batch) + if ragged_value: + value_batch = tf.RaggedTensor.from_tensor(value_batch) + if ragged_key: + key_batch = tf.RaggedTensor.from_tensor(key_batch) + return query_batch, value_batch, key_batch, mask_batch + + +def get_multi_head_attention_model_inputs(input_dims): + """Creates MultiHeadAttention symbolic model input.""" + ( + query_input_dims, + value_input_dims, + key_input_dims, + mask_input_dims, + use_key, + use_mask, + ) = input_dims + query_input = tf.keras.Input(shape=query_input_dims) + value_input = tf.keras.Input(shape=value_input_dims) + key_input = tf.keras.Input(shape=key_input_dims) + mask_input = tf.keras.Input(shape=mask_input_dims) + + input_args = (query_input, value_input) + input_kwargs = {} + if use_key: + input_kwargs['key'] = key_input + if use_mask: + input_kwargs['attention_mask'] = mask_input + return input_args + (key_input, mask_input), input_args, input_kwargs + + +class CheckOutputs(tf.test.TestCase, parameterized.TestCase): + + @parameterized.product( + use_key=[True, False], + use_query_as_kwarg=[True, False], + use_value_as_kwarg=[True, False], + use_attention_mask=[True, False], + return_scores=[True, False], + ragged_key=[True, False], + ragged_value=[True, False], + ragged_query=[True, False], + ) + def test_verify_consistent_outputs( + self, + use_key, + use_query_as_kwarg, + use_value_as_kwarg, + use_attention_mask, + return_scores, + ragged_key, + ragged_value, + ragged_query, + ): + num_heads = 2 + key_dim, value_dim = (2, 3) + query_input_dims = [2, 3] + value_input_dims = [3, 4] + query_batch, value_batch, key_batch, mask_batch = ( + get_multi_head_attention_example_inputs( + query_input_dims, + value_input_dims, + ragged_key=ragged_key, + ragged_value=ragged_value, + ragged_query=ragged_query, + ) + ) + layer_instance = tf.keras.layers.MultiHeadAttention( + num_heads, key_dim, value_dim + ) + + # Set up test inputs, branched on input order. + input_kwargs = { + 'key': key_batch if use_key else None, + 'attention_mask': mask_batch if use_attention_mask else None, + 'return_attention_scores': return_scores, + } + input_args = tuple() + if use_value_as_kwarg and use_query_as_kwarg: + input_kwargs['query'] = query_batch + input_kwargs['value'] = value_batch + elif use_value_as_kwarg: + input_kwargs['value'] = value_batch + input_args = (query_batch,) + elif use_query_as_kwarg: + # Invalid test case; cannot pass query kwarg after value. + return + else: + input_args = (query_batch, key_batch) + + dummy_tape = tf.GradientTape() + with dummy_tape: + _, computed_outputs, _ = ( + multi_head_attention.multi_head_attention_layer_computation( + layer_instance=layer_instance, + input_args=input_args, + input_kwargs=input_kwargs, + tape=dummy_tape, + ) + ) + true_outputs = layer_instance(*input_args, **input_kwargs) + self.assertAllClose(computed_outputs, true_outputs) + + +class GradNormTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + self.strategy = tf.distribute.get_strategy() + self.using_tpu = False + + @parameterized.product( + model_name=list(get_attention_model_generators().keys()), + layer_name=list(get_attention_layer_generators().keys()), + layer_registry_name=list(get_attention_layer_registries().keys()), + param_tuple=get_attention_parameter_tuples(), + num_microbatches=[None, 2], + is_eager=[True, False], + ) + def test_gradient_norms_on_various_models( + self, + model_name, + layer_name, + layer_registry_name, + param_tuple, + num_microbatches, + is_eager, + ): + # Parse inputs to generate test data. + ( + query_input_dims, + value_input_dims, + use_key, + use_attention_mask, + num_heads, + key_dim, + value_dim, + dropout, + use_bias, + output_shape, + ) = param_tuple + attention_generator_inputs = ( + num_heads, + key_dim, + value_dim, + dropout, + use_bias, + output_shape, + ) + query_batch, value_batch, key_batch, mask_batch = ( + get_multi_head_attention_example_inputs( + query_input_dims, value_input_dims + ) + ) + mask_input_dims = query_input_dims[:-1] + value_input_dims[:-1] + + # Make the layer generator via currying. + attention_generator = get_attention_layer_generators()[layer_name] + + def curried_generator(a, b): + del a, b + return attention_generator(*attention_generator_inputs) + + # Load shared assets to all devices. + with self.strategy.scope(): + model = common_test_utils.get_model_from_generator( + model_generator=get_attention_model_generators()[model_name], + layer_generator=curried_generator, + input_dims=( + query_input_dims, + value_input_dims, + value_input_dims, + mask_input_dims, + use_key, + use_attention_mask, + ), + output_dims=None, + is_eager=is_eager, + ) + + # Define the main testing ops. These may be later compiled to a Graph op. + def test_op(query_batch, value_batch, key_batch, mask_batch): + return common_test_utils.get_computed_and_true_norms_from_model( + model=model, + per_example_loss_fn=None, + num_microbatches=num_microbatches, + x_batch=(query_batch, value_batch, key_batch, mask_batch), + weight_batch=None, + registry=get_attention_layer_registries()[layer_registry_name], + partial=False, + ) + + # TPUs can only run `tf.function`-decorated functions. + if self.using_tpu: + test_op = tf.function(test_op, jit_compile=True, autograph=False) + + # TPUs use lower precision than CPUs, so we relax our criterion. + # E.g., one of the TPU runs generated the following results: + # + # computed_norm = 22.756414 + # true_norm = 23.338600 + # abs_diff = 0.58218575 + # rel_diff = 0.02494519 + # + # which is a reasonable level of error for computing gradient norms. + # Other trials also give an absolute (resp. relative) error of around + # 0.05 (resp. 0.0015). + rtol = 1e-1 if self.using_tpu else 1e-2 + atol = 1e-0 if self.using_tpu else 1e-1 + + # Set up the device ops and run the test. + computed_norms, true_norms = self.strategy.run( + test_op, args=(query_batch, value_batch, key_batch, mask_batch) + ) + # TPUs return replica contexts, which must be unwrapped. + batch_size = tf.shape(query_batch)[0] + if self.using_tpu: + common_test_utils.assert_replica_values_are_close(self, computed_norms) + common_test_utils.assert_replica_values_are_close(self, true_norms) + computed_norms = computed_norms.values[0] + true_norms = true_norms.values[0] + expected_size = num_microbatches or batch_size + self.assertEqual(tf.shape(computed_norms)[0], expected_size) + self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/multi_head_attention_tpu_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/multi_head_attention_tpu_test.py new file mode 100644 index 0000000..9d9103d --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/multi_head_attention_tpu_test.py @@ -0,0 +1,30 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import multi_head_attention_test + + +class GradNormTpuTest(multi_head_attention_test.GradNormTest): + + def setUp(self): + super(multi_head_attention_test.GradNormTest, self).setUp() + self.strategy = ctu.create_tpu_strategy() + self.assertIn('TPU', self.strategy.extended.worker_devices[0]) + self.using_tpu = True + + +if __name__ == '__main__': + tf.test.main()