Implement and test a registry function for tfm.nlp.layers.OnDeviceEmbedding
.
This CL also moves the common embedding `sqr_norm_fn` logic between `tf.keras.layers.Embedding` and `tfm.nlp.layers.OnDeviceEmbedding` into a new registry function utility file. PiperOrigin-RevId: 564481407
This commit is contained in:
parent
113b27be43
commit
bcc0d4927e
10 changed files with 443 additions and 122 deletions
|
@ -37,3 +37,4 @@ tensorflow-datasets~=4.5
|
|||
tensorflow-estimator~=2.4
|
||||
tensorflow-probability~=0.20.0
|
||||
tensorflow~=2.4
|
||||
tf-models-official~=2.13
|
||||
|
|
1
setup.py
1
setup.py
|
@ -44,6 +44,7 @@ setup(
|
|||
'tensorflow-estimator~=2.4',
|
||||
'tensorflow-probability~=0.20.0',
|
||||
'tensorflow~=2.4',
|
||||
'tf-models-official~=2.13',
|
||||
],
|
||||
packages=find_packages(),
|
||||
)
|
||||
|
|
|
@ -294,15 +294,13 @@ def make_bow_model(
|
|||
output_dims: List[int],
|
||||
) -> tf.keras.Model:
|
||||
"""Creates a simple embedding bow model."""
|
||||
del layer_generator
|
||||
inputs = tf.keras.Input(shape=input_dims)
|
||||
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
|
||||
# For the Embedding layer, input_dim is the vocabulary size. This should
|
||||
# be distinguished from the input_dim argument, which is the number of ids
|
||||
# in eache example.
|
||||
if len(output_dims) != 1:
|
||||
raise ValueError('Expected `output_dims` to be of size 1.')
|
||||
output_dim = output_dims[0]
|
||||
emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim)
|
||||
emb_layer = layer_generator(input_dims, output_dims)
|
||||
feature_embs = emb_layer(inputs)
|
||||
# Embeddings add one extra dimension to its inputs, which combined with the
|
||||
# batch dimension at dimension 0, equals two additional dimensions compared
|
||||
|
@ -321,18 +319,13 @@ def make_dense_bow_model(
|
|||
output_dims: List[int],
|
||||
) -> tf.keras.Model:
|
||||
"""Creates an embedding bow model with a `Dense` layer."""
|
||||
del layer_generator
|
||||
inputs = tf.keras.Input(shape=input_dims)
|
||||
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
|
||||
# For the Embedding layer, input_dim is the vocabulary size. This should
|
||||
# be distinguished from the input_dim argument, which is the number of ids
|
||||
# in eache example.
|
||||
cardinality = 10
|
||||
emb_layer = layer_generator(input_dims, output_dims)
|
||||
if len(output_dims) != 1:
|
||||
raise ValueError('Expected `output_dims` to be of size 1.')
|
||||
output_dim = output_dims[0]
|
||||
emb_layer = tf.keras.layers.Embedding(
|
||||
input_dim=cardinality, output_dim=output_dim
|
||||
)
|
||||
feature_embs = emb_layer(inputs)
|
||||
# Embeddings add one extra dimension to its inputs, which combined with the
|
||||
# batch dimension at dimension 0, equals two additional dimensions compared
|
||||
|
@ -353,18 +346,14 @@ def make_weighted_bow_model(
|
|||
) -> tf.keras.Model:
|
||||
"""Creates a weighted embedding bow model."""
|
||||
# NOTE: This model only accepts dense input tensors.
|
||||
del layer_generator
|
||||
inputs = tf.keras.Input(shape=input_dims)
|
||||
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
|
||||
# For the Embedding layer, input_dim is the vocabulary size. This should
|
||||
# be distinguished from the input_dim argument, which is the number of ids
|
||||
# in eache example.
|
||||
cardinality = 10
|
||||
emb_layer = layer_generator(input_dims, output_dims)
|
||||
if len(output_dims) != 1:
|
||||
raise ValueError('Expected `output_dims` to be of size 1.')
|
||||
output_dim = output_dims[0]
|
||||
emb_layer = tf.keras.layers.Embedding(
|
||||
input_dim=cardinality, output_dim=output_dim
|
||||
)
|
||||
feature_embs = emb_layer(inputs)
|
||||
# Use deterministic weights to avoid seeding issues on TPUs.
|
||||
feature_shape = input_dims + [output_dim]
|
||||
|
|
|
@ -2,6 +2,13 @@ package(
|
|||
default_visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "registry_function_utils",
|
||||
srcs = ["registry_function_utils.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "einsum_utils",
|
||||
srcs = ["einsum_utils.py"],
|
||||
|
@ -45,7 +52,10 @@ py_library(
|
|||
name = "embedding",
|
||||
srcs = ["embedding.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases"],
|
||||
deps = [
|
||||
":registry_function_utils",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -63,6 +73,31 @@ py_test(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "nlp_on_device_embedding",
|
||||
srcs = ["nlp_on_device_embedding.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":registry_function_utils",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "nlp_on_device_embedding_test",
|
||||
srcs = ["nlp_on_device_embedding_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 6,
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":dense",
|
||||
":nlp_on_device_embedding",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "layer_normalization",
|
||||
srcs = ["layer_normalization.py"],
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from typing import Any, Mapping, Tuple, Union
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils
|
||||
|
||||
|
||||
def embedding_layer_computation(
|
||||
|
@ -71,110 +72,11 @@ def embedding_layer_computation(
|
|||
tape.watch(base_vars)
|
||||
outputs = tf.nn.embedding_lookup(base_vars, input_ids)
|
||||
|
||||
def sqr_norm_fn(base_vars_grads):
|
||||
"""Fast square norm function for Keras embedding layers.
|
||||
|
||||
Args:
|
||||
base_vars_grads: A list of batched embedding gradients.
|
||||
|
||||
Returns:
|
||||
A 1D `tf.Tensor` of squared gradient norms.
|
||||
|
||||
NOTE: to help understand the code, we document in the function body what
|
||||
the expected intermediate variables are for the below running example:
|
||||
|
||||
input_ids = [[1, 1, 2], [0], [2, 0]]
|
||||
base_vars_grads.indices = [1, 1, 2, 0, 2, 0]
|
||||
base_vars_grads.values = [[0.2], [0.2], [0.3], [0.1], [0.3], [0.1]]
|
||||
|
||||
For ease of reference, we also list these variables below:
|
||||
|
||||
row_indices = [[0], [0], [0], [1], [2], [2]]
|
||||
slice_indices = [[1], [1], [2], [0], [2], [0]]
|
||||
paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||
unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||
new_index_positions = [0, 0, 1, 2, 3, 4]
|
||||
num_unique_paired_indices = 5
|
||||
unique_batch_ids = [0, 0, 1, 2, 2]
|
||||
summed_gradients
|
||||
= [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
|
||||
= [[0.4], [0.3], [0.1], [0.3], [0.1]]
|
||||
sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
|
||||
"""
|
||||
# We first get a 1D tensor of the row indices.
|
||||
nrows = tf.shape(input_ids)[0]
|
||||
if isinstance(input_ids, tf.RaggedTensor):
|
||||
row_indices = tf.expand_dims(
|
||||
input_ids.merge_dims(1, -1).value_rowids(), axis=-1
|
||||
)
|
||||
elif isinstance(input_ids, tf.Tensor):
|
||||
ncols = tf.reduce_prod(tf.shape(input_ids)[1:])
|
||||
repeats = tf.repeat(ncols, nrows)
|
||||
row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
|
||||
)
|
||||
row_indices = tf.cast(row_indices, tf.int64)
|
||||
if num_microbatches is not None:
|
||||
microbatch_size = tf.cast(nrows / num_microbatches, tf.int64)
|
||||
nrows = num_microbatches
|
||||
row_indices = tf.cast(
|
||||
tf.math.floordiv(row_indices, microbatch_size), tf.int64
|
||||
)
|
||||
# NOTE: expected values for the running example above are
|
||||
# row_indices = [[0], [0], [0], [1], [2], [2]]
|
||||
|
||||
# Now, sum-reduce the `IndexedSlices` that is the result of a
|
||||
# `tape.gradient()` call. The sum is reduced by the repeated embedding
|
||||
# indices and batch index. It is adapted from the logic in:
|
||||
# tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices
|
||||
if not isinstance(base_vars_grads, tf.IndexedSlices):
|
||||
raise NotImplementedError(
|
||||
"Cannot parse embedding gradients of type: %s"
|
||||
% base_vars_grads.__class__.__name__
|
||||
)
|
||||
slice_indices = tf.expand_dims(base_vars_grads.indices, axis=-1)
|
||||
paired_indices = tf.concat(
|
||||
[tf.cast(row_indices, tf.int64), tf.cast(slice_indices, tf.int64)],
|
||||
axis=1,
|
||||
)
|
||||
(unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2(
|
||||
x=paired_indices, axis=[0]
|
||||
)
|
||||
# NOTE: expected values for the running example above are
|
||||
# slice_indices = [[1], [1], [2], [0], [2], [0]]
|
||||
# paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||
# unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||
# new_index_positions = [0, 0, 1, 2, 3, 4]
|
||||
|
||||
# Next, sum according to the new positions and compute the squared
|
||||
# gradient norms. Oddly enough, not sorting
|
||||
# these indices will break tensor shape inference logic on TPUs.
|
||||
num_unique_paired_indices = tf.shape(unique_paired_indices)[0]
|
||||
unique_batch_ids = unique_paired_indices[:, 0]
|
||||
summed_gradients = tf.math.unsorted_segment_sum(
|
||||
def sqr_norm_fn(base_vars_grads: tf.IndexedSlices):
|
||||
return registry_function_utils.embedding_sqr_norm_fn(
|
||||
base_vars_grads.values,
|
||||
new_index_positions,
|
||||
num_unique_paired_indices,
|
||||
input_ids,
|
||||
num_microbatches,
|
||||
)
|
||||
sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1)
|
||||
# NOTE: expected values for the running example above are
|
||||
# num_unique_paired_indices = 5
|
||||
# unique_batch_ids = [0, 0, 1, 2, 2]
|
||||
# summed_gradients
|
||||
# = [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
|
||||
# = [[0.4], [0.3], [0.1], [0.3], [0.1]]
|
||||
# sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
|
||||
|
||||
# Use a scatter-add strategy to ensure TPU compatibility.
|
||||
result = tf.zeros([nrows])
|
||||
return tf.tensor_scatter_nd_add(
|
||||
result,
|
||||
tf.expand_dims(unique_batch_ids, axis=-1),
|
||||
sqr_gradient_sum,
|
||||
)
|
||||
# NOTE: the expected output for the running example is
|
||||
# [0.16 + 0.09, 0.01, 0.09 + 0.01] = [0.25, 0.01, 0.1]
|
||||
|
||||
return base_vars, outputs, sqr_norm_fn
|
||||
|
|
|
@ -115,9 +115,13 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
# Load shared assets to all devices.
|
||||
with self.strategy.scope():
|
||||
|
||||
def embed_layer_generator(_, output_dims):
|
||||
return tf.keras.layers.Embedding(10, *output_dims)
|
||||
|
||||
model = common_test_utils.get_model_from_generator(
|
||||
model_generator=get_embedding_model_generators()[model_name],
|
||||
layer_generator=None,
|
||||
layer_generator=embed_layer_generator,
|
||||
input_dims=embed_indices.shape[1:],
|
||||
output_dims=[output_dim],
|
||||
is_eager=is_eager,
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2023, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast clipping function for `tfm.nlp.layers.OnDeviceEmbedding`."""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils
|
||||
|
||||
|
||||
def nlp_on_device_embedding_layer_computation(
|
||||
layer_instance: tf.keras.layers.Layer,
|
||||
input_args: Tuple[Any, ...],
|
||||
input_kwargs: Dict[str, Any],
|
||||
tape: tf.GradientTape,
|
||||
num_microbatches: Optional[tf.Tensor] = None,
|
||||
) -> type_aliases.RegistryFunctionOutput:
|
||||
"""Registry function for `tfm.nlp.layers.OnDeviceEmbedding`.
|
||||
|
||||
Args:
|
||||
layer_instance: A `tfm.nlp.layers.OnDeviceEmbedding` instance.
|
||||
input_args: See `dense_layer_computation()` in `dense.py`.
|
||||
input_kwargs: See `dense_layer_computation()` in `dense.py`.
|
||||
tape: See `dense_layer_computation()` in `dense.py`.
|
||||
num_microbatches: See `dense_layer_computation()` in `dense.py`.
|
||||
|
||||
Returns:
|
||||
See `dense_layer_computation()` in `dense.py`.
|
||||
"""
|
||||
if input_kwargs:
|
||||
raise ValueError("Embedding layer calls should not receive kwargs.")
|
||||
del input_kwargs
|
||||
if len(input_args) != 1:
|
||||
raise ValueError("Only layer inputs of length 1 are permitted.")
|
||||
if hasattr(layer_instance, "_use_one_hot"):
|
||||
if layer_instance._use_one_hot: # pylint: disable=protected-access
|
||||
raise NotImplementedError(
|
||||
"The embedding feature '_use_one_hot' is not supported."
|
||||
)
|
||||
# NOTE: Since the implementation of `tfm.nlp.layers.OnDeviceEmbedding` uses
|
||||
# `.set_shape()`, we can assume that inputs are not ragged.
|
||||
input_ids = tf.cast(*input_args, tf.int32)
|
||||
if len(layer_instance.trainable_variables) != 1:
|
||||
raise ValueError(
|
||||
"Only layer instances with only one set of trainable variables"
|
||||
"are permitted."
|
||||
)
|
||||
base_vars = layer_instance.trainable_variables[0]
|
||||
tape.watch(base_vars)
|
||||
outputs = layer_instance(input_ids)
|
||||
|
||||
def sqr_norm_fn(base_vars_grads: tf.IndexedSlices):
|
||||
return registry_function_utils.embedding_sqr_norm_fn(
|
||||
base_vars_grads.values,
|
||||
input_ids,
|
||||
num_microbatches,
|
||||
)
|
||||
|
||||
return base_vars, outputs, sqr_norm_fn
|
|
@ -0,0 +1,159 @@
|
|||
# Copyright 2023, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
import tensorflow_models as tfm
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import nlp_on_device_embedding
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Helper functions.
|
||||
# ==============================================================================
|
||||
def get_nlp_on_device_embedding_model_generators():
|
||||
return {
|
||||
'bow1': common_test_utils.make_bow_model,
|
||||
'bow2': common_test_utils.make_dense_bow_model,
|
||||
'weighted_bow1': common_test_utils.make_weighted_bow_model,
|
||||
}
|
||||
|
||||
|
||||
def get_nlp_on_device_embedding_inputs():
|
||||
"""Generates input_data."""
|
||||
return [
|
||||
# 2D inputs.
|
||||
[[0, 1]],
|
||||
[[0, 1], [1, 1], [0, 0]],
|
||||
# 3D inputs.
|
||||
[[[0, 1]]],
|
||||
[[[0, 1]], [[1, 1]], [[0, 0]]],
|
||||
]
|
||||
|
||||
|
||||
def get_nlp_on_device_embedding_layer_registries():
|
||||
dbl_registry = layer_registry.LayerRegistry()
|
||||
dbl_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
|
||||
dbl_registry.insert(
|
||||
tfm.nlp.layers.OnDeviceEmbedding,
|
||||
nlp_on_device_embedding.nlp_on_device_embedding_layer_computation
|
||||
)
|
||||
return {
|
||||
'embed_and_dense': dbl_registry,
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main tests.
|
||||
# ==============================================================================
|
||||
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.strategy = tf.distribute.get_strategy()
|
||||
|
||||
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
|
||||
# supports them for embeddings.
|
||||
@parameterized.product(
|
||||
input_data=get_nlp_on_device_embedding_inputs(),
|
||||
scale_factor=[None, 0.5, 1.0],
|
||||
model_name=list(
|
||||
get_nlp_on_device_embedding_model_generators().keys()
|
||||
),
|
||||
output_dim=[2],
|
||||
layer_registry_name=list(
|
||||
get_nlp_on_device_embedding_layer_registries().keys()
|
||||
),
|
||||
num_microbatches=[None, 2],
|
||||
is_eager=[True, False],
|
||||
partial=[True, False],
|
||||
)
|
||||
def test_gradient_norms_on_various_models(
|
||||
self,
|
||||
input_data,
|
||||
scale_factor,
|
||||
model_name,
|
||||
output_dim,
|
||||
layer_registry_name,
|
||||
num_microbatches,
|
||||
is_eager,
|
||||
partial,
|
||||
):
|
||||
# Parse inputs to generate test data.
|
||||
embed_indices = tf.convert_to_tensor(input_data, dtype_hint=tf.int32)
|
||||
|
||||
# The following are invalid test combinations and, hence, are skipped.
|
||||
batch_size = embed_indices.shape[0]
|
||||
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
|
||||
if num_microbatches is not None and batch_size % num_microbatches != 0:
|
||||
return
|
||||
|
||||
# Load shared assets to all devices.
|
||||
with self.strategy.scope():
|
||||
|
||||
def embed_layer_generator(_, output_dims):
|
||||
return tfm.nlp.layers.OnDeviceEmbedding(
|
||||
10,
|
||||
*output_dims,
|
||||
scale_factor=scale_factor
|
||||
)
|
||||
|
||||
model = common_test_utils.get_model_from_generator(
|
||||
model_generator=(
|
||||
get_nlp_on_device_embedding_model_generators()[model_name]
|
||||
),
|
||||
layer_generator=embed_layer_generator,
|
||||
input_dims=embed_indices.shape[1:],
|
||||
output_dims=[output_dim],
|
||||
is_eager=is_eager,
|
||||
)
|
||||
|
||||
# Define the main testing ops. These may be later compiled to a Graph op.
|
||||
def test_op(x_batch):
|
||||
return common_test_utils.get_computed_and_true_norms_from_model(
|
||||
model=model,
|
||||
per_example_loss_fn=None,
|
||||
num_microbatches=num_microbatches,
|
||||
x_batch=x_batch,
|
||||
registry=(
|
||||
get_nlp_on_device_embedding_layer_registries()[
|
||||
layer_registry_name
|
||||
]
|
||||
),
|
||||
partial=partial,
|
||||
)
|
||||
|
||||
# TPUs can only run `tf.function`-decorated functions.
|
||||
if using_tpu:
|
||||
test_op = tf.function(test_op, autograph=False)
|
||||
|
||||
# Set up the device ops and run the test.
|
||||
computed_norms, true_norms = self.strategy.run(
|
||||
test_op, args=(embed_indices,)
|
||||
)
|
||||
# TPUs return replica contexts, which must be unwrapped.
|
||||
if using_tpu:
|
||||
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||
computed_norms = computed_norms.values[0]
|
||||
true_norms = true_norms.values[0]
|
||||
expected_size = num_microbatches or batch_size
|
||||
self.assertEqual(tf.shape(computed_norms)[0], expected_size)
|
||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2023, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import nlp_on_device_embedding_test
|
||||
|
||||
|
||||
class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.strategy = common_test_utils.create_tpu_strategy()
|
||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright 2023, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Useful utility functions for implementing registry functions."""
|
||||
|
||||
from typing import Union
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||
|
||||
|
||||
def embedding_sqr_norm_fn(
|
||||
grad_values: tf.Tensor,
|
||||
input_ids: Union[tf.Tensor, tf.RaggedTensor],
|
||||
num_microbatches: Union[tf.Tensor, None] = None,
|
||||
)->type_aliases.RegistryFunctionOutput:
|
||||
"""Fast square norm function for general embedding layers.
|
||||
|
||||
Args:
|
||||
grad_values: Batched embedding gradient values. These come from the 'values'
|
||||
field of the `IndexedSlices` object representing the embedding gradient.
|
||||
input_ids: A `tf.Tensor` of queried embedding ids. The values in this input
|
||||
are the same as the `indices` field of the `IndexedSlices` object
|
||||
representing the embedding gradient.
|
||||
num_microbatches: An optional number of microbatches.
|
||||
|
||||
Returns:
|
||||
A 1D `tf.Tensor` of squared gradient norms.
|
||||
|
||||
NOTE: to help understand the code, we document in the function body what
|
||||
the expected intermediate variables are for the below running example:
|
||||
|
||||
input_ids = [[1, 1, 2], [0], [2, 0]]
|
||||
grads_values = [[0.2], [0.2], [0.3], [0.1], [0.3], [0.1]]
|
||||
|
||||
For ease of reference, we also list these variables below:
|
||||
|
||||
row_indices = [[0], [0], [0], [1], [2], [2]]
|
||||
flattened_indices = [[1], [1], [2], [0], [2], [0]]
|
||||
paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||
unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||
new_index_positions = [0, 0, 1, 2, 3, 4]
|
||||
num_unique_paired_indices = 5
|
||||
unique_batch_ids = [0, 0, 1, 2, 2]
|
||||
summed_gradients
|
||||
= [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
|
||||
= [[0.4], [0.3], [0.1], [0.3], [0.1]]
|
||||
sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
|
||||
"""
|
||||
# We first get a 1D tensor of the row indices.
|
||||
nrows = tf.shape(input_ids)[0]
|
||||
if isinstance(input_ids, tf.RaggedTensor):
|
||||
row_indices = tf.expand_dims(
|
||||
input_ids.merge_dims(1, -1).value_rowids(), axis=-1
|
||||
)
|
||||
elif isinstance(input_ids, tf.Tensor):
|
||||
ncols = tf.reduce_prod(tf.shape(input_ids)[1:])
|
||||
repeats = tf.repeat(ncols, nrows)
|
||||
row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
|
||||
)
|
||||
row_indices = tf.cast(row_indices, tf.int64)
|
||||
if num_microbatches is not None:
|
||||
microbatch_size = tf.cast(nrows / num_microbatches, tf.int64)
|
||||
nrows = num_microbatches
|
||||
row_indices = tf.cast(
|
||||
tf.math.floordiv(row_indices, microbatch_size), tf.int64
|
||||
)
|
||||
# NOTE: expected values for the running example above are
|
||||
# row_indices = [[0], [0], [0], [1], [2], [2]]
|
||||
|
||||
# Now, sum-reduce the `IndexedSlices` that is the result of a
|
||||
# `tape.gradient()` call. The sum is reduced by the repeated embedding
|
||||
# indices and batch index. It is adapted from the logic in:
|
||||
# tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices
|
||||
flattened_indices = tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1)
|
||||
paired_indices = tf.concat(
|
||||
[
|
||||
tf.cast(row_indices, tf.int64),
|
||||
tf.cast(flattened_indices, tf.int64)
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
(unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2(
|
||||
x=paired_indices, axis=[0]
|
||||
)
|
||||
# NOTE: expected values for the running example above are
|
||||
# flattened_indices = [[1], [1], [2], [0], [2], [0]]
|
||||
# paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||
# unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||
# new_index_positions = [0, 0, 1, 2, 3, 4]
|
||||
|
||||
# Next, sum according to the new positions and compute the squared
|
||||
# gradient norms. Oddly enough, not sorting
|
||||
# these indices will break tensor shape inference logic on TPUs.
|
||||
num_unique_paired_indices = tf.shape(unique_paired_indices)[0]
|
||||
unique_batch_ids = unique_paired_indices[:, 0]
|
||||
summed_gradients = tf.math.unsorted_segment_sum(
|
||||
grad_values,
|
||||
new_index_positions,
|
||||
num_unique_paired_indices,
|
||||
)
|
||||
sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1)
|
||||
# NOTE: expected values for the running example above are
|
||||
# num_unique_paired_indices = 5
|
||||
# unique_batch_ids = [0, 0, 1, 2, 2]
|
||||
# summed_gradients
|
||||
# = [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
|
||||
# = [[0.4], [0.3], [0.1], [0.3], [0.1]]
|
||||
# sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
|
||||
|
||||
# Use a scatter-add strategy to ensure TPU compatibility.
|
||||
result = tf.zeros([nrows])
|
||||
return tf.tensor_scatter_nd_add(
|
||||
result,
|
||||
tf.expand_dims(unique_batch_ids, axis=-1),
|
||||
sqr_gradient_sum,
|
||||
)
|
||||
# NOTE: the expected output for the running example is
|
||||
# [0.16 + 0.09, 0.01, 0.09 + 0.01] = [0.25, 0.01, 0.1]
|
Loading…
Reference in a new issue