Implement and test a registry function for tfm.nlp.layers.OnDeviceEmbedding.

This CL also moves the common embedding `sqr_norm_fn` logic between
`tf.keras.layers.Embedding` and `tfm.nlp.layers.OnDeviceEmbedding` into a
new registry function utility file.

PiperOrigin-RevId: 564481407
This commit is contained in:
A. Unique TensorFlower 2023-09-11 13:17:34 -07:00
parent 113b27be43
commit bcc0d4927e
10 changed files with 443 additions and 122 deletions

View file

@ -37,3 +37,4 @@ tensorflow-datasets~=4.5
tensorflow-estimator~=2.4
tensorflow-probability~=0.20.0
tensorflow~=2.4
tf-models-official~=2.13

View file

@ -44,6 +44,7 @@ setup(
'tensorflow-estimator~=2.4',
'tensorflow-probability~=0.20.0',
'tensorflow~=2.4',
'tf-models-official~=2.13',
],
packages=find_packages(),
)

View file

@ -294,15 +294,13 @@ def make_bow_model(
output_dims: List[int],
) -> tf.keras.Model:
"""Creates a simple embedding bow model."""
del layer_generator
inputs = tf.keras.Input(shape=input_dims)
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
# For the Embedding layer, input_dim is the vocabulary size. This should
# be distinguished from the input_dim argument, which is the number of ids
# in eache example.
if len(output_dims) != 1:
raise ValueError('Expected `output_dims` to be of size 1.')
output_dim = output_dims[0]
emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim)
emb_layer = layer_generator(input_dims, output_dims)
feature_embs = emb_layer(inputs)
# Embeddings add one extra dimension to its inputs, which combined with the
# batch dimension at dimension 0, equals two additional dimensions compared
@ -321,18 +319,13 @@ def make_dense_bow_model(
output_dims: List[int],
) -> tf.keras.Model:
"""Creates an embedding bow model with a `Dense` layer."""
del layer_generator
inputs = tf.keras.Input(shape=input_dims)
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
# For the Embedding layer, input_dim is the vocabulary size. This should
# be distinguished from the input_dim argument, which is the number of ids
# in eache example.
cardinality = 10
emb_layer = layer_generator(input_dims, output_dims)
if len(output_dims) != 1:
raise ValueError('Expected `output_dims` to be of size 1.')
output_dim = output_dims[0]
emb_layer = tf.keras.layers.Embedding(
input_dim=cardinality, output_dim=output_dim
)
feature_embs = emb_layer(inputs)
# Embeddings add one extra dimension to its inputs, which combined with the
# batch dimension at dimension 0, equals two additional dimensions compared
@ -353,18 +346,14 @@ def make_weighted_bow_model(
) -> tf.keras.Model:
"""Creates a weighted embedding bow model."""
# NOTE: This model only accepts dense input tensors.
del layer_generator
inputs = tf.keras.Input(shape=input_dims)
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
# For the Embedding layer, input_dim is the vocabulary size. This should
# be distinguished from the input_dim argument, which is the number of ids
# in eache example.
cardinality = 10
emb_layer = layer_generator(input_dims, output_dims)
if len(output_dims) != 1:
raise ValueError('Expected `output_dims` to be of size 1.')
output_dim = output_dims[0]
emb_layer = tf.keras.layers.Embedding(
input_dim=cardinality, output_dim=output_dim
)
feature_embs = emb_layer(inputs)
# Use deterministic weights to avoid seeding issues on TPUs.
feature_shape = input_dims + [output_dim]

View file

@ -2,6 +2,13 @@ package(
default_visibility = ["//visibility:public"],
)
py_library(
name = "registry_function_utils",
srcs = ["registry_function_utils.py"],
srcs_version = "PY3",
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases"],
)
py_library(
name = "einsum_utils",
srcs = ["einsum_utils.py"],
@ -45,7 +52,10 @@ py_library(
name = "embedding",
srcs = ["embedding.py"],
srcs_version = "PY3",
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases"],
deps = [
":registry_function_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
],
)
py_test(
@ -63,6 +73,31 @@ py_test(
],
)
py_library(
name = "nlp_on_device_embedding",
srcs = ["nlp_on_device_embedding.py"],
srcs_version = "PY3",
deps = [
":registry_function_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
],
)
py_test(
name = "nlp_on_device_embedding_test",
srcs = ["nlp_on_device_embedding_test.py"],
python_version = "PY3",
shard_count = 6,
srcs_version = "PY3",
deps = [
":dense",
":nlp_on_device_embedding",
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
],
)
py_library(
name = "layer_normalization",
srcs = ["layer_normalization.py"],

View file

@ -16,6 +16,7 @@
from typing import Any, Mapping, Tuple, Union
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils
def embedding_layer_computation(
@ -71,110 +72,11 @@ def embedding_layer_computation(
tape.watch(base_vars)
outputs = tf.nn.embedding_lookup(base_vars, input_ids)
def sqr_norm_fn(base_vars_grads):
"""Fast square norm function for Keras embedding layers.
Args:
base_vars_grads: A list of batched embedding gradients.
Returns:
A 1D `tf.Tensor` of squared gradient norms.
NOTE: to help understand the code, we document in the function body what
the expected intermediate variables are for the below running example:
input_ids = [[1, 1, 2], [0], [2, 0]]
base_vars_grads.indices = [1, 1, 2, 0, 2, 0]
base_vars_grads.values = [[0.2], [0.2], [0.3], [0.1], [0.3], [0.1]]
For ease of reference, we also list these variables below:
row_indices = [[0], [0], [0], [1], [2], [2]]
slice_indices = [[1], [1], [2], [0], [2], [0]]
paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
new_index_positions = [0, 0, 1, 2, 3, 4]
num_unique_paired_indices = 5
unique_batch_ids = [0, 0, 1, 2, 2]
summed_gradients
= [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
= [[0.4], [0.3], [0.1], [0.3], [0.1]]
sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
"""
# We first get a 1D tensor of the row indices.
nrows = tf.shape(input_ids)[0]
if isinstance(input_ids, tf.RaggedTensor):
row_indices = tf.expand_dims(
input_ids.merge_dims(1, -1).value_rowids(), axis=-1
)
elif isinstance(input_ids, tf.Tensor):
ncols = tf.reduce_prod(tf.shape(input_ids)[1:])
repeats = tf.repeat(ncols, nrows)
row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1])
else:
raise NotImplementedError(
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
)
row_indices = tf.cast(row_indices, tf.int64)
if num_microbatches is not None:
microbatch_size = tf.cast(nrows / num_microbatches, tf.int64)
nrows = num_microbatches
row_indices = tf.cast(
tf.math.floordiv(row_indices, microbatch_size), tf.int64
)
# NOTE: expected values for the running example above are
# row_indices = [[0], [0], [0], [1], [2], [2]]
# Now, sum-reduce the `IndexedSlices` that is the result of a
# `tape.gradient()` call. The sum is reduced by the repeated embedding
# indices and batch index. It is adapted from the logic in:
# tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices
if not isinstance(base_vars_grads, tf.IndexedSlices):
raise NotImplementedError(
"Cannot parse embedding gradients of type: %s"
% base_vars_grads.__class__.__name__
)
slice_indices = tf.expand_dims(base_vars_grads.indices, axis=-1)
paired_indices = tf.concat(
[tf.cast(row_indices, tf.int64), tf.cast(slice_indices, tf.int64)],
axis=1,
)
(unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2(
x=paired_indices, axis=[0]
)
# NOTE: expected values for the running example above are
# slice_indices = [[1], [1], [2], [0], [2], [0]]
# paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
# unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
# new_index_positions = [0, 0, 1, 2, 3, 4]
# Next, sum according to the new positions and compute the squared
# gradient norms. Oddly enough, not sorting
# these indices will break tensor shape inference logic on TPUs.
num_unique_paired_indices = tf.shape(unique_paired_indices)[0]
unique_batch_ids = unique_paired_indices[:, 0]
summed_gradients = tf.math.unsorted_segment_sum(
def sqr_norm_fn(base_vars_grads: tf.IndexedSlices):
return registry_function_utils.embedding_sqr_norm_fn(
base_vars_grads.values,
new_index_positions,
num_unique_paired_indices,
input_ids,
num_microbatches,
)
sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1)
# NOTE: expected values for the running example above are
# num_unique_paired_indices = 5
# unique_batch_ids = [0, 0, 1, 2, 2]
# summed_gradients
# = [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
# = [[0.4], [0.3], [0.1], [0.3], [0.1]]
# sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
# Use a scatter-add strategy to ensure TPU compatibility.
result = tf.zeros([nrows])
return tf.tensor_scatter_nd_add(
result,
tf.expand_dims(unique_batch_ids, axis=-1),
sqr_gradient_sum,
)
# NOTE: the expected output for the running example is
# [0.16 + 0.09, 0.01, 0.09 + 0.01] = [0.25, 0.01, 0.1]
return base_vars, outputs, sqr_norm_fn

View file

@ -115,9 +115,13 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
# Load shared assets to all devices.
with self.strategy.scope():
def embed_layer_generator(_, output_dims):
return tf.keras.layers.Embedding(10, *output_dims)
model = common_test_utils.get_model_from_generator(
model_generator=get_embedding_model_generators()[model_name],
layer_generator=None,
layer_generator=embed_layer_generator,
input_dims=embed_indices.shape[1:],
output_dims=[output_dim],
is_eager=is_eager,

View file

@ -0,0 +1,70 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast clipping function for `tfm.nlp.layers.OnDeviceEmbedding`."""
from typing import Any, Dict, Optional, Tuple
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils
def nlp_on_device_embedding_layer_computation(
layer_instance: tf.keras.layers.Layer,
input_args: Tuple[Any, ...],
input_kwargs: Dict[str, Any],
tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None,
) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tfm.nlp.layers.OnDeviceEmbedding`.
Args:
layer_instance: A `tfm.nlp.layers.OnDeviceEmbedding` instance.
input_args: See `dense_layer_computation()` in `dense.py`.
input_kwargs: See `dense_layer_computation()` in `dense.py`.
tape: See `dense_layer_computation()` in `dense.py`.
num_microbatches: See `dense_layer_computation()` in `dense.py`.
Returns:
See `dense_layer_computation()` in `dense.py`.
"""
if input_kwargs:
raise ValueError("Embedding layer calls should not receive kwargs.")
del input_kwargs
if len(input_args) != 1:
raise ValueError("Only layer inputs of length 1 are permitted.")
if hasattr(layer_instance, "_use_one_hot"):
if layer_instance._use_one_hot: # pylint: disable=protected-access
raise NotImplementedError(
"The embedding feature '_use_one_hot' is not supported."
)
# NOTE: Since the implementation of `tfm.nlp.layers.OnDeviceEmbedding` uses
# `.set_shape()`, we can assume that inputs are not ragged.
input_ids = tf.cast(*input_args, tf.int32)
if len(layer_instance.trainable_variables) != 1:
raise ValueError(
"Only layer instances with only one set of trainable variables"
"are permitted."
)
base_vars = layer_instance.trainable_variables[0]
tape.watch(base_vars)
outputs = layer_instance(input_ids)
def sqr_norm_fn(base_vars_grads: tf.IndexedSlices):
return registry_function_utils.embedding_sqr_norm_fn(
base_vars_grads.values,
input_ids,
num_microbatches,
)
return base_vars, outputs, sqr_norm_fn

View file

@ -0,0 +1,159 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import tensorflow as tf
import tensorflow_models as tfm
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import nlp_on_device_embedding
# ==============================================================================
# Helper functions.
# ==============================================================================
def get_nlp_on_device_embedding_model_generators():
return {
'bow1': common_test_utils.make_bow_model,
'bow2': common_test_utils.make_dense_bow_model,
'weighted_bow1': common_test_utils.make_weighted_bow_model,
}
def get_nlp_on_device_embedding_inputs():
"""Generates input_data."""
return [
# 2D inputs.
[[0, 1]],
[[0, 1], [1, 1], [0, 0]],
# 3D inputs.
[[[0, 1]]],
[[[0, 1]], [[1, 1]], [[0, 0]]],
]
def get_nlp_on_device_embedding_layer_registries():
dbl_registry = layer_registry.LayerRegistry()
dbl_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
dbl_registry.insert(
tfm.nlp.layers.OnDeviceEmbedding,
nlp_on_device_embedding.nlp_on_device_embedding_layer_computation
)
return {
'embed_and_dense': dbl_registry,
}
# ==============================================================================
# Main tests.
# ==============================================================================
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.strategy = tf.distribute.get_strategy()
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
# supports them for embeddings.
@parameterized.product(
input_data=get_nlp_on_device_embedding_inputs(),
scale_factor=[None, 0.5, 1.0],
model_name=list(
get_nlp_on_device_embedding_model_generators().keys()
),
output_dim=[2],
layer_registry_name=list(
get_nlp_on_device_embedding_layer_registries().keys()
),
num_microbatches=[None, 2],
is_eager=[True, False],
partial=[True, False],
)
def test_gradient_norms_on_various_models(
self,
input_data,
scale_factor,
model_name,
output_dim,
layer_registry_name,
num_microbatches,
is_eager,
partial,
):
# Parse inputs to generate test data.
embed_indices = tf.convert_to_tensor(input_data, dtype_hint=tf.int32)
# The following are invalid test combinations and, hence, are skipped.
batch_size = embed_indices.shape[0]
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
if num_microbatches is not None and batch_size % num_microbatches != 0:
return
# Load shared assets to all devices.
with self.strategy.scope():
def embed_layer_generator(_, output_dims):
return tfm.nlp.layers.OnDeviceEmbedding(
10,
*output_dims,
scale_factor=scale_factor
)
model = common_test_utils.get_model_from_generator(
model_generator=(
get_nlp_on_device_embedding_model_generators()[model_name]
),
layer_generator=embed_layer_generator,
input_dims=embed_indices.shape[1:],
output_dims=[output_dim],
is_eager=is_eager,
)
# Define the main testing ops. These may be later compiled to a Graph op.
def test_op(x_batch):
return common_test_utils.get_computed_and_true_norms_from_model(
model=model,
per_example_loss_fn=None,
num_microbatches=num_microbatches,
x_batch=x_batch,
registry=(
get_nlp_on_device_embedding_layer_registries()[
layer_registry_name
]
),
partial=partial,
)
# TPUs can only run `tf.function`-decorated functions.
if using_tpu:
test_op = tf.function(test_op, autograph=False)
# Set up the device ops and run the test.
computed_norms, true_norms = self.strategy.run(
test_op, args=(embed_indices,)
)
# TPUs return replica contexts, which must be unwrapped.
if using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0]
true_norms = true_norms.values[0]
expected_size = num_microbatches or batch_size
self.assertEqual(tf.shape(computed_norms)[0], expected_size)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,29 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import nlp_on_device_embedding_test
class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest):
def setUp(self):
super().setUp()
self.strategy = common_test_utils.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,131 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Useful utility functions for implementing registry functions."""
from typing import Union
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
def embedding_sqr_norm_fn(
grad_values: tf.Tensor,
input_ids: Union[tf.Tensor, tf.RaggedTensor],
num_microbatches: Union[tf.Tensor, None] = None,
)->type_aliases.RegistryFunctionOutput:
"""Fast square norm function for general embedding layers.
Args:
grad_values: Batched embedding gradient values. These come from the 'values'
field of the `IndexedSlices` object representing the embedding gradient.
input_ids: A `tf.Tensor` of queried embedding ids. The values in this input
are the same as the `indices` field of the `IndexedSlices` object
representing the embedding gradient.
num_microbatches: An optional number of microbatches.
Returns:
A 1D `tf.Tensor` of squared gradient norms.
NOTE: to help understand the code, we document in the function body what
the expected intermediate variables are for the below running example:
input_ids = [[1, 1, 2], [0], [2, 0]]
grads_values = [[0.2], [0.2], [0.3], [0.1], [0.3], [0.1]]
For ease of reference, we also list these variables below:
row_indices = [[0], [0], [0], [1], [2], [2]]
flattened_indices = [[1], [1], [2], [0], [2], [0]]
paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
new_index_positions = [0, 0, 1, 2, 3, 4]
num_unique_paired_indices = 5
unique_batch_ids = [0, 0, 1, 2, 2]
summed_gradients
= [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
= [[0.4], [0.3], [0.1], [0.3], [0.1]]
sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
"""
# We first get a 1D tensor of the row indices.
nrows = tf.shape(input_ids)[0]
if isinstance(input_ids, tf.RaggedTensor):
row_indices = tf.expand_dims(
input_ids.merge_dims(1, -1).value_rowids(), axis=-1
)
elif isinstance(input_ids, tf.Tensor):
ncols = tf.reduce_prod(tf.shape(input_ids)[1:])
repeats = tf.repeat(ncols, nrows)
row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1])
else:
raise NotImplementedError(
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
)
row_indices = tf.cast(row_indices, tf.int64)
if num_microbatches is not None:
microbatch_size = tf.cast(nrows / num_microbatches, tf.int64)
nrows = num_microbatches
row_indices = tf.cast(
tf.math.floordiv(row_indices, microbatch_size), tf.int64
)
# NOTE: expected values for the running example above are
# row_indices = [[0], [0], [0], [1], [2], [2]]
# Now, sum-reduce the `IndexedSlices` that is the result of a
# `tape.gradient()` call. The sum is reduced by the repeated embedding
# indices and batch index. It is adapted from the logic in:
# tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices
flattened_indices = tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1)
paired_indices = tf.concat(
[
tf.cast(row_indices, tf.int64),
tf.cast(flattened_indices, tf.int64)
],
axis=1,
)
(unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2(
x=paired_indices, axis=[0]
)
# NOTE: expected values for the running example above are
# flattened_indices = [[1], [1], [2], [0], [2], [0]]
# paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
# unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
# new_index_positions = [0, 0, 1, 2, 3, 4]
# Next, sum according to the new positions and compute the squared
# gradient norms. Oddly enough, not sorting
# these indices will break tensor shape inference logic on TPUs.
num_unique_paired_indices = tf.shape(unique_paired_indices)[0]
unique_batch_ids = unique_paired_indices[:, 0]
summed_gradients = tf.math.unsorted_segment_sum(
grad_values,
new_index_positions,
num_unique_paired_indices,
)
sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1)
# NOTE: expected values for the running example above are
# num_unique_paired_indices = 5
# unique_batch_ids = [0, 0, 1, 2, 2]
# summed_gradients
# = [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
# = [[0.4], [0.3], [0.1], [0.3], [0.1]]
# sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
# Use a scatter-add strategy to ensure TPU compatibility.
result = tf.zeros([nrows])
return tf.tensor_scatter_nd_add(
result,
tf.expand_dims(unique_batch_ids, axis=-1),
sqr_gradient_sum,
)
# NOTE: the expected output for the running example is
# [0.16 + 0.09, 0.01, 0.09 + 0.01] = [0.25, 0.01, 0.1]