Modify fast clipping logic to support computation on TPUs.

PiperOrigin-RevId: 550673798
This commit is contained in:
A. Unique TensorFlower 2023-07-24 14:28:11 -07:00
parent cb6659d11b
commit c1c97f1c1c
7 changed files with 417 additions and 124 deletions

View file

@ -15,7 +15,9 @@
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
import tensorflow.compat.v2 as tf_compat
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
@ -24,6 +26,23 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
# ==============================================================================
# Helper functions
# ==============================================================================
def create_tpu_strategy():
"""Initializes a TPU environment."""
# Done to avoid transferring data between CPUs and TPUs.
tf_compat.config.set_soft_device_placement(False)
resolver = tf_compat.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf_compat.config.experimental_connect_to_cluster(resolver)
tf_compat.tpu.experimental.initialize_tpu_system(resolver)
return tf_compat.distribute.TPUStrategy(resolver)
def assert_replica_values_are_close(test_case_obj, replica_context):
"""Checks if all replica context tensors are near each other."""
base_tensor = replica_context.values[0]
for t in replica_context.values[1:]:
test_case_obj.assertAllClose(base_tensor, t)
def get_nd_test_batches(n: int):
"""Returns a list of input batches of dimension n."""
# The first two batches have a single element, the last batch has 2 elements.
@ -79,13 +98,76 @@ def compute_true_gradient_norms(
trainable_vars = trainable_vars or input_model.trainable_variables
for var in trainable_vars:
jacobian = tape.jacobian(loss, var, experimental_use_pfor=False)
reduction_axes = tf.range(1, len(jacobian.shape))
reduction_axes = tf.range(1, tf.rank(jacobian))
sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes)
sqr_norms.append(sqr_norm)
sqr_norm_tsr = tf.stack(sqr_norms, axis=1)
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
def get_model_from_generator(
model_generator: type_aliases.ModelGenerator,
layer_generator: type_aliases.LayerGenerator,
input_dims: Union[int, List[int]],
output_dim: int,
is_eager: bool,
) -> tf.keras.Model:
"""Creates a simple model from input specifications."""
model = model_generator(layer_generator, input_dims, output_dim)
model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=1.0),
loss=tf.keras.losses.MeanSquaredError(
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
),
run_eagerly=is_eager,
)
return model
def get_computed_and_true_norms_from_model(
model: tf.keras.Model,
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
num_microbatches: Optional[int],
x_batch: tf.Tensor,
weight_batch: Optional[tf.Tensor] = None,
rng_seed: int = 777,
registry: layer_registry.LayerRegistry = None,
partial: bool = False,
):
"""Generates relevant norms from an input model and other specs."""
trainable_vars = None
if partial:
# Gets the first layer with variables.
for l in model.layers:
trainable_vars = l.trainable_variables
if trainable_vars:
break
y_pred = model(x_batch)
y_batch = tf.ones_like(y_pred)
tf.keras.utils.set_random_seed(rng_seed)
computed_norms = clip_grads.compute_gradient_norms(
input_model=model,
x_batch=x_batch,
y_batch=y_batch,
weight_batch=weight_batch,
layer_registry=registry,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
trainable_vars=trainable_vars,
)
tf.keras.utils.set_random_seed(rng_seed)
true_norms = compute_true_gradient_norms(
input_model=model,
x_batch=x_batch,
y_batch=y_batch,
weight_batch=weight_batch,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
trainable_vars=trainable_vars,
)
return computed_norms, true_norms
def get_computed_and_true_norms(
model_generator: type_aliases.ModelGenerator,
layer_generator: type_aliases.LayerGenerator,
@ -133,45 +215,23 @@ def get_computed_and_true_norms(
model and layer generators. The second element contains the true clipped
gradient norms under the aforementioned setting.
"""
model = model_generator(layer_generator, input_dims, output_dim)
model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=1.0),
loss=tf.keras.losses.MeanSquaredError(
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
),
run_eagerly=is_eager,
model = get_model_from_generator(
model_generator=model_generator,
layer_generator=layer_generator,
input_dims=input_dims,
output_dim=output_dim,
is_eager=is_eager,
)
trainable_vars = None
if partial:
# Gets the first layer with variables.
for l in model.layers:
trainable_vars = l.trainable_variables
if trainable_vars:
break
y_pred = model(x_batch)
y_batch = tf.ones_like(y_pred)
tf.keras.utils.set_random_seed(rng_seed)
computed_norms = clip_grads.compute_gradient_norms(
input_model=model,
x_batch=x_batch,
y_batch=y_batch,
weight_batch=weight_batch,
layer_registry=registry,
return get_computed_and_true_norms_from_model(
model=model,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
trainable_vars=trainable_vars,
x_batch=x_batch,
weight_batch=weight_batch,
rng_seed=rng_seed,
registry=registry,
partial=partial,
)
tf.keras.utils.set_random_seed(rng_seed)
true_norms = compute_true_gradient_norms(
model,
x_batch,
y_batch,
weight_batch,
per_example_loss_fn,
num_microbatches,
trainable_vars=trainable_vars,
)
return (computed_norms, true_norms)
# ==============================================================================
@ -234,7 +294,11 @@ def make_bow_model(layer_generator, input_dims, output_dim):
# in eache example.
emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim)
feature_embs = emb_layer(inputs)
reduction_axes = tf.range(1, len(feature_embs.shape))
# Embeddings add one extra dimension to its inputs, which combined with the
# batch dimension at dimension 0, equals two additional dimensions compared
# to the number of input dimensions. Here, we want to reduce over the output
# space, but exclude the batch dimension.
reduction_axes = range(1, len(input_dims) + 2)
example_embs = tf.expand_dims(
tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1
)
@ -253,7 +317,11 @@ def make_dense_bow_model(layer_generator, input_dims, output_dim):
input_dim=cardinality, output_dim=output_dim
)
feature_embs = emb_layer(inputs)
reduction_axes = tf.range(1, len(feature_embs.shape))
# Embeddings add one extra dimension to its inputs, which combined with the
# batch dimension at dimension 0, equals two additional dimensions compared
# to the number of input dimensions. Here, we want to reduce over the output
# space, but exclude the batch dimension.
reduction_axes = range(1, len(input_dims) + 2)
example_embs = tf.expand_dims(
tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1
)
@ -274,9 +342,21 @@ def make_weighted_bow_model(layer_generator, input_dims, output_dim):
input_dim=cardinality, output_dim=output_dim
)
feature_embs = emb_layer(inputs)
feature_weights = tf.random.uniform(tf.shape(feature_embs))
# Use deterministic weights to avoid seeding issues on TPUs.
feature_shape = input_dims + [output_dim]
feature_weights = tf.expand_dims(
tf.reshape(
tf.range(np.product(feature_shape), dtype=tf.float32),
feature_shape,
),
axis=0,
)
weighted_embs = feature_embs * feature_weights
reduction_axes = tf.range(1, len(weighted_embs.shape))
# Embeddings add one extra dimension to its inputs, which combined with the
# batch dimension at dimension 0, equals two additional dimensions compared
# to the number of input dimensions. Here, we want to reduce over the output
# space, but exclude the batch dimension.
reduction_axes = range(1, len(input_dims) + 2)
example_embs = tf.expand_dims(
tf.reduce_sum(weighted_embs, axis=reduction_axes), axis=-1
)

View file

@ -19,9 +19,10 @@ py_test(
size = "large",
srcs = ["dense_test.py"],
python_version = "PY3",
shard_count = 8,
shard_count = 12,
srcs_version = "PY3",
deps = [
":dense",
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
@ -37,12 +38,13 @@ py_library(
py_test(
name = "embedding_test",
size = "large",
srcs = ["embedding_test.py"],
python_version = "PY3",
shard_count = 8,
shard_count = 12,
srcs_version = "PY3",
deps = [
":dense",
":embedding",
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",

View file

@ -16,6 +16,7 @@ from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
# ==============================================================================
@ -40,16 +41,30 @@ def get_dense_model_generators():
}
def get_dense_layer_registries():
dense_registry = layer_registry.LayerRegistry()
dense_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
return {
'dense_only': dense_registry,
'default': layer_registry.make_default_layer_registry(),
}
# ==============================================================================
# Main tests.
# ==============================================================================
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.strategy = tf.distribute.get_strategy()
@parameterized.product(
model_name=list(get_dense_model_generators().keys()),
layer_name=list(get_dense_layer_generators().keys()),
input_dim=[4],
output_dim=[2],
layer_registry_name=list(get_dense_layer_registries().keys()),
per_example_loss_fn=[None, common_test_utils.test_loss_fn],
num_microbatches=[None, 1, 2],
is_eager=[True, False],
@ -62,38 +77,74 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
layer_name,
input_dim,
output_dim,
layer_registry_name,
per_example_loss_fn,
num_microbatches,
is_eager,
partial,
weighted,
):
model_generator = get_dense_model_generators()[model_name]
layer_generator = get_dense_layer_generators()[layer_name]
# Parse inputs to generate test data.
x_batches, weight_batches = common_test_utils.get_nd_test_batches(input_dim)
default_registry = layer_registry.make_default_layer_registry()
# Load shared assets to all devices.
with self.strategy.scope():
model = common_test_utils.get_model_from_generator(
model_generator=get_dense_model_generators()[model_name],
layer_generator=get_dense_layer_generators()[layer_name],
input_dims=input_dim,
output_dim=output_dim,
is_eager=is_eager,
)
# Define the main testing ops. These may be later compiled to a Graph op.
def test_op(x_batch, weight_batch):
return common_test_utils.get_computed_and_true_norms_from_model(
model=model,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch,
weight_batch=weight_batch if weighted else None,
registry=get_dense_layer_registries()[layer_registry_name],
partial=partial,
)
# TPUs can only run `tf.function`-decorated functions.
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
if using_tpu:
test_op = tf.function(test_op, jit_compile=True, autograph=False)
# TPUs use lower precision than CPUs, so we relax our criterion.
# E.g., one of the TPU runs generated the following results:
#
# computed_norm = 22.530651
# true_norm = 22.570976
# abs_diff = 0.04032516
# rel_diff = 0.00178659
#
# which is a reasonable level of error for computing gradient norms.
# Other trials also give an absolute (resp. relative) error of around
# 0.05 (resp. 0.0015).
rtol = 1e-2 if using_tpu else 1e-3
atol = 1e-1 if using_tpu else 1e-2
for x_batch, weight_batch in zip(x_batches, weight_batches):
batch_size = x_batch.shape[0]
if num_microbatches is not None and batch_size % num_microbatches != 0:
continue
computed_norms, true_norms = (
common_test_utils.get_computed_and_true_norms(
model_generator,
layer_generator,
input_dim,
output_dim,
per_example_loss_fn,
num_microbatches,
is_eager,
x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch,
weight_batch=weight_batch if weighted else None,
registry=default_registry,
partial=partial,
)
# Set up the device ops and run the test.
computed_norms, true_norms = self.strategy.run(
test_op, args=(x_batch, weight_batch)
)
# TPUs return replica contexts, which must be unwrapped.
if using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0]
true_norms = true_norms.values[0]
expected_size = num_microbatches or batch_size
self.assertEqual(computed_norms.shape[0], expected_size)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
self.assertEqual(tf.shape(computed_norms)[0], expected_size)
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
if __name__ == '__main__':

View file

@ -0,0 +1,29 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense_test
class GradNormTpuTest(dense_test.GradNormTest):
def setUp(self):
super().setUp()
self.strategy = ctu.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
if __name__ == '__main__':
tf.test.main()

View file

@ -62,12 +62,46 @@ def embedding_layer_computation(
"'_use_one_hot_matmul' is not supported."
)
input_ids = tf.cast(*input_args, tf.int32)
if len(layer_instance.trainable_variables) != 1:
raise ValueError(
"Only layer instances with only one set of trainable variables"
"are permitted."
)
base_vars = layer_instance.trainable_variables[0]
tape.watch(base_vars)
outputs = tf.nn.embedding_lookup(base_vars, input_ids)
def sqr_norm_fn(base_vars_grads):
# Get a 1D tensor of the row indices.
"""Fast square norm function for Keras embedding layers.
Args:
base_vars_grads: A list of batched embedding gradients.
Returns:
A 1D `tf.Tensor` of squared gradient norms.
NOTE: to help understand the code, we document in the function body what
the expected intermediate variables are for the below running example:
input_ids = [[1, 1, 2], [0], [2, 0]]
base_vars_grads.indices = [1, 1, 2, 0, 2, 0]
base_vars_grads.values = [[0.2], [0.2], [0.3], [0.1], [0.3], [0.1]]
For ease of reference, we also list these variables below:
row_indices = [[0], [0], [0], [1], [2], [2]]
slice_indices = [[1], [1], [2], [0], [2], [0]]
paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
new_index_positions = [0, 0, 1, 2, 3, 4]
num_unique_paired_indices = 5
unique_batch_ids = [0, 0, 1, 2, 2]
summed_gradients
= [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
= [[0.4], [0.3], [0.1], [0.3], [0.1]]
sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
"""
# We first get a 1D tensor of the row indices.
nrows = tf.shape(input_ids)[0]
if isinstance(input_ids, tf.RaggedTensor):
row_indices = tf.expand_dims(
@ -81,16 +115,19 @@ def embedding_layer_computation(
raise NotImplementedError(
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
)
row_indices = tf.cast(row_indices, tf.int32)
row_indices = tf.cast(row_indices, tf.int64)
if num_microbatches is not None:
microbatch_size = tf.cast(nrows / num_microbatches, tf.int32)
microbatch_size = tf.cast(nrows / num_microbatches, tf.int64)
nrows = num_microbatches
row_indices = tf.cast(
tf.math.floordiv(row_indices, microbatch_size), tf.int32
tf.math.floordiv(row_indices, microbatch_size), tf.int64
)
# Sum-reduce the `IndexSlices` that is the result of a `tape.gradient()`
# call. The sum is reduced by the repeated embedding indices and batch
# index. It is adapted from the logic in:
# NOTE: expected values for the running example above are
# row_indices = [[0], [0], [0], [1], [2], [2]]
# Now, sum-reduce the `IndexedSlices` that is the result of a
# `tape.gradient()` call. The sum is reduced by the repeated embedding
# indices and batch index. It is adapted from the logic in:
# tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices
if not isinstance(base_vars_grads, tf.IndexedSlices):
raise NotImplementedError(
@ -105,20 +142,39 @@ def embedding_layer_computation(
(unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2(
x=paired_indices, axis=[0]
)
# NOTE: expected values for the running example above are
# slice_indices = [[1], [1], [2], [0], [2], [0]]
# paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
# unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
# new_index_positions = [0, 0, 1, 2, 3, 4]
# Next, sum according to the new positions and compute the squared
# gradient norms. Oddly enough, not sorting
# these indices will break tensor shape inference logic on TPUs.
num_unique_paired_indices = tf.shape(unique_paired_indices)[0]
unique_batch_ids = unique_paired_indices[:, 0]
summed_gradients = tf.math.unsorted_segment_sum(
base_vars_grads.values,
new_index_positions,
tf.shape(unique_paired_indices)[0],
num_unique_paired_indices,
)
# Compute the squared gradient norms at the per-example level.
sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1)
summed_data_range = tf.range(tf.shape(sqr_gradient_sum)[0])
return tf.sparse.segment_sum(
# NOTE: expected values for the running example above are
# num_unique_paired_indices = 5
# unique_batch_ids = [0, 0, 1, 2, 2]
# summed_gradients
# = [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
# = [[0.4], [0.3], [0.1], [0.3], [0.1]]
# sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
# Use a scatter-add strategy to ensure TPU compatibility.
result = tf.zeros([nrows])
return tf.tensor_scatter_nd_add(
result,
tf.expand_dims(unique_batch_ids, axis=-1),
sqr_gradient_sum,
summed_data_range,
tf.sort(unique_batch_ids),
num_segments=nrows,
) # fill in empty inputs
)
# NOTE: the expected output for the running example is
# [0.16 + 0.09, 0.01, 0.09 + 0.01] = [0.25, 0.01, 0.1]
return base_vars, outputs, sqr_norm_fn

View file

@ -16,6 +16,8 @@ from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import embedding
# ==============================================================================
@ -29,41 +31,50 @@ def get_embedding_model_generators():
}
def get_embedding_inputs():
"""Generates test pairs of the form (is_ragged, input_data)."""
return [
# 2D inputs.
(False, [[0, 1]]),
(False, [[0, 1], [1, 1], [0, 0]]),
(True, [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]]),
(True, [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1], [0, 1]]),
# 3D inputs.
(False, [[[0, 1]]]),
(False, [[[0, 1]], [[1, 1]], [[0, 0]]]),
(True, [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]),
(True, [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]], [[0]]]),
]
def get_embedding_layer_registries():
dbl_registry = layer_registry.LayerRegistry()
dbl_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
dbl_registry.insert(
tf.keras.layers.Embedding, embedding.embedding_layer_computation
)
return {
'embed_and_dense': dbl_registry,
'default': layer_registry.make_default_layer_registry(),
}
# ==============================================================================
# Main tests.
# ==============================================================================
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.strategy = tf.distribute.get_strategy()
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
# supports them for embeddings.
@parameterized.product(
x_batch=[
# 2D inputs.
tf.convert_to_tensor([[0, 1]], dtype_hint=tf.int32),
tf.convert_to_tensor([[0, 1], [1, 1], [0, 0]], dtype_hint=tf.int32),
tf.ragged.constant(
[[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32
),
tf.ragged.constant(
[[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1], [0, 1]],
dtype=tf.int32,
),
# 3D inputs.
tf.convert_to_tensor([[[0, 1]]], dtype_hint=tf.int32),
tf.convert_to_tensor(
[[[0, 1]], [[1, 1]], [[0, 0]]], dtype_hint=tf.int32
),
tf.ragged.constant(
[[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]],
dtype=tf.int32,
),
tf.ragged.constant(
[[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]], [[0]]],
dtype=tf.int32,
),
],
x_inputs=get_embedding_inputs(),
model_name=list(get_embedding_model_generators().keys()),
output_dim=[2],
layer_registry_name=list(get_embedding_layer_registries().keys()),
per_example_loss_fn=[None, common_test_utils.test_loss_fn],
num_microbatches=[None, 2],
is_eager=[True, False],
@ -71,39 +82,74 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
)
def test_gradient_norms_on_various_models(
self,
x_batch,
x_inputs,
model_name,
output_dim,
layer_registry_name,
per_example_loss_fn,
num_microbatches,
is_eager,
partial,
):
batch_size = x_batch.shape[0]
# The following are invalid test combinations, and are skipped.
# Parse inputs to generate test data.
is_ragged, input_data = x_inputs
embed_indices = (
tf.ragged.constant(input_data, dtype=tf.int32)
if is_ragged
else tf.convert_to_tensor(input_data, dtype_hint=tf.int32)
)
# The following are invalid test combinations and, hence, are skipped.
batch_size = embed_indices.shape[0]
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
if (
num_microbatches is not None and batch_size % num_microbatches != 0
) or (
model_name == 'weighted_bow1' and isinstance(x_batch, tf.RaggedTensor)
(num_microbatches is not None and batch_size % num_microbatches != 0)
or (model_name == 'weighted_bow1' and is_ragged)
or (
# Current clipping ops do not have corresponding TPU kernels.
using_tpu
and is_ragged
)
):
return
default_registry = layer_registry.make_default_layer_registry()
model_generator = get_embedding_model_generators()[model_name]
computed_norms, true_norms = (
common_test_utils.get_computed_and_true_norms(
model_generator=model_generator,
layer_generator=None,
input_dims=x_batch.shape[1:],
output_dim=output_dim,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
is_eager=is_eager,
x_batch=x_batch,
registry=default_registry,
partial=partial,
)
# Load shared assets to all devices.
with self.strategy.scope():
model = common_test_utils.get_model_from_generator(
model_generator=get_embedding_model_generators()[model_name],
layer_generator=None,
input_dims=embed_indices.shape[1:],
output_dim=output_dim,
is_eager=is_eager,
)
# Define the main testing ops. These may be later compiled to a Graph op.
def test_op(x_batch):
return common_test_utils.get_computed_and_true_norms_from_model(
model=model,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
x_batch=x_batch,
registry=get_embedding_layer_registries()[layer_registry_name],
partial=partial,
)
# TPUs can only run `tf.function`-decorated functions.
if using_tpu:
test_op = tf.function(test_op, autograph=False)
# Set up the device ops and run the test.
computed_norms, true_norms = self.strategy.run(
test_op, args=(embed_indices,)
)
self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size)
# TPUs return replica contexts, which must be unwrapped.
if using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0]
true_norms = true_norms.values[0]
expected_size = num_microbatches or batch_size
self.assertEqual(tf.shape(computed_norms)[0], expected_size)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)

View file

@ -0,0 +1,29 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import embedding_test
class GradNormTpuTest(embedding_test.GradNormTest):
def setUp(self):
super().setUp()
self.strategy = common_test_utils.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
if __name__ == '__main__':
tf.test.main()