Modify fast clipping logic to support computation on TPUs.

PiperOrigin-RevId: 550673798
This commit is contained in:
A. Unique TensorFlower 2023-07-24 14:28:11 -07:00
parent cb6659d11b
commit c1c97f1c1c
7 changed files with 417 additions and 124 deletions

View file

@ -15,7 +15,9 @@
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow.compat.v2 as tf_compat
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
@ -24,6 +26,23 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
# ============================================================================== # ==============================================================================
# Helper functions # Helper functions
# ============================================================================== # ==============================================================================
def create_tpu_strategy():
"""Initializes a TPU environment."""
# Done to avoid transferring data between CPUs and TPUs.
tf_compat.config.set_soft_device_placement(False)
resolver = tf_compat.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf_compat.config.experimental_connect_to_cluster(resolver)
tf_compat.tpu.experimental.initialize_tpu_system(resolver)
return tf_compat.distribute.TPUStrategy(resolver)
def assert_replica_values_are_close(test_case_obj, replica_context):
"""Checks if all replica context tensors are near each other."""
base_tensor = replica_context.values[0]
for t in replica_context.values[1:]:
test_case_obj.assertAllClose(base_tensor, t)
def get_nd_test_batches(n: int): def get_nd_test_batches(n: int):
"""Returns a list of input batches of dimension n.""" """Returns a list of input batches of dimension n."""
# The first two batches have a single element, the last batch has 2 elements. # The first two batches have a single element, the last batch has 2 elements.
@ -79,13 +98,76 @@ def compute_true_gradient_norms(
trainable_vars = trainable_vars or input_model.trainable_variables trainable_vars = trainable_vars or input_model.trainable_variables
for var in trainable_vars: for var in trainable_vars:
jacobian = tape.jacobian(loss, var, experimental_use_pfor=False) jacobian = tape.jacobian(loss, var, experimental_use_pfor=False)
reduction_axes = tf.range(1, len(jacobian.shape)) reduction_axes = tf.range(1, tf.rank(jacobian))
sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes) sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes)
sqr_norms.append(sqr_norm) sqr_norms.append(sqr_norm)
sqr_norm_tsr = tf.stack(sqr_norms, axis=1) sqr_norm_tsr = tf.stack(sqr_norms, axis=1)
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
def get_model_from_generator(
model_generator: type_aliases.ModelGenerator,
layer_generator: type_aliases.LayerGenerator,
input_dims: Union[int, List[int]],
output_dim: int,
is_eager: bool,
) -> tf.keras.Model:
"""Creates a simple model from input specifications."""
model = model_generator(layer_generator, input_dims, output_dim)
model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=1.0),
loss=tf.keras.losses.MeanSquaredError(
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
),
run_eagerly=is_eager,
)
return model
def get_computed_and_true_norms_from_model(
model: tf.keras.Model,
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
num_microbatches: Optional[int],
x_batch: tf.Tensor,
weight_batch: Optional[tf.Tensor] = None,
rng_seed: int = 777,
registry: layer_registry.LayerRegistry = None,
partial: bool = False,
):
"""Generates relevant norms from an input model and other specs."""
trainable_vars = None
if partial:
# Gets the first layer with variables.
for l in model.layers:
trainable_vars = l.trainable_variables
if trainable_vars:
break
y_pred = model(x_batch)
y_batch = tf.ones_like(y_pred)
tf.keras.utils.set_random_seed(rng_seed)
computed_norms = clip_grads.compute_gradient_norms(
input_model=model,
x_batch=x_batch,
y_batch=y_batch,
weight_batch=weight_batch,
layer_registry=registry,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
trainable_vars=trainable_vars,
)
tf.keras.utils.set_random_seed(rng_seed)
true_norms = compute_true_gradient_norms(
input_model=model,
x_batch=x_batch,
y_batch=y_batch,
weight_batch=weight_batch,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
trainable_vars=trainable_vars,
)
return computed_norms, true_norms
def get_computed_and_true_norms( def get_computed_and_true_norms(
model_generator: type_aliases.ModelGenerator, model_generator: type_aliases.ModelGenerator,
layer_generator: type_aliases.LayerGenerator, layer_generator: type_aliases.LayerGenerator,
@ -133,45 +215,23 @@ def get_computed_and_true_norms(
model and layer generators. The second element contains the true clipped model and layer generators. The second element contains the true clipped
gradient norms under the aforementioned setting. gradient norms under the aforementioned setting.
""" """
model = model_generator(layer_generator, input_dims, output_dim) model = get_model_from_generator(
model.compile( model_generator=model_generator,
optimizer=tf.keras.optimizers.SGD(learning_rate=1.0), layer_generator=layer_generator,
loss=tf.keras.losses.MeanSquaredError( input_dims=input_dims,
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE output_dim=output_dim,
), is_eager=is_eager,
run_eagerly=is_eager,
) )
trainable_vars = None return get_computed_and_true_norms_from_model(
if partial: model=model,
# Gets the first layer with variables.
for l in model.layers:
trainable_vars = l.trainable_variables
if trainable_vars:
break
y_pred = model(x_batch)
y_batch = tf.ones_like(y_pred)
tf.keras.utils.set_random_seed(rng_seed)
computed_norms = clip_grads.compute_gradient_norms(
input_model=model,
x_batch=x_batch,
y_batch=y_batch,
weight_batch=weight_batch,
layer_registry=registry,
per_example_loss_fn=per_example_loss_fn, per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
trainable_vars=trainable_vars, x_batch=x_batch,
weight_batch=weight_batch,
rng_seed=rng_seed,
registry=registry,
partial=partial,
) )
tf.keras.utils.set_random_seed(rng_seed)
true_norms = compute_true_gradient_norms(
model,
x_batch,
y_batch,
weight_batch,
per_example_loss_fn,
num_microbatches,
trainable_vars=trainable_vars,
)
return (computed_norms, true_norms)
# ============================================================================== # ==============================================================================
@ -234,7 +294,11 @@ def make_bow_model(layer_generator, input_dims, output_dim):
# in eache example. # in eache example.
emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim) emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim)
feature_embs = emb_layer(inputs) feature_embs = emb_layer(inputs)
reduction_axes = tf.range(1, len(feature_embs.shape)) # Embeddings add one extra dimension to its inputs, which combined with the
# batch dimension at dimension 0, equals two additional dimensions compared
# to the number of input dimensions. Here, we want to reduce over the output
# space, but exclude the batch dimension.
reduction_axes = range(1, len(input_dims) + 2)
example_embs = tf.expand_dims( example_embs = tf.expand_dims(
tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1 tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1
) )
@ -253,7 +317,11 @@ def make_dense_bow_model(layer_generator, input_dims, output_dim):
input_dim=cardinality, output_dim=output_dim input_dim=cardinality, output_dim=output_dim
) )
feature_embs = emb_layer(inputs) feature_embs = emb_layer(inputs)
reduction_axes = tf.range(1, len(feature_embs.shape)) # Embeddings add one extra dimension to its inputs, which combined with the
# batch dimension at dimension 0, equals two additional dimensions compared
# to the number of input dimensions. Here, we want to reduce over the output
# space, but exclude the batch dimension.
reduction_axes = range(1, len(input_dims) + 2)
example_embs = tf.expand_dims( example_embs = tf.expand_dims(
tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1 tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1
) )
@ -274,9 +342,21 @@ def make_weighted_bow_model(layer_generator, input_dims, output_dim):
input_dim=cardinality, output_dim=output_dim input_dim=cardinality, output_dim=output_dim
) )
feature_embs = emb_layer(inputs) feature_embs = emb_layer(inputs)
feature_weights = tf.random.uniform(tf.shape(feature_embs)) # Use deterministic weights to avoid seeding issues on TPUs.
feature_shape = input_dims + [output_dim]
feature_weights = tf.expand_dims(
tf.reshape(
tf.range(np.product(feature_shape), dtype=tf.float32),
feature_shape,
),
axis=0,
)
weighted_embs = feature_embs * feature_weights weighted_embs = feature_embs * feature_weights
reduction_axes = tf.range(1, len(weighted_embs.shape)) # Embeddings add one extra dimension to its inputs, which combined with the
# batch dimension at dimension 0, equals two additional dimensions compared
# to the number of input dimensions. Here, we want to reduce over the output
# space, but exclude the batch dimension.
reduction_axes = range(1, len(input_dims) + 2)
example_embs = tf.expand_dims( example_embs = tf.expand_dims(
tf.reduce_sum(weighted_embs, axis=reduction_axes), axis=-1 tf.reduce_sum(weighted_embs, axis=reduction_axes), axis=-1
) )

View file

@ -19,9 +19,10 @@ py_test(
size = "large", size = "large",
srcs = ["dense_test.py"], srcs = ["dense_test.py"],
python_version = "PY3", python_version = "PY3",
shard_count = 8, shard_count = 12,
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
":dense",
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads", "//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
@ -37,12 +38,13 @@ py_library(
py_test( py_test(
name = "embedding_test", name = "embedding_test",
size = "large",
srcs = ["embedding_test.py"], srcs = ["embedding_test.py"],
python_version = "PY3", python_version = "PY3",
shard_count = 8, shard_count = 12,
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
":dense",
":embedding",
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads", "//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",

View file

@ -16,6 +16,7 @@ from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
# ============================================================================== # ==============================================================================
@ -40,16 +41,30 @@ def get_dense_model_generators():
} }
def get_dense_layer_registries():
dense_registry = layer_registry.LayerRegistry()
dense_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
return {
'dense_only': dense_registry,
'default': layer_registry.make_default_layer_registry(),
}
# ============================================================================== # ==============================================================================
# Main tests. # Main tests.
# ============================================================================== # ==============================================================================
class GradNormTest(tf.test.TestCase, parameterized.TestCase): class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.strategy = tf.distribute.get_strategy()
@parameterized.product( @parameterized.product(
model_name=list(get_dense_model_generators().keys()), model_name=list(get_dense_model_generators().keys()),
layer_name=list(get_dense_layer_generators().keys()), layer_name=list(get_dense_layer_generators().keys()),
input_dim=[4], input_dim=[4],
output_dim=[2], output_dim=[2],
layer_registry_name=list(get_dense_layer_registries().keys()),
per_example_loss_fn=[None, common_test_utils.test_loss_fn], per_example_loss_fn=[None, common_test_utils.test_loss_fn],
num_microbatches=[None, 1, 2], num_microbatches=[None, 1, 2],
is_eager=[True, False], is_eager=[True, False],
@ -62,38 +77,74 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
layer_name, layer_name,
input_dim, input_dim,
output_dim, output_dim,
layer_registry_name,
per_example_loss_fn, per_example_loss_fn,
num_microbatches, num_microbatches,
is_eager, is_eager,
partial, partial,
weighted, weighted,
): ):
model_generator = get_dense_model_generators()[model_name] # Parse inputs to generate test data.
layer_generator = get_dense_layer_generators()[layer_name]
x_batches, weight_batches = common_test_utils.get_nd_test_batches(input_dim) x_batches, weight_batches = common_test_utils.get_nd_test_batches(input_dim)
default_registry = layer_registry.make_default_layer_registry()
# Load shared assets to all devices.
with self.strategy.scope():
model = common_test_utils.get_model_from_generator(
model_generator=get_dense_model_generators()[model_name],
layer_generator=get_dense_layer_generators()[layer_name],
input_dims=input_dim,
output_dim=output_dim,
is_eager=is_eager,
)
# Define the main testing ops. These may be later compiled to a Graph op.
def test_op(x_batch, weight_batch):
return common_test_utils.get_computed_and_true_norms_from_model(
model=model,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch,
weight_batch=weight_batch if weighted else None,
registry=get_dense_layer_registries()[layer_registry_name],
partial=partial,
)
# TPUs can only run `tf.function`-decorated functions.
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
if using_tpu:
test_op = tf.function(test_op, jit_compile=True, autograph=False)
# TPUs use lower precision than CPUs, so we relax our criterion.
# E.g., one of the TPU runs generated the following results:
#
# computed_norm = 22.530651
# true_norm = 22.570976
# abs_diff = 0.04032516
# rel_diff = 0.00178659
#
# which is a reasonable level of error for computing gradient norms.
# Other trials also give an absolute (resp. relative) error of around
# 0.05 (resp. 0.0015).
rtol = 1e-2 if using_tpu else 1e-3
atol = 1e-1 if using_tpu else 1e-2
for x_batch, weight_batch in zip(x_batches, weight_batches): for x_batch, weight_batch in zip(x_batches, weight_batches):
batch_size = x_batch.shape[0] batch_size = x_batch.shape[0]
if num_microbatches is not None and batch_size % num_microbatches != 0: if num_microbatches is not None and batch_size % num_microbatches != 0:
continue continue
computed_norms, true_norms = ( # Set up the device ops and run the test.
common_test_utils.get_computed_and_true_norms( computed_norms, true_norms = self.strategy.run(
model_generator, test_op, args=(x_batch, weight_batch)
layer_generator,
input_dim,
output_dim,
per_example_loss_fn,
num_microbatches,
is_eager,
x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch,
weight_batch=weight_batch if weighted else None,
registry=default_registry,
partial=partial,
)
) )
# TPUs return replica contexts, which must be unwrapped.
if using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0]
true_norms = true_norms.values[0]
expected_size = num_microbatches or batch_size expected_size = num_microbatches or batch_size
self.assertEqual(computed_norms.shape[0], expected_size) self.assertEqual(tf.shape(computed_norms)[0], expected_size)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -0,0 +1,29 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense_test
class GradNormTpuTest(dense_test.GradNormTest):
def setUp(self):
super().setUp()
self.strategy = ctu.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
if __name__ == '__main__':
tf.test.main()

View file

@ -62,12 +62,46 @@ def embedding_layer_computation(
"'_use_one_hot_matmul' is not supported." "'_use_one_hot_matmul' is not supported."
) )
input_ids = tf.cast(*input_args, tf.int32) input_ids = tf.cast(*input_args, tf.int32)
if len(layer_instance.trainable_variables) != 1:
raise ValueError(
"Only layer instances with only one set of trainable variables"
"are permitted."
)
base_vars = layer_instance.trainable_variables[0] base_vars = layer_instance.trainable_variables[0]
tape.watch(base_vars) tape.watch(base_vars)
outputs = tf.nn.embedding_lookup(base_vars, input_ids) outputs = tf.nn.embedding_lookup(base_vars, input_ids)
def sqr_norm_fn(base_vars_grads): def sqr_norm_fn(base_vars_grads):
# Get a 1D tensor of the row indices. """Fast square norm function for Keras embedding layers.
Args:
base_vars_grads: A list of batched embedding gradients.
Returns:
A 1D `tf.Tensor` of squared gradient norms.
NOTE: to help understand the code, we document in the function body what
the expected intermediate variables are for the below running example:
input_ids = [[1, 1, 2], [0], [2, 0]]
base_vars_grads.indices = [1, 1, 2, 0, 2, 0]
base_vars_grads.values = [[0.2], [0.2], [0.3], [0.1], [0.3], [0.1]]
For ease of reference, we also list these variables below:
row_indices = [[0], [0], [0], [1], [2], [2]]
slice_indices = [[1], [1], [2], [0], [2], [0]]
paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
new_index_positions = [0, 0, 1, 2, 3, 4]
num_unique_paired_indices = 5
unique_batch_ids = [0, 0, 1, 2, 2]
summed_gradients
= [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
= [[0.4], [0.3], [0.1], [0.3], [0.1]]
sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
"""
# We first get a 1D tensor of the row indices.
nrows = tf.shape(input_ids)[0] nrows = tf.shape(input_ids)[0]
if isinstance(input_ids, tf.RaggedTensor): if isinstance(input_ids, tf.RaggedTensor):
row_indices = tf.expand_dims( row_indices = tf.expand_dims(
@ -81,16 +115,19 @@ def embedding_layer_computation(
raise NotImplementedError( raise NotImplementedError(
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__ "Cannot parse input_ids of type %s" % input_ids.__class__.__name__
) )
row_indices = tf.cast(row_indices, tf.int32) row_indices = tf.cast(row_indices, tf.int64)
if num_microbatches is not None: if num_microbatches is not None:
microbatch_size = tf.cast(nrows / num_microbatches, tf.int32) microbatch_size = tf.cast(nrows / num_microbatches, tf.int64)
nrows = num_microbatches nrows = num_microbatches
row_indices = tf.cast( row_indices = tf.cast(
tf.math.floordiv(row_indices, microbatch_size), tf.int32 tf.math.floordiv(row_indices, microbatch_size), tf.int64
) )
# Sum-reduce the `IndexSlices` that is the result of a `tape.gradient()` # NOTE: expected values for the running example above are
# call. The sum is reduced by the repeated embedding indices and batch # row_indices = [[0], [0], [0], [1], [2], [2]]
# index. It is adapted from the logic in:
# Now, sum-reduce the `IndexedSlices` that is the result of a
# `tape.gradient()` call. The sum is reduced by the repeated embedding
# indices and batch index. It is adapted from the logic in:
# tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices # tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices
if not isinstance(base_vars_grads, tf.IndexedSlices): if not isinstance(base_vars_grads, tf.IndexedSlices):
raise NotImplementedError( raise NotImplementedError(
@ -105,20 +142,39 @@ def embedding_layer_computation(
(unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2( (unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2(
x=paired_indices, axis=[0] x=paired_indices, axis=[0]
) )
# NOTE: expected values for the running example above are
# slice_indices = [[1], [1], [2], [0], [2], [0]]
# paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
# unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
# new_index_positions = [0, 0, 1, 2, 3, 4]
# Next, sum according to the new positions and compute the squared
# gradient norms. Oddly enough, not sorting
# these indices will break tensor shape inference logic on TPUs.
num_unique_paired_indices = tf.shape(unique_paired_indices)[0]
unique_batch_ids = unique_paired_indices[:, 0] unique_batch_ids = unique_paired_indices[:, 0]
summed_gradients = tf.math.unsorted_segment_sum( summed_gradients = tf.math.unsorted_segment_sum(
base_vars_grads.values, base_vars_grads.values,
new_index_positions, new_index_positions,
tf.shape(unique_paired_indices)[0], num_unique_paired_indices,
) )
# Compute the squared gradient norms at the per-example level.
sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1) sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1)
summed_data_range = tf.range(tf.shape(sqr_gradient_sum)[0]) # NOTE: expected values for the running example above are
return tf.sparse.segment_sum( # num_unique_paired_indices = 5
# unique_batch_ids = [0, 0, 1, 2, 2]
# summed_gradients
# = [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
# = [[0.4], [0.3], [0.1], [0.3], [0.1]]
# sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
# Use a scatter-add strategy to ensure TPU compatibility.
result = tf.zeros([nrows])
return tf.tensor_scatter_nd_add(
result,
tf.expand_dims(unique_batch_ids, axis=-1),
sqr_gradient_sum, sqr_gradient_sum,
summed_data_range, )
tf.sort(unique_batch_ids), # NOTE: the expected output for the running example is
num_segments=nrows, # [0.16 + 0.09, 0.01, 0.09 + 0.01] = [0.25, 0.01, 0.1]
) # fill in empty inputs
return base_vars, outputs, sqr_norm_fn return base_vars, outputs, sqr_norm_fn

View file

@ -16,6 +16,8 @@ from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import embedding
# ============================================================================== # ==============================================================================
@ -29,41 +31,50 @@ def get_embedding_model_generators():
} }
def get_embedding_inputs():
"""Generates test pairs of the form (is_ragged, input_data)."""
return [
# 2D inputs.
(False, [[0, 1]]),
(False, [[0, 1], [1, 1], [0, 0]]),
(True, [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]]),
(True, [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1], [0, 1]]),
# 3D inputs.
(False, [[[0, 1]]]),
(False, [[[0, 1]], [[1, 1]], [[0, 0]]]),
(True, [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]),
(True, [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]], [[0]]]),
]
def get_embedding_layer_registries():
dbl_registry = layer_registry.LayerRegistry()
dbl_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
dbl_registry.insert(
tf.keras.layers.Embedding, embedding.embedding_layer_computation
)
return {
'embed_and_dense': dbl_registry,
'default': layer_registry.make_default_layer_registry(),
}
# ============================================================================== # ==============================================================================
# Main tests. # Main tests.
# ============================================================================== # ==============================================================================
class GradNormTest(tf.test.TestCase, parameterized.TestCase): class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.strategy = tf.distribute.get_strategy()
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment # TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
# supports them for embeddings. # supports them for embeddings.
@parameterized.product( @parameterized.product(
x_batch=[ x_inputs=get_embedding_inputs(),
# 2D inputs.
tf.convert_to_tensor([[0, 1]], dtype_hint=tf.int32),
tf.convert_to_tensor([[0, 1], [1, 1], [0, 0]], dtype_hint=tf.int32),
tf.ragged.constant(
[[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32
),
tf.ragged.constant(
[[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1], [0, 1]],
dtype=tf.int32,
),
# 3D inputs.
tf.convert_to_tensor([[[0, 1]]], dtype_hint=tf.int32),
tf.convert_to_tensor(
[[[0, 1]], [[1, 1]], [[0, 0]]], dtype_hint=tf.int32
),
tf.ragged.constant(
[[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]],
dtype=tf.int32,
),
tf.ragged.constant(
[[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]], [[0]]],
dtype=tf.int32,
),
],
model_name=list(get_embedding_model_generators().keys()), model_name=list(get_embedding_model_generators().keys()),
output_dim=[2], output_dim=[2],
layer_registry_name=list(get_embedding_layer_registries().keys()),
per_example_loss_fn=[None, common_test_utils.test_loss_fn], per_example_loss_fn=[None, common_test_utils.test_loss_fn],
num_microbatches=[None, 2], num_microbatches=[None, 2],
is_eager=[True, False], is_eager=[True, False],
@ -71,39 +82,74 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
) )
def test_gradient_norms_on_various_models( def test_gradient_norms_on_various_models(
self, self,
x_batch, x_inputs,
model_name, model_name,
output_dim, output_dim,
layer_registry_name,
per_example_loss_fn, per_example_loss_fn,
num_microbatches, num_microbatches,
is_eager, is_eager,
partial, partial,
): ):
batch_size = x_batch.shape[0] # Parse inputs to generate test data.
# The following are invalid test combinations, and are skipped. is_ragged, input_data = x_inputs
embed_indices = (
tf.ragged.constant(input_data, dtype=tf.int32)
if is_ragged
else tf.convert_to_tensor(input_data, dtype_hint=tf.int32)
)
# The following are invalid test combinations and, hence, are skipped.
batch_size = embed_indices.shape[0]
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
if ( if (
num_microbatches is not None and batch_size % num_microbatches != 0 (num_microbatches is not None and batch_size % num_microbatches != 0)
) or ( or (model_name == 'weighted_bow1' and is_ragged)
model_name == 'weighted_bow1' and isinstance(x_batch, tf.RaggedTensor) or (
# Current clipping ops do not have corresponding TPU kernels.
using_tpu
and is_ragged
)
): ):
return return
default_registry = layer_registry.make_default_layer_registry()
model_generator = get_embedding_model_generators()[model_name] # Load shared assets to all devices.
computed_norms, true_norms = ( with self.strategy.scope():
common_test_utils.get_computed_and_true_norms( model = common_test_utils.get_model_from_generator(
model_generator=model_generator, model_generator=get_embedding_model_generators()[model_name],
layer_generator=None, layer_generator=None,
input_dims=x_batch.shape[1:], input_dims=embed_indices.shape[1:],
output_dim=output_dim, output_dim=output_dim,
is_eager=is_eager,
)
# Define the main testing ops. These may be later compiled to a Graph op.
def test_op(x_batch):
return common_test_utils.get_computed_and_true_norms_from_model(
model=model,
per_example_loss_fn=per_example_loss_fn, per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
is_eager=is_eager,
x_batch=x_batch, x_batch=x_batch,
registry=default_registry, registry=get_embedding_layer_registries()[layer_registry_name],
partial=partial, partial=partial,
) )
# TPUs can only run `tf.function`-decorated functions.
if using_tpu:
test_op = tf.function(test_op, autograph=False)
# Set up the device ops and run the test.
computed_norms, true_norms = self.strategy.run(
test_op, args=(embed_indices,)
) )
self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size) # TPUs return replica contexts, which must be unwrapped.
if using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0]
true_norms = true_norms.values[0]
expected_size = num_microbatches or batch_size
self.assertEqual(tf.shape(computed_norms)[0], expected_size)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)

View file

@ -0,0 +1,29 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import embedding_test
class GradNormTpuTest(embedding_test.GradNormTest):
def setUp(self):
super().setUp()
self.strategy = common_test_utils.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
if __name__ == '__main__':
tf.test.main()