Modify fast clipping logic to support computation on TPUs.
PiperOrigin-RevId: 550673798
This commit is contained in:
parent
cb6659d11b
commit
c1c97f1c1c
7 changed files with 417 additions and 124 deletions
|
@ -15,7 +15,9 @@
|
|||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v2 as tf_compat
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||
|
@ -24,6 +26,23 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
|||
# ==============================================================================
|
||||
# Helper functions
|
||||
# ==============================================================================
|
||||
def create_tpu_strategy():
|
||||
"""Initializes a TPU environment."""
|
||||
# Done to avoid transferring data between CPUs and TPUs.
|
||||
tf_compat.config.set_soft_device_placement(False)
|
||||
resolver = tf_compat.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
||||
tf_compat.config.experimental_connect_to_cluster(resolver)
|
||||
tf_compat.tpu.experimental.initialize_tpu_system(resolver)
|
||||
return tf_compat.distribute.TPUStrategy(resolver)
|
||||
|
||||
|
||||
def assert_replica_values_are_close(test_case_obj, replica_context):
|
||||
"""Checks if all replica context tensors are near each other."""
|
||||
base_tensor = replica_context.values[0]
|
||||
for t in replica_context.values[1:]:
|
||||
test_case_obj.assertAllClose(base_tensor, t)
|
||||
|
||||
|
||||
def get_nd_test_batches(n: int):
|
||||
"""Returns a list of input batches of dimension n."""
|
||||
# The first two batches have a single element, the last batch has 2 elements.
|
||||
|
@ -79,13 +98,76 @@ def compute_true_gradient_norms(
|
|||
trainable_vars = trainable_vars or input_model.trainable_variables
|
||||
for var in trainable_vars:
|
||||
jacobian = tape.jacobian(loss, var, experimental_use_pfor=False)
|
||||
reduction_axes = tf.range(1, len(jacobian.shape))
|
||||
reduction_axes = tf.range(1, tf.rank(jacobian))
|
||||
sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes)
|
||||
sqr_norms.append(sqr_norm)
|
||||
sqr_norm_tsr = tf.stack(sqr_norms, axis=1)
|
||||
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
||||
|
||||
|
||||
def get_model_from_generator(
|
||||
model_generator: type_aliases.ModelGenerator,
|
||||
layer_generator: type_aliases.LayerGenerator,
|
||||
input_dims: Union[int, List[int]],
|
||||
output_dim: int,
|
||||
is_eager: bool,
|
||||
) -> tf.keras.Model:
|
||||
"""Creates a simple model from input specifications."""
|
||||
model = model_generator(layer_generator, input_dims, output_dim)
|
||||
model.compile(
|
||||
optimizer=tf.keras.optimizers.SGD(learning_rate=1.0),
|
||||
loss=tf.keras.losses.MeanSquaredError(
|
||||
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
|
||||
),
|
||||
run_eagerly=is_eager,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def get_computed_and_true_norms_from_model(
|
||||
model: tf.keras.Model,
|
||||
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
|
||||
num_microbatches: Optional[int],
|
||||
x_batch: tf.Tensor,
|
||||
weight_batch: Optional[tf.Tensor] = None,
|
||||
rng_seed: int = 777,
|
||||
registry: layer_registry.LayerRegistry = None,
|
||||
partial: bool = False,
|
||||
):
|
||||
"""Generates relevant norms from an input model and other specs."""
|
||||
trainable_vars = None
|
||||
if partial:
|
||||
# Gets the first layer with variables.
|
||||
for l in model.layers:
|
||||
trainable_vars = l.trainable_variables
|
||||
if trainable_vars:
|
||||
break
|
||||
y_pred = model(x_batch)
|
||||
y_batch = tf.ones_like(y_pred)
|
||||
tf.keras.utils.set_random_seed(rng_seed)
|
||||
computed_norms = clip_grads.compute_gradient_norms(
|
||||
input_model=model,
|
||||
x_batch=x_batch,
|
||||
y_batch=y_batch,
|
||||
weight_batch=weight_batch,
|
||||
layer_registry=registry,
|
||||
per_example_loss_fn=per_example_loss_fn,
|
||||
num_microbatches=num_microbatches,
|
||||
trainable_vars=trainable_vars,
|
||||
)
|
||||
tf.keras.utils.set_random_seed(rng_seed)
|
||||
true_norms = compute_true_gradient_norms(
|
||||
input_model=model,
|
||||
x_batch=x_batch,
|
||||
y_batch=y_batch,
|
||||
weight_batch=weight_batch,
|
||||
per_example_loss_fn=per_example_loss_fn,
|
||||
num_microbatches=num_microbatches,
|
||||
trainable_vars=trainable_vars,
|
||||
)
|
||||
return computed_norms, true_norms
|
||||
|
||||
|
||||
def get_computed_and_true_norms(
|
||||
model_generator: type_aliases.ModelGenerator,
|
||||
layer_generator: type_aliases.LayerGenerator,
|
||||
|
@ -133,45 +215,23 @@ def get_computed_and_true_norms(
|
|||
model and layer generators. The second element contains the true clipped
|
||||
gradient norms under the aforementioned setting.
|
||||
"""
|
||||
model = model_generator(layer_generator, input_dims, output_dim)
|
||||
model.compile(
|
||||
optimizer=tf.keras.optimizers.SGD(learning_rate=1.0),
|
||||
loss=tf.keras.losses.MeanSquaredError(
|
||||
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
|
||||
),
|
||||
run_eagerly=is_eager,
|
||||
model = get_model_from_generator(
|
||||
model_generator=model_generator,
|
||||
layer_generator=layer_generator,
|
||||
input_dims=input_dims,
|
||||
output_dim=output_dim,
|
||||
is_eager=is_eager,
|
||||
)
|
||||
trainable_vars = None
|
||||
if partial:
|
||||
# Gets the first layer with variables.
|
||||
for l in model.layers:
|
||||
trainable_vars = l.trainable_variables
|
||||
if trainable_vars:
|
||||
break
|
||||
y_pred = model(x_batch)
|
||||
y_batch = tf.ones_like(y_pred)
|
||||
tf.keras.utils.set_random_seed(rng_seed)
|
||||
computed_norms = clip_grads.compute_gradient_norms(
|
||||
input_model=model,
|
||||
x_batch=x_batch,
|
||||
y_batch=y_batch,
|
||||
weight_batch=weight_batch,
|
||||
layer_registry=registry,
|
||||
return get_computed_and_true_norms_from_model(
|
||||
model=model,
|
||||
per_example_loss_fn=per_example_loss_fn,
|
||||
num_microbatches=num_microbatches,
|
||||
trainable_vars=trainable_vars,
|
||||
x_batch=x_batch,
|
||||
weight_batch=weight_batch,
|
||||
rng_seed=rng_seed,
|
||||
registry=registry,
|
||||
partial=partial,
|
||||
)
|
||||
tf.keras.utils.set_random_seed(rng_seed)
|
||||
true_norms = compute_true_gradient_norms(
|
||||
model,
|
||||
x_batch,
|
||||
y_batch,
|
||||
weight_batch,
|
||||
per_example_loss_fn,
|
||||
num_microbatches,
|
||||
trainable_vars=trainable_vars,
|
||||
)
|
||||
return (computed_norms, true_norms)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -234,7 +294,11 @@ def make_bow_model(layer_generator, input_dims, output_dim):
|
|||
# in eache example.
|
||||
emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim)
|
||||
feature_embs = emb_layer(inputs)
|
||||
reduction_axes = tf.range(1, len(feature_embs.shape))
|
||||
# Embeddings add one extra dimension to its inputs, which combined with the
|
||||
# batch dimension at dimension 0, equals two additional dimensions compared
|
||||
# to the number of input dimensions. Here, we want to reduce over the output
|
||||
# space, but exclude the batch dimension.
|
||||
reduction_axes = range(1, len(input_dims) + 2)
|
||||
example_embs = tf.expand_dims(
|
||||
tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1
|
||||
)
|
||||
|
@ -253,7 +317,11 @@ def make_dense_bow_model(layer_generator, input_dims, output_dim):
|
|||
input_dim=cardinality, output_dim=output_dim
|
||||
)
|
||||
feature_embs = emb_layer(inputs)
|
||||
reduction_axes = tf.range(1, len(feature_embs.shape))
|
||||
# Embeddings add one extra dimension to its inputs, which combined with the
|
||||
# batch dimension at dimension 0, equals two additional dimensions compared
|
||||
# to the number of input dimensions. Here, we want to reduce over the output
|
||||
# space, but exclude the batch dimension.
|
||||
reduction_axes = range(1, len(input_dims) + 2)
|
||||
example_embs = tf.expand_dims(
|
||||
tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1
|
||||
)
|
||||
|
@ -274,9 +342,21 @@ def make_weighted_bow_model(layer_generator, input_dims, output_dim):
|
|||
input_dim=cardinality, output_dim=output_dim
|
||||
)
|
||||
feature_embs = emb_layer(inputs)
|
||||
feature_weights = tf.random.uniform(tf.shape(feature_embs))
|
||||
# Use deterministic weights to avoid seeding issues on TPUs.
|
||||
feature_shape = input_dims + [output_dim]
|
||||
feature_weights = tf.expand_dims(
|
||||
tf.reshape(
|
||||
tf.range(np.product(feature_shape), dtype=tf.float32),
|
||||
feature_shape,
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
weighted_embs = feature_embs * feature_weights
|
||||
reduction_axes = tf.range(1, len(weighted_embs.shape))
|
||||
# Embeddings add one extra dimension to its inputs, which combined with the
|
||||
# batch dimension at dimension 0, equals two additional dimensions compared
|
||||
# to the number of input dimensions. Here, we want to reduce over the output
|
||||
# space, but exclude the batch dimension.
|
||||
reduction_axes = range(1, len(input_dims) + 2)
|
||||
example_embs = tf.expand_dims(
|
||||
tf.reduce_sum(weighted_embs, axis=reduction_axes), axis=-1
|
||||
)
|
||||
|
|
|
@ -19,9 +19,10 @@ py_test(
|
|||
size = "large",
|
||||
srcs = ["dense_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 8,
|
||||
shard_count = 12,
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":dense",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
||||
|
@ -37,12 +38,13 @@ py_library(
|
|||
|
||||
py_test(
|
||||
name = "embedding_test",
|
||||
size = "large",
|
||||
srcs = ["embedding_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 8,
|
||||
shard_count = 12,
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":dense",
|
||||
":embedding",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
||||
|
|
|
@ -16,6 +16,7 @@ from absl.testing import parameterized
|
|||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -40,16 +41,30 @@ def get_dense_model_generators():
|
|||
}
|
||||
|
||||
|
||||
def get_dense_layer_registries():
|
||||
dense_registry = layer_registry.LayerRegistry()
|
||||
dense_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
|
||||
return {
|
||||
'dense_only': dense_registry,
|
||||
'default': layer_registry.make_default_layer_registry(),
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main tests.
|
||||
# ==============================================================================
|
||||
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.strategy = tf.distribute.get_strategy()
|
||||
|
||||
@parameterized.product(
|
||||
model_name=list(get_dense_model_generators().keys()),
|
||||
layer_name=list(get_dense_layer_generators().keys()),
|
||||
input_dim=[4],
|
||||
output_dim=[2],
|
||||
layer_registry_name=list(get_dense_layer_registries().keys()),
|
||||
per_example_loss_fn=[None, common_test_utils.test_loss_fn],
|
||||
num_microbatches=[None, 1, 2],
|
||||
is_eager=[True, False],
|
||||
|
@ -62,38 +77,74 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
|||
layer_name,
|
||||
input_dim,
|
||||
output_dim,
|
||||
layer_registry_name,
|
||||
per_example_loss_fn,
|
||||
num_microbatches,
|
||||
is_eager,
|
||||
partial,
|
||||
weighted,
|
||||
):
|
||||
model_generator = get_dense_model_generators()[model_name]
|
||||
layer_generator = get_dense_layer_generators()[layer_name]
|
||||
# Parse inputs to generate test data.
|
||||
x_batches, weight_batches = common_test_utils.get_nd_test_batches(input_dim)
|
||||
default_registry = layer_registry.make_default_layer_registry()
|
||||
|
||||
# Load shared assets to all devices.
|
||||
with self.strategy.scope():
|
||||
model = common_test_utils.get_model_from_generator(
|
||||
model_generator=get_dense_model_generators()[model_name],
|
||||
layer_generator=get_dense_layer_generators()[layer_name],
|
||||
input_dims=input_dim,
|
||||
output_dim=output_dim,
|
||||
is_eager=is_eager,
|
||||
)
|
||||
|
||||
# Define the main testing ops. These may be later compiled to a Graph op.
|
||||
def test_op(x_batch, weight_batch):
|
||||
return common_test_utils.get_computed_and_true_norms_from_model(
|
||||
model=model,
|
||||
per_example_loss_fn=per_example_loss_fn,
|
||||
num_microbatches=num_microbatches,
|
||||
x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch,
|
||||
weight_batch=weight_batch if weighted else None,
|
||||
registry=get_dense_layer_registries()[layer_registry_name],
|
||||
partial=partial,
|
||||
)
|
||||
|
||||
# TPUs can only run `tf.function`-decorated functions.
|
||||
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
|
||||
if using_tpu:
|
||||
test_op = tf.function(test_op, jit_compile=True, autograph=False)
|
||||
|
||||
# TPUs use lower precision than CPUs, so we relax our criterion.
|
||||
# E.g., one of the TPU runs generated the following results:
|
||||
#
|
||||
# computed_norm = 22.530651
|
||||
# true_norm = 22.570976
|
||||
# abs_diff = 0.04032516
|
||||
# rel_diff = 0.00178659
|
||||
#
|
||||
# which is a reasonable level of error for computing gradient norms.
|
||||
# Other trials also give an absolute (resp. relative) error of around
|
||||
# 0.05 (resp. 0.0015).
|
||||
rtol = 1e-2 if using_tpu else 1e-3
|
||||
atol = 1e-1 if using_tpu else 1e-2
|
||||
|
||||
for x_batch, weight_batch in zip(x_batches, weight_batches):
|
||||
batch_size = x_batch.shape[0]
|
||||
if num_microbatches is not None and batch_size % num_microbatches != 0:
|
||||
continue
|
||||
computed_norms, true_norms = (
|
||||
common_test_utils.get_computed_and_true_norms(
|
||||
model_generator,
|
||||
layer_generator,
|
||||
input_dim,
|
||||
output_dim,
|
||||
per_example_loss_fn,
|
||||
num_microbatches,
|
||||
is_eager,
|
||||
x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch,
|
||||
weight_batch=weight_batch if weighted else None,
|
||||
registry=default_registry,
|
||||
partial=partial,
|
||||
)
|
||||
# Set up the device ops and run the test.
|
||||
computed_norms, true_norms = self.strategy.run(
|
||||
test_op, args=(x_batch, weight_batch)
|
||||
)
|
||||
# TPUs return replica contexts, which must be unwrapped.
|
||||
if using_tpu:
|
||||
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||
computed_norms = computed_norms.values[0]
|
||||
true_norms = true_norms.values[0]
|
||||
expected_size = num_microbatches or batch_size
|
||||
self.assertEqual(computed_norms.shape[0], expected_size)
|
||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||
self.assertEqual(tf.shape(computed_norms)[0], expected_size)
|
||||
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2023, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense_test
|
||||
|
||||
|
||||
class GradNormTpuTest(dense_test.GradNormTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.strategy = ctu.create_tpu_strategy()
|
||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -62,12 +62,46 @@ def embedding_layer_computation(
|
|||
"'_use_one_hot_matmul' is not supported."
|
||||
)
|
||||
input_ids = tf.cast(*input_args, tf.int32)
|
||||
if len(layer_instance.trainable_variables) != 1:
|
||||
raise ValueError(
|
||||
"Only layer instances with only one set of trainable variables"
|
||||
"are permitted."
|
||||
)
|
||||
base_vars = layer_instance.trainable_variables[0]
|
||||
tape.watch(base_vars)
|
||||
outputs = tf.nn.embedding_lookup(base_vars, input_ids)
|
||||
|
||||
def sqr_norm_fn(base_vars_grads):
|
||||
# Get a 1D tensor of the row indices.
|
||||
"""Fast square norm function for Keras embedding layers.
|
||||
|
||||
Args:
|
||||
base_vars_grads: A list of batched embedding gradients.
|
||||
|
||||
Returns:
|
||||
A 1D `tf.Tensor` of squared gradient norms.
|
||||
|
||||
NOTE: to help understand the code, we document in the function body what
|
||||
the expected intermediate variables are for the below running example:
|
||||
|
||||
input_ids = [[1, 1, 2], [0], [2, 0]]
|
||||
base_vars_grads.indices = [1, 1, 2, 0, 2, 0]
|
||||
base_vars_grads.values = [[0.2], [0.2], [0.3], [0.1], [0.3], [0.1]]
|
||||
|
||||
For ease of reference, we also list these variables below:
|
||||
|
||||
row_indices = [[0], [0], [0], [1], [2], [2]]
|
||||
slice_indices = [[1], [1], [2], [0], [2], [0]]
|
||||
paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||
unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||
new_index_positions = [0, 0, 1, 2, 3, 4]
|
||||
num_unique_paired_indices = 5
|
||||
unique_batch_ids = [0, 0, 1, 2, 2]
|
||||
summed_gradients
|
||||
= [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
|
||||
= [[0.4], [0.3], [0.1], [0.3], [0.1]]
|
||||
sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
|
||||
"""
|
||||
# We first get a 1D tensor of the row indices.
|
||||
nrows = tf.shape(input_ids)[0]
|
||||
if isinstance(input_ids, tf.RaggedTensor):
|
||||
row_indices = tf.expand_dims(
|
||||
|
@ -81,16 +115,19 @@ def embedding_layer_computation(
|
|||
raise NotImplementedError(
|
||||
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
|
||||
)
|
||||
row_indices = tf.cast(row_indices, tf.int32)
|
||||
row_indices = tf.cast(row_indices, tf.int64)
|
||||
if num_microbatches is not None:
|
||||
microbatch_size = tf.cast(nrows / num_microbatches, tf.int32)
|
||||
microbatch_size = tf.cast(nrows / num_microbatches, tf.int64)
|
||||
nrows = num_microbatches
|
||||
row_indices = tf.cast(
|
||||
tf.math.floordiv(row_indices, microbatch_size), tf.int32
|
||||
tf.math.floordiv(row_indices, microbatch_size), tf.int64
|
||||
)
|
||||
# Sum-reduce the `IndexSlices` that is the result of a `tape.gradient()`
|
||||
# call. The sum is reduced by the repeated embedding indices and batch
|
||||
# index. It is adapted from the logic in:
|
||||
# NOTE: expected values for the running example above are
|
||||
# row_indices = [[0], [0], [0], [1], [2], [2]]
|
||||
|
||||
# Now, sum-reduce the `IndexedSlices` that is the result of a
|
||||
# `tape.gradient()` call. The sum is reduced by the repeated embedding
|
||||
# indices and batch index. It is adapted from the logic in:
|
||||
# tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices
|
||||
if not isinstance(base_vars_grads, tf.IndexedSlices):
|
||||
raise NotImplementedError(
|
||||
|
@ -105,20 +142,39 @@ def embedding_layer_computation(
|
|||
(unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2(
|
||||
x=paired_indices, axis=[0]
|
||||
)
|
||||
# NOTE: expected values for the running example above are
|
||||
# slice_indices = [[1], [1], [2], [0], [2], [0]]
|
||||
# paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||
# unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||
# new_index_positions = [0, 0, 1, 2, 3, 4]
|
||||
|
||||
# Next, sum according to the new positions and compute the squared
|
||||
# gradient norms. Oddly enough, not sorting
|
||||
# these indices will break tensor shape inference logic on TPUs.
|
||||
num_unique_paired_indices = tf.shape(unique_paired_indices)[0]
|
||||
unique_batch_ids = unique_paired_indices[:, 0]
|
||||
summed_gradients = tf.math.unsorted_segment_sum(
|
||||
base_vars_grads.values,
|
||||
new_index_positions,
|
||||
tf.shape(unique_paired_indices)[0],
|
||||
num_unique_paired_indices,
|
||||
)
|
||||
# Compute the squared gradient norms at the per-example level.
|
||||
sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1)
|
||||
summed_data_range = tf.range(tf.shape(sqr_gradient_sum)[0])
|
||||
return tf.sparse.segment_sum(
|
||||
# NOTE: expected values for the running example above are
|
||||
# num_unique_paired_indices = 5
|
||||
# unique_batch_ids = [0, 0, 1, 2, 2]
|
||||
# summed_gradients
|
||||
# = [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
|
||||
# = [[0.4], [0.3], [0.1], [0.3], [0.1]]
|
||||
# sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
|
||||
|
||||
# Use a scatter-add strategy to ensure TPU compatibility.
|
||||
result = tf.zeros([nrows])
|
||||
return tf.tensor_scatter_nd_add(
|
||||
result,
|
||||
tf.expand_dims(unique_batch_ids, axis=-1),
|
||||
sqr_gradient_sum,
|
||||
summed_data_range,
|
||||
tf.sort(unique_batch_ids),
|
||||
num_segments=nrows,
|
||||
) # fill in empty inputs
|
||||
)
|
||||
# NOTE: the expected output for the running example is
|
||||
# [0.16 + 0.09, 0.01, 0.09 + 0.01] = [0.25, 0.01, 0.1]
|
||||
|
||||
return base_vars, outputs, sqr_norm_fn
|
||||
|
|
|
@ -16,6 +16,8 @@ from absl.testing import parameterized
|
|||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import embedding
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -29,41 +31,50 @@ def get_embedding_model_generators():
|
|||
}
|
||||
|
||||
|
||||
def get_embedding_inputs():
|
||||
"""Generates test pairs of the form (is_ragged, input_data)."""
|
||||
return [
|
||||
# 2D inputs.
|
||||
(False, [[0, 1]]),
|
||||
(False, [[0, 1], [1, 1], [0, 0]]),
|
||||
(True, [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]]),
|
||||
(True, [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1], [0, 1]]),
|
||||
# 3D inputs.
|
||||
(False, [[[0, 1]]]),
|
||||
(False, [[[0, 1]], [[1, 1]], [[0, 0]]]),
|
||||
(True, [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]),
|
||||
(True, [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]], [[0]]]),
|
||||
]
|
||||
|
||||
|
||||
def get_embedding_layer_registries():
|
||||
dbl_registry = layer_registry.LayerRegistry()
|
||||
dbl_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
|
||||
dbl_registry.insert(
|
||||
tf.keras.layers.Embedding, embedding.embedding_layer_computation
|
||||
)
|
||||
return {
|
||||
'embed_and_dense': dbl_registry,
|
||||
'default': layer_registry.make_default_layer_registry(),
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main tests.
|
||||
# ==============================================================================
|
||||
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.strategy = tf.distribute.get_strategy()
|
||||
|
||||
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
|
||||
# supports them for embeddings.
|
||||
@parameterized.product(
|
||||
x_batch=[
|
||||
# 2D inputs.
|
||||
tf.convert_to_tensor([[0, 1]], dtype_hint=tf.int32),
|
||||
tf.convert_to_tensor([[0, 1], [1, 1], [0, 0]], dtype_hint=tf.int32),
|
||||
tf.ragged.constant(
|
||||
[[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32
|
||||
),
|
||||
tf.ragged.constant(
|
||||
[[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1], [0, 1]],
|
||||
dtype=tf.int32,
|
||||
),
|
||||
# 3D inputs.
|
||||
tf.convert_to_tensor([[[0, 1]]], dtype_hint=tf.int32),
|
||||
tf.convert_to_tensor(
|
||||
[[[0, 1]], [[1, 1]], [[0, 0]]], dtype_hint=tf.int32
|
||||
),
|
||||
tf.ragged.constant(
|
||||
[[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]],
|
||||
dtype=tf.int32,
|
||||
),
|
||||
tf.ragged.constant(
|
||||
[[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]], [[0]]],
|
||||
dtype=tf.int32,
|
||||
),
|
||||
],
|
||||
x_inputs=get_embedding_inputs(),
|
||||
model_name=list(get_embedding_model_generators().keys()),
|
||||
output_dim=[2],
|
||||
layer_registry_name=list(get_embedding_layer_registries().keys()),
|
||||
per_example_loss_fn=[None, common_test_utils.test_loss_fn],
|
||||
num_microbatches=[None, 2],
|
||||
is_eager=[True, False],
|
||||
|
@ -71,39 +82,74 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
|||
)
|
||||
def test_gradient_norms_on_various_models(
|
||||
self,
|
||||
x_batch,
|
||||
x_inputs,
|
||||
model_name,
|
||||
output_dim,
|
||||
layer_registry_name,
|
||||
per_example_loss_fn,
|
||||
num_microbatches,
|
||||
is_eager,
|
||||
partial,
|
||||
):
|
||||
batch_size = x_batch.shape[0]
|
||||
# The following are invalid test combinations, and are skipped.
|
||||
# Parse inputs to generate test data.
|
||||
is_ragged, input_data = x_inputs
|
||||
embed_indices = (
|
||||
tf.ragged.constant(input_data, dtype=tf.int32)
|
||||
if is_ragged
|
||||
else tf.convert_to_tensor(input_data, dtype_hint=tf.int32)
|
||||
)
|
||||
|
||||
# The following are invalid test combinations and, hence, are skipped.
|
||||
batch_size = embed_indices.shape[0]
|
||||
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
|
||||
if (
|
||||
num_microbatches is not None and batch_size % num_microbatches != 0
|
||||
) or (
|
||||
model_name == 'weighted_bow1' and isinstance(x_batch, tf.RaggedTensor)
|
||||
(num_microbatches is not None and batch_size % num_microbatches != 0)
|
||||
or (model_name == 'weighted_bow1' and is_ragged)
|
||||
or (
|
||||
# Current clipping ops do not have corresponding TPU kernels.
|
||||
using_tpu
|
||||
and is_ragged
|
||||
)
|
||||
):
|
||||
return
|
||||
default_registry = layer_registry.make_default_layer_registry()
|
||||
model_generator = get_embedding_model_generators()[model_name]
|
||||
computed_norms, true_norms = (
|
||||
common_test_utils.get_computed_and_true_norms(
|
||||
model_generator=model_generator,
|
||||
layer_generator=None,
|
||||
input_dims=x_batch.shape[1:],
|
||||
output_dim=output_dim,
|
||||
per_example_loss_fn=per_example_loss_fn,
|
||||
num_microbatches=num_microbatches,
|
||||
is_eager=is_eager,
|
||||
x_batch=x_batch,
|
||||
registry=default_registry,
|
||||
partial=partial,
|
||||
)
|
||||
|
||||
# Load shared assets to all devices.
|
||||
with self.strategy.scope():
|
||||
model = common_test_utils.get_model_from_generator(
|
||||
model_generator=get_embedding_model_generators()[model_name],
|
||||
layer_generator=None,
|
||||
input_dims=embed_indices.shape[1:],
|
||||
output_dim=output_dim,
|
||||
is_eager=is_eager,
|
||||
)
|
||||
|
||||
# Define the main testing ops. These may be later compiled to a Graph op.
|
||||
def test_op(x_batch):
|
||||
return common_test_utils.get_computed_and_true_norms_from_model(
|
||||
model=model,
|
||||
per_example_loss_fn=per_example_loss_fn,
|
||||
num_microbatches=num_microbatches,
|
||||
x_batch=x_batch,
|
||||
registry=get_embedding_layer_registries()[layer_registry_name],
|
||||
partial=partial,
|
||||
)
|
||||
|
||||
# TPUs can only run `tf.function`-decorated functions.
|
||||
if using_tpu:
|
||||
test_op = tf.function(test_op, autograph=False)
|
||||
|
||||
# Set up the device ops and run the test.
|
||||
computed_norms, true_norms = self.strategy.run(
|
||||
test_op, args=(embed_indices,)
|
||||
)
|
||||
self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size)
|
||||
# TPUs return replica contexts, which must be unwrapped.
|
||||
if using_tpu:
|
||||
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||
computed_norms = computed_norms.values[0]
|
||||
true_norms = true_norms.values[0]
|
||||
expected_size = num_microbatches or batch_size
|
||||
self.assertEqual(tf.shape(computed_norms)[0], expected_size)
|
||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2023, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import embedding_test
|
||||
|
||||
|
||||
class GradNormTpuTest(embedding_test.GradNormTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.strategy = common_test_utils.create_tpu_strategy()
|
||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
Loading…
Reference in a new issue