Implement and test a registry function for tfm.nlp.layers.OnDeviceEmbedding
.
This CL also moves the common embedding `sqr_norm_fn` logic between `tf.keras.layers.Embedding` and `tfm.nlp.layers.OnDeviceEmbedding` into a new registry function utility file. PiperOrigin-RevId: 564481407
This commit is contained in:
parent
113b27be43
commit
bcc0d4927e
10 changed files with 443 additions and 122 deletions
|
@ -37,3 +37,4 @@ tensorflow-datasets~=4.5
|
||||||
tensorflow-estimator~=2.4
|
tensorflow-estimator~=2.4
|
||||||
tensorflow-probability~=0.20.0
|
tensorflow-probability~=0.20.0
|
||||||
tensorflow~=2.4
|
tensorflow~=2.4
|
||||||
|
tf-models-official~=2.13
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -44,6 +44,7 @@ setup(
|
||||||
'tensorflow-estimator~=2.4',
|
'tensorflow-estimator~=2.4',
|
||||||
'tensorflow-probability~=0.20.0',
|
'tensorflow-probability~=0.20.0',
|
||||||
'tensorflow~=2.4',
|
'tensorflow~=2.4',
|
||||||
|
'tf-models-official~=2.13',
|
||||||
],
|
],
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -294,15 +294,13 @@ def make_bow_model(
|
||||||
output_dims: List[int],
|
output_dims: List[int],
|
||||||
) -> tf.keras.Model:
|
) -> tf.keras.Model:
|
||||||
"""Creates a simple embedding bow model."""
|
"""Creates a simple embedding bow model."""
|
||||||
del layer_generator
|
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
|
||||||
inputs = tf.keras.Input(shape=input_dims)
|
|
||||||
# For the Embedding layer, input_dim is the vocabulary size. This should
|
# For the Embedding layer, input_dim is the vocabulary size. This should
|
||||||
# be distinguished from the input_dim argument, which is the number of ids
|
# be distinguished from the input_dim argument, which is the number of ids
|
||||||
# in eache example.
|
# in eache example.
|
||||||
if len(output_dims) != 1:
|
if len(output_dims) != 1:
|
||||||
raise ValueError('Expected `output_dims` to be of size 1.')
|
raise ValueError('Expected `output_dims` to be of size 1.')
|
||||||
output_dim = output_dims[0]
|
emb_layer = layer_generator(input_dims, output_dims)
|
||||||
emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim)
|
|
||||||
feature_embs = emb_layer(inputs)
|
feature_embs = emb_layer(inputs)
|
||||||
# Embeddings add one extra dimension to its inputs, which combined with the
|
# Embeddings add one extra dimension to its inputs, which combined with the
|
||||||
# batch dimension at dimension 0, equals two additional dimensions compared
|
# batch dimension at dimension 0, equals two additional dimensions compared
|
||||||
|
@ -321,18 +319,13 @@ def make_dense_bow_model(
|
||||||
output_dims: List[int],
|
output_dims: List[int],
|
||||||
) -> tf.keras.Model:
|
) -> tf.keras.Model:
|
||||||
"""Creates an embedding bow model with a `Dense` layer."""
|
"""Creates an embedding bow model with a `Dense` layer."""
|
||||||
del layer_generator
|
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
|
||||||
inputs = tf.keras.Input(shape=input_dims)
|
|
||||||
# For the Embedding layer, input_dim is the vocabulary size. This should
|
# For the Embedding layer, input_dim is the vocabulary size. This should
|
||||||
# be distinguished from the input_dim argument, which is the number of ids
|
# be distinguished from the input_dim argument, which is the number of ids
|
||||||
# in eache example.
|
# in eache example.
|
||||||
cardinality = 10
|
emb_layer = layer_generator(input_dims, output_dims)
|
||||||
if len(output_dims) != 1:
|
if len(output_dims) != 1:
|
||||||
raise ValueError('Expected `output_dims` to be of size 1.')
|
raise ValueError('Expected `output_dims` to be of size 1.')
|
||||||
output_dim = output_dims[0]
|
|
||||||
emb_layer = tf.keras.layers.Embedding(
|
|
||||||
input_dim=cardinality, output_dim=output_dim
|
|
||||||
)
|
|
||||||
feature_embs = emb_layer(inputs)
|
feature_embs = emb_layer(inputs)
|
||||||
# Embeddings add one extra dimension to its inputs, which combined with the
|
# Embeddings add one extra dimension to its inputs, which combined with the
|
||||||
# batch dimension at dimension 0, equals two additional dimensions compared
|
# batch dimension at dimension 0, equals two additional dimensions compared
|
||||||
|
@ -353,18 +346,14 @@ def make_weighted_bow_model(
|
||||||
) -> tf.keras.Model:
|
) -> tf.keras.Model:
|
||||||
"""Creates a weighted embedding bow model."""
|
"""Creates a weighted embedding bow model."""
|
||||||
# NOTE: This model only accepts dense input tensors.
|
# NOTE: This model only accepts dense input tensors.
|
||||||
del layer_generator
|
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
|
||||||
inputs = tf.keras.Input(shape=input_dims)
|
|
||||||
# For the Embedding layer, input_dim is the vocabulary size. This should
|
# For the Embedding layer, input_dim is the vocabulary size. This should
|
||||||
# be distinguished from the input_dim argument, which is the number of ids
|
# be distinguished from the input_dim argument, which is the number of ids
|
||||||
# in eache example.
|
# in eache example.
|
||||||
cardinality = 10
|
emb_layer = layer_generator(input_dims, output_dims)
|
||||||
if len(output_dims) != 1:
|
if len(output_dims) != 1:
|
||||||
raise ValueError('Expected `output_dims` to be of size 1.')
|
raise ValueError('Expected `output_dims` to be of size 1.')
|
||||||
output_dim = output_dims[0]
|
output_dim = output_dims[0]
|
||||||
emb_layer = tf.keras.layers.Embedding(
|
|
||||||
input_dim=cardinality, output_dim=output_dim
|
|
||||||
)
|
|
||||||
feature_embs = emb_layer(inputs)
|
feature_embs = emb_layer(inputs)
|
||||||
# Use deterministic weights to avoid seeding issues on TPUs.
|
# Use deterministic weights to avoid seeding issues on TPUs.
|
||||||
feature_shape = input_dims + [output_dim]
|
feature_shape = input_dims + [output_dim]
|
||||||
|
|
|
@ -2,6 +2,13 @@ package(
|
||||||
default_visibility = ["//visibility:public"],
|
default_visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "registry_function_utils",
|
||||||
|
srcs = ["registry_function_utils.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "einsum_utils",
|
name = "einsum_utils",
|
||||||
srcs = ["einsum_utils.py"],
|
srcs = ["einsum_utils.py"],
|
||||||
|
@ -45,7 +52,10 @@ py_library(
|
||||||
name = "embedding",
|
name = "embedding",
|
||||||
srcs = ["embedding.py"],
|
srcs = ["embedding.py"],
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases"],
|
deps = [
|
||||||
|
":registry_function_utils",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
@ -63,6 +73,31 @@ py_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "nlp_on_device_embedding",
|
||||||
|
srcs = ["nlp_on_device_embedding.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":registry_function_utils",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "nlp_on_device_embedding_test",
|
||||||
|
srcs = ["nlp_on_device_embedding_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
shard_count = 6,
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":dense",
|
||||||
|
":nlp_on_device_embedding",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "layer_normalization",
|
name = "layer_normalization",
|
||||||
srcs = ["layer_normalization.py"],
|
srcs = ["layer_normalization.py"],
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
from typing import Any, Mapping, Tuple, Union
|
from typing import Any, Mapping, Tuple, Union
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils
|
||||||
|
|
||||||
|
|
||||||
def embedding_layer_computation(
|
def embedding_layer_computation(
|
||||||
|
@ -71,110 +72,11 @@ def embedding_layer_computation(
|
||||||
tape.watch(base_vars)
|
tape.watch(base_vars)
|
||||||
outputs = tf.nn.embedding_lookup(base_vars, input_ids)
|
outputs = tf.nn.embedding_lookup(base_vars, input_ids)
|
||||||
|
|
||||||
def sqr_norm_fn(base_vars_grads):
|
def sqr_norm_fn(base_vars_grads: tf.IndexedSlices):
|
||||||
"""Fast square norm function for Keras embedding layers.
|
return registry_function_utils.embedding_sqr_norm_fn(
|
||||||
|
|
||||||
Args:
|
|
||||||
base_vars_grads: A list of batched embedding gradients.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A 1D `tf.Tensor` of squared gradient norms.
|
|
||||||
|
|
||||||
NOTE: to help understand the code, we document in the function body what
|
|
||||||
the expected intermediate variables are for the below running example:
|
|
||||||
|
|
||||||
input_ids = [[1, 1, 2], [0], [2, 0]]
|
|
||||||
base_vars_grads.indices = [1, 1, 2, 0, 2, 0]
|
|
||||||
base_vars_grads.values = [[0.2], [0.2], [0.3], [0.1], [0.3], [0.1]]
|
|
||||||
|
|
||||||
For ease of reference, we also list these variables below:
|
|
||||||
|
|
||||||
row_indices = [[0], [0], [0], [1], [2], [2]]
|
|
||||||
slice_indices = [[1], [1], [2], [0], [2], [0]]
|
|
||||||
paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
|
||||||
unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
|
||||||
new_index_positions = [0, 0, 1, 2, 3, 4]
|
|
||||||
num_unique_paired_indices = 5
|
|
||||||
unique_batch_ids = [0, 0, 1, 2, 2]
|
|
||||||
summed_gradients
|
|
||||||
= [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
|
|
||||||
= [[0.4], [0.3], [0.1], [0.3], [0.1]]
|
|
||||||
sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
|
|
||||||
"""
|
|
||||||
# We first get a 1D tensor of the row indices.
|
|
||||||
nrows = tf.shape(input_ids)[0]
|
|
||||||
if isinstance(input_ids, tf.RaggedTensor):
|
|
||||||
row_indices = tf.expand_dims(
|
|
||||||
input_ids.merge_dims(1, -1).value_rowids(), axis=-1
|
|
||||||
)
|
|
||||||
elif isinstance(input_ids, tf.Tensor):
|
|
||||||
ncols = tf.reduce_prod(tf.shape(input_ids)[1:])
|
|
||||||
repeats = tf.repeat(ncols, nrows)
|
|
||||||
row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1])
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
|
|
||||||
)
|
|
||||||
row_indices = tf.cast(row_indices, tf.int64)
|
|
||||||
if num_microbatches is not None:
|
|
||||||
microbatch_size = tf.cast(nrows / num_microbatches, tf.int64)
|
|
||||||
nrows = num_microbatches
|
|
||||||
row_indices = tf.cast(
|
|
||||||
tf.math.floordiv(row_indices, microbatch_size), tf.int64
|
|
||||||
)
|
|
||||||
# NOTE: expected values for the running example above are
|
|
||||||
# row_indices = [[0], [0], [0], [1], [2], [2]]
|
|
||||||
|
|
||||||
# Now, sum-reduce the `IndexedSlices` that is the result of a
|
|
||||||
# `tape.gradient()` call. The sum is reduced by the repeated embedding
|
|
||||||
# indices and batch index. It is adapted from the logic in:
|
|
||||||
# tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices
|
|
||||||
if not isinstance(base_vars_grads, tf.IndexedSlices):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Cannot parse embedding gradients of type: %s"
|
|
||||||
% base_vars_grads.__class__.__name__
|
|
||||||
)
|
|
||||||
slice_indices = tf.expand_dims(base_vars_grads.indices, axis=-1)
|
|
||||||
paired_indices = tf.concat(
|
|
||||||
[tf.cast(row_indices, tf.int64), tf.cast(slice_indices, tf.int64)],
|
|
||||||
axis=1,
|
|
||||||
)
|
|
||||||
(unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2(
|
|
||||||
x=paired_indices, axis=[0]
|
|
||||||
)
|
|
||||||
# NOTE: expected values for the running example above are
|
|
||||||
# slice_indices = [[1], [1], [2], [0], [2], [0]]
|
|
||||||
# paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
|
||||||
# unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
|
||||||
# new_index_positions = [0, 0, 1, 2, 3, 4]
|
|
||||||
|
|
||||||
# Next, sum according to the new positions and compute the squared
|
|
||||||
# gradient norms. Oddly enough, not sorting
|
|
||||||
# these indices will break tensor shape inference logic on TPUs.
|
|
||||||
num_unique_paired_indices = tf.shape(unique_paired_indices)[0]
|
|
||||||
unique_batch_ids = unique_paired_indices[:, 0]
|
|
||||||
summed_gradients = tf.math.unsorted_segment_sum(
|
|
||||||
base_vars_grads.values,
|
base_vars_grads.values,
|
||||||
new_index_positions,
|
input_ids,
|
||||||
num_unique_paired_indices,
|
num_microbatches,
|
||||||
)
|
)
|
||||||
sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1)
|
|
||||||
# NOTE: expected values for the running example above are
|
|
||||||
# num_unique_paired_indices = 5
|
|
||||||
# unique_batch_ids = [0, 0, 1, 2, 2]
|
|
||||||
# summed_gradients
|
|
||||||
# = [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
|
|
||||||
# = [[0.4], [0.3], [0.1], [0.3], [0.1]]
|
|
||||||
# sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
|
|
||||||
|
|
||||||
# Use a scatter-add strategy to ensure TPU compatibility.
|
|
||||||
result = tf.zeros([nrows])
|
|
||||||
return tf.tensor_scatter_nd_add(
|
|
||||||
result,
|
|
||||||
tf.expand_dims(unique_batch_ids, axis=-1),
|
|
||||||
sqr_gradient_sum,
|
|
||||||
)
|
|
||||||
# NOTE: the expected output for the running example is
|
|
||||||
# [0.16 + 0.09, 0.01, 0.09 + 0.01] = [0.25, 0.01, 0.1]
|
|
||||||
|
|
||||||
return base_vars, outputs, sqr_norm_fn
|
return base_vars, outputs, sqr_norm_fn
|
||||||
|
|
|
@ -115,9 +115,13 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
# Load shared assets to all devices.
|
# Load shared assets to all devices.
|
||||||
with self.strategy.scope():
|
with self.strategy.scope():
|
||||||
|
|
||||||
|
def embed_layer_generator(_, output_dims):
|
||||||
|
return tf.keras.layers.Embedding(10, *output_dims)
|
||||||
|
|
||||||
model = common_test_utils.get_model_from_generator(
|
model = common_test_utils.get_model_from_generator(
|
||||||
model_generator=get_embedding_model_generators()[model_name],
|
model_generator=get_embedding_model_generators()[model_name],
|
||||||
layer_generator=None,
|
layer_generator=embed_layer_generator,
|
||||||
input_dims=embed_indices.shape[1:],
|
input_dims=embed_indices.shape[1:],
|
||||||
output_dims=[output_dim],
|
output_dims=[output_dim],
|
||||||
is_eager=is_eager,
|
is_eager=is_eager,
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Fast clipping function for `tfm.nlp.layers.OnDeviceEmbedding`."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils
|
||||||
|
|
||||||
|
|
||||||
|
def nlp_on_device_embedding_layer_computation(
|
||||||
|
layer_instance: tf.keras.layers.Layer,
|
||||||
|
input_args: Tuple[Any, ...],
|
||||||
|
input_kwargs: Dict[str, Any],
|
||||||
|
tape: tf.GradientTape,
|
||||||
|
num_microbatches: Optional[tf.Tensor] = None,
|
||||||
|
) -> type_aliases.RegistryFunctionOutput:
|
||||||
|
"""Registry function for `tfm.nlp.layers.OnDeviceEmbedding`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_instance: A `tfm.nlp.layers.OnDeviceEmbedding` instance.
|
||||||
|
input_args: See `dense_layer_computation()` in `dense.py`.
|
||||||
|
input_kwargs: See `dense_layer_computation()` in `dense.py`.
|
||||||
|
tape: See `dense_layer_computation()` in `dense.py`.
|
||||||
|
num_microbatches: See `dense_layer_computation()` in `dense.py`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
See `dense_layer_computation()` in `dense.py`.
|
||||||
|
"""
|
||||||
|
if input_kwargs:
|
||||||
|
raise ValueError("Embedding layer calls should not receive kwargs.")
|
||||||
|
del input_kwargs
|
||||||
|
if len(input_args) != 1:
|
||||||
|
raise ValueError("Only layer inputs of length 1 are permitted.")
|
||||||
|
if hasattr(layer_instance, "_use_one_hot"):
|
||||||
|
if layer_instance._use_one_hot: # pylint: disable=protected-access
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The embedding feature '_use_one_hot' is not supported."
|
||||||
|
)
|
||||||
|
# NOTE: Since the implementation of `tfm.nlp.layers.OnDeviceEmbedding` uses
|
||||||
|
# `.set_shape()`, we can assume that inputs are not ragged.
|
||||||
|
input_ids = tf.cast(*input_args, tf.int32)
|
||||||
|
if len(layer_instance.trainable_variables) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Only layer instances with only one set of trainable variables"
|
||||||
|
"are permitted."
|
||||||
|
)
|
||||||
|
base_vars = layer_instance.trainable_variables[0]
|
||||||
|
tape.watch(base_vars)
|
||||||
|
outputs = layer_instance(input_ids)
|
||||||
|
|
||||||
|
def sqr_norm_fn(base_vars_grads: tf.IndexedSlices):
|
||||||
|
return registry_function_utils.embedding_sqr_norm_fn(
|
||||||
|
base_vars_grads.values,
|
||||||
|
input_ids,
|
||||||
|
num_microbatches,
|
||||||
|
)
|
||||||
|
|
||||||
|
return base_vars, outputs, sqr_norm_fn
|
|
@ -0,0 +1,159 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_models as tfm
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import nlp_on_device_embedding
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Helper functions.
|
||||||
|
# ==============================================================================
|
||||||
|
def get_nlp_on_device_embedding_model_generators():
|
||||||
|
return {
|
||||||
|
'bow1': common_test_utils.make_bow_model,
|
||||||
|
'bow2': common_test_utils.make_dense_bow_model,
|
||||||
|
'weighted_bow1': common_test_utils.make_weighted_bow_model,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_nlp_on_device_embedding_inputs():
|
||||||
|
"""Generates input_data."""
|
||||||
|
return [
|
||||||
|
# 2D inputs.
|
||||||
|
[[0, 1]],
|
||||||
|
[[0, 1], [1, 1], [0, 0]],
|
||||||
|
# 3D inputs.
|
||||||
|
[[[0, 1]]],
|
||||||
|
[[[0, 1]], [[1, 1]], [[0, 0]]],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_nlp_on_device_embedding_layer_registries():
|
||||||
|
dbl_registry = layer_registry.LayerRegistry()
|
||||||
|
dbl_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
|
||||||
|
dbl_registry.insert(
|
||||||
|
tfm.nlp.layers.OnDeviceEmbedding,
|
||||||
|
nlp_on_device_embedding.nlp_on_device_embedding_layer_computation
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
'embed_and_dense': dbl_registry,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Main tests.
|
||||||
|
# ==============================================================================
|
||||||
|
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.strategy = tf.distribute.get_strategy()
|
||||||
|
|
||||||
|
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
|
||||||
|
# supports them for embeddings.
|
||||||
|
@parameterized.product(
|
||||||
|
input_data=get_nlp_on_device_embedding_inputs(),
|
||||||
|
scale_factor=[None, 0.5, 1.0],
|
||||||
|
model_name=list(
|
||||||
|
get_nlp_on_device_embedding_model_generators().keys()
|
||||||
|
),
|
||||||
|
output_dim=[2],
|
||||||
|
layer_registry_name=list(
|
||||||
|
get_nlp_on_device_embedding_layer_registries().keys()
|
||||||
|
),
|
||||||
|
num_microbatches=[None, 2],
|
||||||
|
is_eager=[True, False],
|
||||||
|
partial=[True, False],
|
||||||
|
)
|
||||||
|
def test_gradient_norms_on_various_models(
|
||||||
|
self,
|
||||||
|
input_data,
|
||||||
|
scale_factor,
|
||||||
|
model_name,
|
||||||
|
output_dim,
|
||||||
|
layer_registry_name,
|
||||||
|
num_microbatches,
|
||||||
|
is_eager,
|
||||||
|
partial,
|
||||||
|
):
|
||||||
|
# Parse inputs to generate test data.
|
||||||
|
embed_indices = tf.convert_to_tensor(input_data, dtype_hint=tf.int32)
|
||||||
|
|
||||||
|
# The following are invalid test combinations and, hence, are skipped.
|
||||||
|
batch_size = embed_indices.shape[0]
|
||||||
|
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
|
||||||
|
if num_microbatches is not None and batch_size % num_microbatches != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load shared assets to all devices.
|
||||||
|
with self.strategy.scope():
|
||||||
|
|
||||||
|
def embed_layer_generator(_, output_dims):
|
||||||
|
return tfm.nlp.layers.OnDeviceEmbedding(
|
||||||
|
10,
|
||||||
|
*output_dims,
|
||||||
|
scale_factor=scale_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
model = common_test_utils.get_model_from_generator(
|
||||||
|
model_generator=(
|
||||||
|
get_nlp_on_device_embedding_model_generators()[model_name]
|
||||||
|
),
|
||||||
|
layer_generator=embed_layer_generator,
|
||||||
|
input_dims=embed_indices.shape[1:],
|
||||||
|
output_dims=[output_dim],
|
||||||
|
is_eager=is_eager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define the main testing ops. These may be later compiled to a Graph op.
|
||||||
|
def test_op(x_batch):
|
||||||
|
return common_test_utils.get_computed_and_true_norms_from_model(
|
||||||
|
model=model,
|
||||||
|
per_example_loss_fn=None,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
x_batch=x_batch,
|
||||||
|
registry=(
|
||||||
|
get_nlp_on_device_embedding_layer_registries()[
|
||||||
|
layer_registry_name
|
||||||
|
]
|
||||||
|
),
|
||||||
|
partial=partial,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TPUs can only run `tf.function`-decorated functions.
|
||||||
|
if using_tpu:
|
||||||
|
test_op = tf.function(test_op, autograph=False)
|
||||||
|
|
||||||
|
# Set up the device ops and run the test.
|
||||||
|
computed_norms, true_norms = self.strategy.run(
|
||||||
|
test_op, args=(embed_indices,)
|
||||||
|
)
|
||||||
|
# TPUs return replica contexts, which must be unwrapped.
|
||||||
|
if using_tpu:
|
||||||
|
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||||
|
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||||
|
computed_norms = computed_norms.values[0]
|
||||||
|
true_norms = true_norms.values[0]
|
||||||
|
expected_size = num_microbatches or batch_size
|
||||||
|
self.assertEqual(tf.shape(computed_norms)[0], expected_size)
|
||||||
|
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import nlp_on_device_embedding_test
|
||||||
|
|
||||||
|
|
||||||
|
class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.strategy = common_test_utils.create_tpu_strategy()
|
||||||
|
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,131 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Useful utility functions for implementing registry functions."""
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
|
||||||
|
|
||||||
|
def embedding_sqr_norm_fn(
|
||||||
|
grad_values: tf.Tensor,
|
||||||
|
input_ids: Union[tf.Tensor, tf.RaggedTensor],
|
||||||
|
num_microbatches: Union[tf.Tensor, None] = None,
|
||||||
|
)->type_aliases.RegistryFunctionOutput:
|
||||||
|
"""Fast square norm function for general embedding layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grad_values: Batched embedding gradient values. These come from the 'values'
|
||||||
|
field of the `IndexedSlices` object representing the embedding gradient.
|
||||||
|
input_ids: A `tf.Tensor` of queried embedding ids. The values in this input
|
||||||
|
are the same as the `indices` field of the `IndexedSlices` object
|
||||||
|
representing the embedding gradient.
|
||||||
|
num_microbatches: An optional number of microbatches.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 1D `tf.Tensor` of squared gradient norms.
|
||||||
|
|
||||||
|
NOTE: to help understand the code, we document in the function body what
|
||||||
|
the expected intermediate variables are for the below running example:
|
||||||
|
|
||||||
|
input_ids = [[1, 1, 2], [0], [2, 0]]
|
||||||
|
grads_values = [[0.2], [0.2], [0.3], [0.1], [0.3], [0.1]]
|
||||||
|
|
||||||
|
For ease of reference, we also list these variables below:
|
||||||
|
|
||||||
|
row_indices = [[0], [0], [0], [1], [2], [2]]
|
||||||
|
flattened_indices = [[1], [1], [2], [0], [2], [0]]
|
||||||
|
paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||||
|
unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||||
|
new_index_positions = [0, 0, 1, 2, 3, 4]
|
||||||
|
num_unique_paired_indices = 5
|
||||||
|
unique_batch_ids = [0, 0, 1, 2, 2]
|
||||||
|
summed_gradients
|
||||||
|
= [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
|
||||||
|
= [[0.4], [0.3], [0.1], [0.3], [0.1]]
|
||||||
|
sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
|
||||||
|
"""
|
||||||
|
# We first get a 1D tensor of the row indices.
|
||||||
|
nrows = tf.shape(input_ids)[0]
|
||||||
|
if isinstance(input_ids, tf.RaggedTensor):
|
||||||
|
row_indices = tf.expand_dims(
|
||||||
|
input_ids.merge_dims(1, -1).value_rowids(), axis=-1
|
||||||
|
)
|
||||||
|
elif isinstance(input_ids, tf.Tensor):
|
||||||
|
ncols = tf.reduce_prod(tf.shape(input_ids)[1:])
|
||||||
|
repeats = tf.repeat(ncols, nrows)
|
||||||
|
row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
|
||||||
|
)
|
||||||
|
row_indices = tf.cast(row_indices, tf.int64)
|
||||||
|
if num_microbatches is not None:
|
||||||
|
microbatch_size = tf.cast(nrows / num_microbatches, tf.int64)
|
||||||
|
nrows = num_microbatches
|
||||||
|
row_indices = tf.cast(
|
||||||
|
tf.math.floordiv(row_indices, microbatch_size), tf.int64
|
||||||
|
)
|
||||||
|
# NOTE: expected values for the running example above are
|
||||||
|
# row_indices = [[0], [0], [0], [1], [2], [2]]
|
||||||
|
|
||||||
|
# Now, sum-reduce the `IndexedSlices` that is the result of a
|
||||||
|
# `tape.gradient()` call. The sum is reduced by the repeated embedding
|
||||||
|
# indices and batch index. It is adapted from the logic in:
|
||||||
|
# tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices
|
||||||
|
flattened_indices = tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1)
|
||||||
|
paired_indices = tf.concat(
|
||||||
|
[
|
||||||
|
tf.cast(row_indices, tf.int64),
|
||||||
|
tf.cast(flattened_indices, tf.int64)
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
(unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2(
|
||||||
|
x=paired_indices, axis=[0]
|
||||||
|
)
|
||||||
|
# NOTE: expected values for the running example above are
|
||||||
|
# flattened_indices = [[1], [1], [2], [0], [2], [0]]
|
||||||
|
# paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||||
|
# unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||||
|
# new_index_positions = [0, 0, 1, 2, 3, 4]
|
||||||
|
|
||||||
|
# Next, sum according to the new positions and compute the squared
|
||||||
|
# gradient norms. Oddly enough, not sorting
|
||||||
|
# these indices will break tensor shape inference logic on TPUs.
|
||||||
|
num_unique_paired_indices = tf.shape(unique_paired_indices)[0]
|
||||||
|
unique_batch_ids = unique_paired_indices[:, 0]
|
||||||
|
summed_gradients = tf.math.unsorted_segment_sum(
|
||||||
|
grad_values,
|
||||||
|
new_index_positions,
|
||||||
|
num_unique_paired_indices,
|
||||||
|
)
|
||||||
|
sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1)
|
||||||
|
# NOTE: expected values for the running example above are
|
||||||
|
# num_unique_paired_indices = 5
|
||||||
|
# unique_batch_ids = [0, 0, 1, 2, 2]
|
||||||
|
# summed_gradients
|
||||||
|
# = [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
|
||||||
|
# = [[0.4], [0.3], [0.1], [0.3], [0.1]]
|
||||||
|
# sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
|
||||||
|
|
||||||
|
# Use a scatter-add strategy to ensure TPU compatibility.
|
||||||
|
result = tf.zeros([nrows])
|
||||||
|
return tf.tensor_scatter_nd_add(
|
||||||
|
result,
|
||||||
|
tf.expand_dims(unique_batch_ids, axis=-1),
|
||||||
|
sqr_gradient_sum,
|
||||||
|
)
|
||||||
|
# NOTE: the expected output for the running example is
|
||||||
|
# [0.16 + 0.09, 0.01, 0.09 + 0.01] = [0.25, 0.01, 0.1]
|
Loading…
Reference in a new issue