Implement and test a registry function for tf.keras.layers.MultiHeadAttention
.
PiperOrigin-RevId: 584620638
This commit is contained in:
parent
03db50ba94
commit
b19088f048
4 changed files with 645 additions and 0 deletions
|
@ -176,3 +176,29 @@ py_test(
|
|||
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "multi_head_attention",
|
||||
srcs = ["multi_head_attention.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":einsum_dense",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "multi_head_attention_test",
|
||||
srcs = ["multi_head_attention_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 8,
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":dense",
|
||||
":multi_head_attention",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright 2023, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast clipping function for `tf.keras.layers.MultiHeadAttention`."""
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_dense
|
||||
|
||||
|
||||
def multi_head_attention_layer_computation(
|
||||
layer_instance: tf.keras.layers.MultiHeadAttention,
|
||||
input_args: Sequence[Any],
|
||||
input_kwargs: Mapping[str, Any],
|
||||
tape: tf.GradientTape,
|
||||
num_microbatches: Optional[tf.Tensor] = None,
|
||||
) -> type_aliases.RegistryFunctionOutput:
|
||||
"""Registry function for `tf.keras.layers.MultiHeadAttention`.
|
||||
|
||||
This function essentially applies the registry function for
|
||||
`tf.keras.layers.EinsumDense` three times. Some hints about the nature of
|
||||
the Einsum transforms are given below.
|
||||
|
||||
-------------------
|
||||
ABOUT INPUT SHAPES
|
||||
-------------------
|
||||
For a given {query, key, value} input `I` of shape
|
||||
|
||||
[Eq. A] tf.shape(I) == [n, a[0],... , a[k-1], b]
|
||||
|
||||
where `n` is the batch size, the corresponding Einsum equation for its
|
||||
`EinsumDense` transform is given by:
|
||||
|
||||
{n a[0] ... a[k-1] b},{b c d}->{n a[1] ... a[k-1] c d}
|
||||
|
||||
where `c` corresponds to the number of attention heads
|
||||
(`layer_instance.num_heads`) and `d` corresponds to the size per head
|
||||
(`layer_instance.key_dim` or `layer_instance.value_dim`).
|
||||
|
||||
It is expected that the rank of the query, key, and value inputs are the same.
|
||||
|
||||
------------------
|
||||
ABOUT OUTPUT SHAPE
|
||||
------------------
|
||||
Suppose the shape of the `query` input `Q` is given by [Eq. A] above with
|
||||
`I == Q`. Then, if `layer_instance.output_shape is None`, the output `O` of
|
||||
the layer satisfies `tf.shape(Q) == tf.shape(O)`. However, if we have
|
||||
`layer_instance.output_shape is not None`, then
|
||||
|
||||
tf.shape(Q) == [n, a[0], ..., a[k-1], *layer_instance.output_shape]
|
||||
|
||||
Args:
|
||||
layer_instance: A `tf.keras.layers.MultiHeadAttention` instance.
|
||||
input_args: See `dense_layer_computation()`.
|
||||
input_kwargs: See `dense_layer_computation()`.
|
||||
tape: See `dense_layer_computation()`.
|
||||
num_microbatches: See `dense_layer_computation()`.
|
||||
|
||||
Returns:
|
||||
See `dense_layer_computation()`.
|
||||
"""
|
||||
# ----------------------
|
||||
# PREPROCESS THE INPUTS.
|
||||
# ----------------------
|
||||
query = (
|
||||
input_kwargs.get("query")
|
||||
if input_kwargs.get("query") is not None
|
||||
else input_args[0]
|
||||
)
|
||||
value = (
|
||||
input_kwargs.get("value")
|
||||
if input_kwargs.get("value") is not None
|
||||
else input_args[1]
|
||||
)
|
||||
key = input_kwargs.get("key")
|
||||
attention_mask = input_kwargs.get("attention_mask")
|
||||
return_attention_scores = input_kwargs.get("return_attention_scores")
|
||||
training = input_kwargs.get("training")
|
||||
use_causal_mask = input_kwargs.get("use_causal_mask")
|
||||
attention_mask = layer_instance._compute_attention_mask( # pylint: disable=protected-access
|
||||
query,
|
||||
value,
|
||||
key=key,
|
||||
attention_mask=attention_mask,
|
||||
use_causal_mask=use_causal_mask,
|
||||
)
|
||||
if not layer_instance._built_from_signature: # pylint: disable=protected-access
|
||||
layer_instance._build_from_signature(query=query, value=value, key=key) # pylint: disable=protected-access
|
||||
if key is None:
|
||||
key = value
|
||||
|
||||
query_lengths = 0
|
||||
query_is_ragged = isinstance(query, tf.RaggedTensor)
|
||||
if query_is_ragged:
|
||||
query_lengths = query.nested_row_lengths()
|
||||
query = query.to_tensor()
|
||||
|
||||
key_is_ragged = isinstance(key, tf.RaggedTensor)
|
||||
value_is_ragged = isinstance(value, tf.RaggedTensor)
|
||||
if key_is_ragged and value_is_ragged:
|
||||
bounding_shape = tf.math.maximum(
|
||||
key.bounding_shape(), value.bounding_shape()
|
||||
)
|
||||
key = key.to_tensor(shape=bounding_shape)
|
||||
value = value.to_tensor(shape=bounding_shape)
|
||||
elif key_is_ragged:
|
||||
key = key.to_tensor(shape=tf.shape(value))
|
||||
elif value_is_ragged:
|
||||
value = value.to_tensor(shape=tf.shape(key))
|
||||
else:
|
||||
pass
|
||||
# ------------------------------
|
||||
# APPLY THE FAST CLIPPING TRICK.
|
||||
# ------------------------------
|
||||
# trainable_op: W_q * QUERY
|
||||
query_base_vars, query, query_sqr_norm_fn = (
|
||||
einsum_dense.einsum_layer_computation(
|
||||
layer_instance._query_dense, # pylint: disable=protected-access
|
||||
(query,),
|
||||
{},
|
||||
tape,
|
||||
num_microbatches,
|
||||
)
|
||||
)
|
||||
# trainable_op: W_k * KEY
|
||||
key_base_vars, key, key_sqr_norm_fn = einsum_dense.einsum_layer_computation(
|
||||
layer_instance._key_dense, # pylint: disable=protected-access
|
||||
(key,),
|
||||
{},
|
||||
tape,
|
||||
num_microbatches,
|
||||
)
|
||||
# trainable_op: W_v * VALUE
|
||||
value_base_vars, value, value_sqr_norm_fn = (
|
||||
einsum_dense.einsum_layer_computation(
|
||||
layer_instance._value_dense, # pylint: disable=protected-access
|
||||
(value,),
|
||||
{},
|
||||
tape,
|
||||
num_microbatches,
|
||||
)
|
||||
)
|
||||
# op: TEMP = ATTENTION(W_q * QUERY, W_k * KEY, W_v * VALUE)
|
||||
temp_output, attention_scores = layer_instance._compute_attention( # pylint: disable=protected-access
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask,
|
||||
training,
|
||||
)
|
||||
# trainable_op: W_o * OUTPUT
|
||||
(
|
||||
attention_output_base_vars,
|
||||
attention_output,
|
||||
attention_output_sqr_norm_fn,
|
||||
) = einsum_dense.einsum_layer_computation(
|
||||
layer_instance._output_dense, # pylint: disable=protected-access
|
||||
(temp_output,),
|
||||
{},
|
||||
tape,
|
||||
num_microbatches,
|
||||
)
|
||||
# ------------------------
|
||||
# POSTPROCESS THE OUTPUTS.
|
||||
# ------------------------
|
||||
# Get registry output tensors ready.
|
||||
if query_is_ragged:
|
||||
attention_output = tf.RaggedTensor.from_tensor(
|
||||
attention_output, query_lengths
|
||||
)
|
||||
outputs = attention_output
|
||||
if return_attention_scores:
|
||||
outputs = (attention_output, attention_scores)
|
||||
base_vars = [
|
||||
query_base_vars,
|
||||
key_base_vars,
|
||||
value_base_vars,
|
||||
attention_output_base_vars,
|
||||
]
|
||||
|
||||
# The square norm function should just aggregate the squared norms
|
||||
# corresponding to each trainable op.
|
||||
def sqr_norm_fn(grad_list):
|
||||
if len(grad_list) != 4:
|
||||
raise ValueError(
|
||||
"Expected a container of 4 gradients for the `MultiheadAttention` "
|
||||
"square norm function's input. Instead, received a container of "
|
||||
"size "
|
||||
+ str(len(grad_list))
|
||||
)
|
||||
combined_sqr_norms = tf.stack(
|
||||
[
|
||||
query_sqr_norm_fn(grad_list[0]),
|
||||
key_sqr_norm_fn(grad_list[1]),
|
||||
value_sqr_norm_fn(grad_list[2]),
|
||||
attention_output_sqr_norm_fn(grad_list[3]),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
return tf.reduce_sum(combined_sqr_norms, axis=1)
|
||||
|
||||
return base_vars, outputs, sqr_norm_fn
|
|
@ -0,0 +1,375 @@
|
|||
# Copyright 2023, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import multi_head_attention
|
||||
|
||||
|
||||
def get_attention_layer_generators():
|
||||
def basic_attention_layer(
|
||||
num_heads,
|
||||
key_dim,
|
||||
value_dim,
|
||||
dropout,
|
||||
use_bias,
|
||||
output_shape,
|
||||
):
|
||||
return tf.keras.layers.MultiHeadAttention(
|
||||
num_heads,
|
||||
key_dim,
|
||||
value_dim=value_dim,
|
||||
dropout=dropout,
|
||||
use_bias=use_bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
|
||||
return {
|
||||
'basic_attention_layer': basic_attention_layer,
|
||||
}
|
||||
|
||||
|
||||
def make_one_layer_attention_model(layer_generator, input_dims, output_dims):
|
||||
"""Creates a 1-layer MultiHeadAttention model."""
|
||||
inputs, input_args, input_kwargs = get_multi_head_attention_model_inputs(
|
||||
input_dims
|
||||
)
|
||||
layer1 = layer_generator(input_dims, output_dims)
|
||||
del output_dims
|
||||
temp1 = layer1(*input_args, **input_kwargs)
|
||||
outputs = common_test_utils.reshape_and_sum(temp1)
|
||||
return tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
|
||||
def make_two_layer_attention_model(layer_generator, input_dims, output_dims):
|
||||
"""Creates a 2-layer MultiHeadAttention model."""
|
||||
inputs, input_args, input_kwargs = get_multi_head_attention_model_inputs(
|
||||
input_dims
|
||||
)
|
||||
layer1 = layer_generator(input_dims, output_dims)
|
||||
temp1 = layer1(*input_args, **input_kwargs)
|
||||
temp2 = tf.keras.layers.Dense(1)(temp1)
|
||||
outputs = common_test_utils.reshape_and_sum(temp2)
|
||||
return tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
|
||||
def get_attention_model_generators():
|
||||
return {
|
||||
'seq1_mha': make_one_layer_attention_model,
|
||||
'seq2_mha': make_two_layer_attention_model,
|
||||
}
|
||||
|
||||
|
||||
def get_attention_parameter_tuples():
|
||||
# (query_input_dims, value_input_dims, use_key, use_attention_mask,
|
||||
# num_heads, key_dim, value_dim, dropout, use_bias, output_shape)
|
||||
return [
|
||||
# Small instances, default flags.
|
||||
([2, 3], [3, 4], False, False, 2, 2, 3, 0.0, True, None), # defaults
|
||||
([2, 3], [3, 4], True, False, 2, 2, 3, 0.0, True, None), # use key
|
||||
([2, 3], [3, 4], False, False, 2, 2, 3, 0.0, False, None), # no bias
|
||||
([2, 3], [3, 4], False, False, 2, 2, 3, 0.1, True, None), # dropout
|
||||
([2, 3], [3, 4], False, False, 2, 2, 3, 0.0, True, [3]), # output shape
|
||||
([2, 3], [3, 4], False, False, 1, 2, 3, 0.0, True, 3), # single head
|
||||
([2, 3], [3, 4], False, True, 2, 2, 3, 0.0, True, None), # attention mask
|
||||
]
|
||||
|
||||
|
||||
def get_attention_layer_registries():
|
||||
attention_and_dense = layer_registry.LayerRegistry()
|
||||
attention_and_dense.insert(
|
||||
tf.keras.layers.MultiHeadAttention,
|
||||
multi_head_attention.multi_head_attention_layer_computation,
|
||||
)
|
||||
attention_and_dense.insert(
|
||||
tf.keras.layers.Dense,
|
||||
dense.dense_layer_computation,
|
||||
)
|
||||
return {
|
||||
'attention_and_dense': attention_and_dense,
|
||||
}
|
||||
|
||||
|
||||
def get_multi_head_attention_example_inputs(
|
||||
query_input_dims,
|
||||
value_input_dims,
|
||||
ragged_key=False,
|
||||
ragged_value=False,
|
||||
ragged_query=False,
|
||||
):
|
||||
"""Generates example MultiHeadAttention concrete inputs for testing."""
|
||||
# Each batched input is a reshape of a `tf.range()` call.
|
||||
batch_size = 2
|
||||
# Query input tensor.
|
||||
query_size = tf.reduce_prod(query_input_dims)
|
||||
query_tsr = tf.range(batch_size * query_size, dtype=tf.float32) / tf.cast(
|
||||
query_size, tf.float32
|
||||
)
|
||||
query_batch = tf.reshape(query_tsr, [batch_size] + query_input_dims)
|
||||
# Value input tensor.
|
||||
value_size = tf.reduce_prod(value_input_dims)
|
||||
value_tsr = (
|
||||
2.0
|
||||
* tf.range(batch_size * value_size, dtype=tf.float32)
|
||||
/ tf.cast(value_size, tf.float32)
|
||||
)
|
||||
value_batch = tf.reshape(value_tsr, [batch_size] + value_input_dims)
|
||||
# Key input tensor (optional).
|
||||
key_tsr = (
|
||||
3.0
|
||||
* tf.range(batch_size * value_size, dtype=tf.float32)
|
||||
/ tf.cast(value_size, tf.float32)
|
||||
)
|
||||
key_batch = tf.reshape(key_tsr, [batch_size] + value_input_dims)
|
||||
# Attention mask input tensor (optional).
|
||||
mask_size = tf.reduce_prod(query_input_dims[:-1]) * tf.reduce_prod(
|
||||
value_input_dims[:-1]
|
||||
)
|
||||
mask_input_dims = query_input_dims[:-1] + value_input_dims[:-1]
|
||||
mask_tsr = tf.random.uniform([int(batch_size * mask_size)]) <= 0.5
|
||||
mask_batch = tf.reshape(mask_tsr, [batch_size] + mask_input_dims)
|
||||
# Convert to ragged if needed.
|
||||
if ragged_query:
|
||||
query_batch = tf.RaggedTensor.from_tensor(query_batch)
|
||||
if ragged_value:
|
||||
value_batch = tf.RaggedTensor.from_tensor(value_batch)
|
||||
if ragged_key:
|
||||
key_batch = tf.RaggedTensor.from_tensor(key_batch)
|
||||
return query_batch, value_batch, key_batch, mask_batch
|
||||
|
||||
|
||||
def get_multi_head_attention_model_inputs(input_dims):
|
||||
"""Creates MultiHeadAttention symbolic model input."""
|
||||
(
|
||||
query_input_dims,
|
||||
value_input_dims,
|
||||
key_input_dims,
|
||||
mask_input_dims,
|
||||
use_key,
|
||||
use_mask,
|
||||
) = input_dims
|
||||
query_input = tf.keras.Input(shape=query_input_dims)
|
||||
value_input = tf.keras.Input(shape=value_input_dims)
|
||||
key_input = tf.keras.Input(shape=key_input_dims)
|
||||
mask_input = tf.keras.Input(shape=mask_input_dims)
|
||||
|
||||
input_args = (query_input, value_input)
|
||||
input_kwargs = {}
|
||||
if use_key:
|
||||
input_kwargs['key'] = key_input
|
||||
if use_mask:
|
||||
input_kwargs['attention_mask'] = mask_input
|
||||
return input_args + (key_input, mask_input), input_args, input_kwargs
|
||||
|
||||
|
||||
class CheckOutputs(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.product(
|
||||
use_key=[True, False],
|
||||
use_query_as_kwarg=[True, False],
|
||||
use_value_as_kwarg=[True, False],
|
||||
use_attention_mask=[True, False],
|
||||
return_scores=[True, False],
|
||||
ragged_key=[True, False],
|
||||
ragged_value=[True, False],
|
||||
ragged_query=[True, False],
|
||||
)
|
||||
def test_verify_consistent_outputs(
|
||||
self,
|
||||
use_key,
|
||||
use_query_as_kwarg,
|
||||
use_value_as_kwarg,
|
||||
use_attention_mask,
|
||||
return_scores,
|
||||
ragged_key,
|
||||
ragged_value,
|
||||
ragged_query,
|
||||
):
|
||||
num_heads = 2
|
||||
key_dim, value_dim = (2, 3)
|
||||
query_input_dims = [2, 3]
|
||||
value_input_dims = [3, 4]
|
||||
query_batch, value_batch, key_batch, mask_batch = (
|
||||
get_multi_head_attention_example_inputs(
|
||||
query_input_dims,
|
||||
value_input_dims,
|
||||
ragged_key=ragged_key,
|
||||
ragged_value=ragged_value,
|
||||
ragged_query=ragged_query,
|
||||
)
|
||||
)
|
||||
layer_instance = tf.keras.layers.MultiHeadAttention(
|
||||
num_heads, key_dim, value_dim
|
||||
)
|
||||
|
||||
# Set up test inputs, branched on input order.
|
||||
input_kwargs = {
|
||||
'key': key_batch if use_key else None,
|
||||
'attention_mask': mask_batch if use_attention_mask else None,
|
||||
'return_attention_scores': return_scores,
|
||||
}
|
||||
input_args = tuple()
|
||||
if use_value_as_kwarg and use_query_as_kwarg:
|
||||
input_kwargs['query'] = query_batch
|
||||
input_kwargs['value'] = value_batch
|
||||
elif use_value_as_kwarg:
|
||||
input_kwargs['value'] = value_batch
|
||||
input_args = (query_batch,)
|
||||
elif use_query_as_kwarg:
|
||||
# Invalid test case; cannot pass query kwarg after value.
|
||||
return
|
||||
else:
|
||||
input_args = (query_batch, key_batch)
|
||||
|
||||
dummy_tape = tf.GradientTape()
|
||||
with dummy_tape:
|
||||
_, computed_outputs, _ = (
|
||||
multi_head_attention.multi_head_attention_layer_computation(
|
||||
layer_instance=layer_instance,
|
||||
input_args=input_args,
|
||||
input_kwargs=input_kwargs,
|
||||
tape=dummy_tape,
|
||||
)
|
||||
)
|
||||
true_outputs = layer_instance(*input_args, **input_kwargs)
|
||||
self.assertAllClose(computed_outputs, true_outputs)
|
||||
|
||||
|
||||
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.strategy = tf.distribute.get_strategy()
|
||||
self.using_tpu = False
|
||||
|
||||
@parameterized.product(
|
||||
model_name=list(get_attention_model_generators().keys()),
|
||||
layer_name=list(get_attention_layer_generators().keys()),
|
||||
layer_registry_name=list(get_attention_layer_registries().keys()),
|
||||
param_tuple=get_attention_parameter_tuples(),
|
||||
num_microbatches=[None, 2],
|
||||
is_eager=[True, False],
|
||||
)
|
||||
def test_gradient_norms_on_various_models(
|
||||
self,
|
||||
model_name,
|
||||
layer_name,
|
||||
layer_registry_name,
|
||||
param_tuple,
|
||||
num_microbatches,
|
||||
is_eager,
|
||||
):
|
||||
# Parse inputs to generate test data.
|
||||
(
|
||||
query_input_dims,
|
||||
value_input_dims,
|
||||
use_key,
|
||||
use_attention_mask,
|
||||
num_heads,
|
||||
key_dim,
|
||||
value_dim,
|
||||
dropout,
|
||||
use_bias,
|
||||
output_shape,
|
||||
) = param_tuple
|
||||
attention_generator_inputs = (
|
||||
num_heads,
|
||||
key_dim,
|
||||
value_dim,
|
||||
dropout,
|
||||
use_bias,
|
||||
output_shape,
|
||||
)
|
||||
query_batch, value_batch, key_batch, mask_batch = (
|
||||
get_multi_head_attention_example_inputs(
|
||||
query_input_dims, value_input_dims
|
||||
)
|
||||
)
|
||||
mask_input_dims = query_input_dims[:-1] + value_input_dims[:-1]
|
||||
|
||||
# Make the layer generator via currying.
|
||||
attention_generator = get_attention_layer_generators()[layer_name]
|
||||
|
||||
def curried_generator(a, b):
|
||||
del a, b
|
||||
return attention_generator(*attention_generator_inputs)
|
||||
|
||||
# Load shared assets to all devices.
|
||||
with self.strategy.scope():
|
||||
model = common_test_utils.get_model_from_generator(
|
||||
model_generator=get_attention_model_generators()[model_name],
|
||||
layer_generator=curried_generator,
|
||||
input_dims=(
|
||||
query_input_dims,
|
||||
value_input_dims,
|
||||
value_input_dims,
|
||||
mask_input_dims,
|
||||
use_key,
|
||||
use_attention_mask,
|
||||
),
|
||||
output_dims=None,
|
||||
is_eager=is_eager,
|
||||
)
|
||||
|
||||
# Define the main testing ops. These may be later compiled to a Graph op.
|
||||
def test_op(query_batch, value_batch, key_batch, mask_batch):
|
||||
return common_test_utils.get_computed_and_true_norms_from_model(
|
||||
model=model,
|
||||
per_example_loss_fn=None,
|
||||
num_microbatches=num_microbatches,
|
||||
x_batch=(query_batch, value_batch, key_batch, mask_batch),
|
||||
weight_batch=None,
|
||||
registry=get_attention_layer_registries()[layer_registry_name],
|
||||
partial=False,
|
||||
)
|
||||
|
||||
# TPUs can only run `tf.function`-decorated functions.
|
||||
if self.using_tpu:
|
||||
test_op = tf.function(test_op, jit_compile=True, autograph=False)
|
||||
|
||||
# TPUs use lower precision than CPUs, so we relax our criterion.
|
||||
# E.g., one of the TPU runs generated the following results:
|
||||
#
|
||||
# computed_norm = 22.756414
|
||||
# true_norm = 23.338600
|
||||
# abs_diff = 0.58218575
|
||||
# rel_diff = 0.02494519
|
||||
#
|
||||
# which is a reasonable level of error for computing gradient norms.
|
||||
# Other trials also give an absolute (resp. relative) error of around
|
||||
# 0.05 (resp. 0.0015).
|
||||
rtol = 1e-1 if self.using_tpu else 1e-2
|
||||
atol = 1e-0 if self.using_tpu else 1e-1
|
||||
|
||||
# Set up the device ops and run the test.
|
||||
computed_norms, true_norms = self.strategy.run(
|
||||
test_op, args=(query_batch, value_batch, key_batch, mask_batch)
|
||||
)
|
||||
# TPUs return replica contexts, which must be unwrapped.
|
||||
batch_size = tf.shape(query_batch)[0]
|
||||
if self.using_tpu:
|
||||
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||
computed_norms = computed_norms.values[0]
|
||||
true_norms = true_norms.values[0]
|
||||
expected_size = num_microbatches or batch_size
|
||||
self.assertEqual(tf.shape(computed_norms)[0], expected_size)
|
||||
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2023, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import multi_head_attention_test
|
||||
|
||||
|
||||
class GradNormTpuTest(multi_head_attention_test.GradNormTest):
|
||||
|
||||
def setUp(self):
|
||||
super(multi_head_attention_test.GradNormTest, self).setUp()
|
||||
self.strategy = ctu.create_tpu_strategy()
|
||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||
self.using_tpu = True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
Loading…
Reference in a new issue