Modify fast clipping logic to support computation on TPUs.
PiperOrigin-RevId: 550673798
This commit is contained in:
parent
cb6659d11b
commit
c1c97f1c1c
7 changed files with 417 additions and 124 deletions
|
@ -15,7 +15,9 @@
|
||||||
|
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
import tensorflow.compat.v2 as tf_compat
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
@ -24,6 +26,23 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Helper functions
|
# Helper functions
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
def create_tpu_strategy():
|
||||||
|
"""Initializes a TPU environment."""
|
||||||
|
# Done to avoid transferring data between CPUs and TPUs.
|
||||||
|
tf_compat.config.set_soft_device_placement(False)
|
||||||
|
resolver = tf_compat.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
||||||
|
tf_compat.config.experimental_connect_to_cluster(resolver)
|
||||||
|
tf_compat.tpu.experimental.initialize_tpu_system(resolver)
|
||||||
|
return tf_compat.distribute.TPUStrategy(resolver)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_replica_values_are_close(test_case_obj, replica_context):
|
||||||
|
"""Checks if all replica context tensors are near each other."""
|
||||||
|
base_tensor = replica_context.values[0]
|
||||||
|
for t in replica_context.values[1:]:
|
||||||
|
test_case_obj.assertAllClose(base_tensor, t)
|
||||||
|
|
||||||
|
|
||||||
def get_nd_test_batches(n: int):
|
def get_nd_test_batches(n: int):
|
||||||
"""Returns a list of input batches of dimension n."""
|
"""Returns a list of input batches of dimension n."""
|
||||||
# The first two batches have a single element, the last batch has 2 elements.
|
# The first two batches have a single element, the last batch has 2 elements.
|
||||||
|
@ -79,13 +98,76 @@ def compute_true_gradient_norms(
|
||||||
trainable_vars = trainable_vars or input_model.trainable_variables
|
trainable_vars = trainable_vars or input_model.trainable_variables
|
||||||
for var in trainable_vars:
|
for var in trainable_vars:
|
||||||
jacobian = tape.jacobian(loss, var, experimental_use_pfor=False)
|
jacobian = tape.jacobian(loss, var, experimental_use_pfor=False)
|
||||||
reduction_axes = tf.range(1, len(jacobian.shape))
|
reduction_axes = tf.range(1, tf.rank(jacobian))
|
||||||
sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes)
|
sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes)
|
||||||
sqr_norms.append(sqr_norm)
|
sqr_norms.append(sqr_norm)
|
||||||
sqr_norm_tsr = tf.stack(sqr_norms, axis=1)
|
sqr_norm_tsr = tf.stack(sqr_norms, axis=1)
|
||||||
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_from_generator(
|
||||||
|
model_generator: type_aliases.ModelGenerator,
|
||||||
|
layer_generator: type_aliases.LayerGenerator,
|
||||||
|
input_dims: Union[int, List[int]],
|
||||||
|
output_dim: int,
|
||||||
|
is_eager: bool,
|
||||||
|
) -> tf.keras.Model:
|
||||||
|
"""Creates a simple model from input specifications."""
|
||||||
|
model = model_generator(layer_generator, input_dims, output_dim)
|
||||||
|
model.compile(
|
||||||
|
optimizer=tf.keras.optimizers.SGD(learning_rate=1.0),
|
||||||
|
loss=tf.keras.losses.MeanSquaredError(
|
||||||
|
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
|
||||||
|
),
|
||||||
|
run_eagerly=is_eager,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_computed_and_true_norms_from_model(
|
||||||
|
model: tf.keras.Model,
|
||||||
|
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
|
||||||
|
num_microbatches: Optional[int],
|
||||||
|
x_batch: tf.Tensor,
|
||||||
|
weight_batch: Optional[tf.Tensor] = None,
|
||||||
|
rng_seed: int = 777,
|
||||||
|
registry: layer_registry.LayerRegistry = None,
|
||||||
|
partial: bool = False,
|
||||||
|
):
|
||||||
|
"""Generates relevant norms from an input model and other specs."""
|
||||||
|
trainable_vars = None
|
||||||
|
if partial:
|
||||||
|
# Gets the first layer with variables.
|
||||||
|
for l in model.layers:
|
||||||
|
trainable_vars = l.trainable_variables
|
||||||
|
if trainable_vars:
|
||||||
|
break
|
||||||
|
y_pred = model(x_batch)
|
||||||
|
y_batch = tf.ones_like(y_pred)
|
||||||
|
tf.keras.utils.set_random_seed(rng_seed)
|
||||||
|
computed_norms = clip_grads.compute_gradient_norms(
|
||||||
|
input_model=model,
|
||||||
|
x_batch=x_batch,
|
||||||
|
y_batch=y_batch,
|
||||||
|
weight_batch=weight_batch,
|
||||||
|
layer_registry=registry,
|
||||||
|
per_example_loss_fn=per_example_loss_fn,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
trainable_vars=trainable_vars,
|
||||||
|
)
|
||||||
|
tf.keras.utils.set_random_seed(rng_seed)
|
||||||
|
true_norms = compute_true_gradient_norms(
|
||||||
|
input_model=model,
|
||||||
|
x_batch=x_batch,
|
||||||
|
y_batch=y_batch,
|
||||||
|
weight_batch=weight_batch,
|
||||||
|
per_example_loss_fn=per_example_loss_fn,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
trainable_vars=trainable_vars,
|
||||||
|
)
|
||||||
|
return computed_norms, true_norms
|
||||||
|
|
||||||
|
|
||||||
def get_computed_and_true_norms(
|
def get_computed_and_true_norms(
|
||||||
model_generator: type_aliases.ModelGenerator,
|
model_generator: type_aliases.ModelGenerator,
|
||||||
layer_generator: type_aliases.LayerGenerator,
|
layer_generator: type_aliases.LayerGenerator,
|
||||||
|
@ -133,45 +215,23 @@ def get_computed_and_true_norms(
|
||||||
model and layer generators. The second element contains the true clipped
|
model and layer generators. The second element contains the true clipped
|
||||||
gradient norms under the aforementioned setting.
|
gradient norms under the aforementioned setting.
|
||||||
"""
|
"""
|
||||||
model = model_generator(layer_generator, input_dims, output_dim)
|
model = get_model_from_generator(
|
||||||
model.compile(
|
model_generator=model_generator,
|
||||||
optimizer=tf.keras.optimizers.SGD(learning_rate=1.0),
|
layer_generator=layer_generator,
|
||||||
loss=tf.keras.losses.MeanSquaredError(
|
input_dims=input_dims,
|
||||||
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
|
output_dim=output_dim,
|
||||||
),
|
is_eager=is_eager,
|
||||||
run_eagerly=is_eager,
|
|
||||||
)
|
)
|
||||||
trainable_vars = None
|
return get_computed_and_true_norms_from_model(
|
||||||
if partial:
|
model=model,
|
||||||
# Gets the first layer with variables.
|
|
||||||
for l in model.layers:
|
|
||||||
trainable_vars = l.trainable_variables
|
|
||||||
if trainable_vars:
|
|
||||||
break
|
|
||||||
y_pred = model(x_batch)
|
|
||||||
y_batch = tf.ones_like(y_pred)
|
|
||||||
tf.keras.utils.set_random_seed(rng_seed)
|
|
||||||
computed_norms = clip_grads.compute_gradient_norms(
|
|
||||||
input_model=model,
|
|
||||||
x_batch=x_batch,
|
|
||||||
y_batch=y_batch,
|
|
||||||
weight_batch=weight_batch,
|
|
||||||
layer_registry=registry,
|
|
||||||
per_example_loss_fn=per_example_loss_fn,
|
per_example_loss_fn=per_example_loss_fn,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
trainable_vars=trainable_vars,
|
x_batch=x_batch,
|
||||||
|
weight_batch=weight_batch,
|
||||||
|
rng_seed=rng_seed,
|
||||||
|
registry=registry,
|
||||||
|
partial=partial,
|
||||||
)
|
)
|
||||||
tf.keras.utils.set_random_seed(rng_seed)
|
|
||||||
true_norms = compute_true_gradient_norms(
|
|
||||||
model,
|
|
||||||
x_batch,
|
|
||||||
y_batch,
|
|
||||||
weight_batch,
|
|
||||||
per_example_loss_fn,
|
|
||||||
num_microbatches,
|
|
||||||
trainable_vars=trainable_vars,
|
|
||||||
)
|
|
||||||
return (computed_norms, true_norms)
|
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
@ -234,7 +294,11 @@ def make_bow_model(layer_generator, input_dims, output_dim):
|
||||||
# in eache example.
|
# in eache example.
|
||||||
emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim)
|
emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim)
|
||||||
feature_embs = emb_layer(inputs)
|
feature_embs = emb_layer(inputs)
|
||||||
reduction_axes = tf.range(1, len(feature_embs.shape))
|
# Embeddings add one extra dimension to its inputs, which combined with the
|
||||||
|
# batch dimension at dimension 0, equals two additional dimensions compared
|
||||||
|
# to the number of input dimensions. Here, we want to reduce over the output
|
||||||
|
# space, but exclude the batch dimension.
|
||||||
|
reduction_axes = range(1, len(input_dims) + 2)
|
||||||
example_embs = tf.expand_dims(
|
example_embs = tf.expand_dims(
|
||||||
tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1
|
tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1
|
||||||
)
|
)
|
||||||
|
@ -253,7 +317,11 @@ def make_dense_bow_model(layer_generator, input_dims, output_dim):
|
||||||
input_dim=cardinality, output_dim=output_dim
|
input_dim=cardinality, output_dim=output_dim
|
||||||
)
|
)
|
||||||
feature_embs = emb_layer(inputs)
|
feature_embs = emb_layer(inputs)
|
||||||
reduction_axes = tf.range(1, len(feature_embs.shape))
|
# Embeddings add one extra dimension to its inputs, which combined with the
|
||||||
|
# batch dimension at dimension 0, equals two additional dimensions compared
|
||||||
|
# to the number of input dimensions. Here, we want to reduce over the output
|
||||||
|
# space, but exclude the batch dimension.
|
||||||
|
reduction_axes = range(1, len(input_dims) + 2)
|
||||||
example_embs = tf.expand_dims(
|
example_embs = tf.expand_dims(
|
||||||
tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1
|
tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1
|
||||||
)
|
)
|
||||||
|
@ -274,9 +342,21 @@ def make_weighted_bow_model(layer_generator, input_dims, output_dim):
|
||||||
input_dim=cardinality, output_dim=output_dim
|
input_dim=cardinality, output_dim=output_dim
|
||||||
)
|
)
|
||||||
feature_embs = emb_layer(inputs)
|
feature_embs = emb_layer(inputs)
|
||||||
feature_weights = tf.random.uniform(tf.shape(feature_embs))
|
# Use deterministic weights to avoid seeding issues on TPUs.
|
||||||
|
feature_shape = input_dims + [output_dim]
|
||||||
|
feature_weights = tf.expand_dims(
|
||||||
|
tf.reshape(
|
||||||
|
tf.range(np.product(feature_shape), dtype=tf.float32),
|
||||||
|
feature_shape,
|
||||||
|
),
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
weighted_embs = feature_embs * feature_weights
|
weighted_embs = feature_embs * feature_weights
|
||||||
reduction_axes = tf.range(1, len(weighted_embs.shape))
|
# Embeddings add one extra dimension to its inputs, which combined with the
|
||||||
|
# batch dimension at dimension 0, equals two additional dimensions compared
|
||||||
|
# to the number of input dimensions. Here, we want to reduce over the output
|
||||||
|
# space, but exclude the batch dimension.
|
||||||
|
reduction_axes = range(1, len(input_dims) + 2)
|
||||||
example_embs = tf.expand_dims(
|
example_embs = tf.expand_dims(
|
||||||
tf.reduce_sum(weighted_embs, axis=reduction_axes), axis=-1
|
tf.reduce_sum(weighted_embs, axis=reduction_axes), axis=-1
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,9 +19,10 @@ py_test(
|
||||||
size = "large",
|
size = "large",
|
||||||
srcs = ["dense_test.py"],
|
srcs = ["dense_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 8,
|
shard_count = 12,
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":dense",
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
||||||
|
@ -37,12 +38,13 @@ py_library(
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "embedding_test",
|
name = "embedding_test",
|
||||||
size = "large",
|
|
||||||
srcs = ["embedding_test.py"],
|
srcs = ["embedding_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 8,
|
shard_count = 12,
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":dense",
|
||||||
|
":embedding",
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
||||||
|
|
|
@ -16,6 +16,7 @@ from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
@ -40,16 +41,30 @@ def get_dense_model_generators():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_dense_layer_registries():
|
||||||
|
dense_registry = layer_registry.LayerRegistry()
|
||||||
|
dense_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
|
||||||
|
return {
|
||||||
|
'dense_only': dense_registry,
|
||||||
|
'default': layer_registry.make_default_layer_registry(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Main tests.
|
# Main tests.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.strategy = tf.distribute.get_strategy()
|
||||||
|
|
||||||
@parameterized.product(
|
@parameterized.product(
|
||||||
model_name=list(get_dense_model_generators().keys()),
|
model_name=list(get_dense_model_generators().keys()),
|
||||||
layer_name=list(get_dense_layer_generators().keys()),
|
layer_name=list(get_dense_layer_generators().keys()),
|
||||||
input_dim=[4],
|
input_dim=[4],
|
||||||
output_dim=[2],
|
output_dim=[2],
|
||||||
|
layer_registry_name=list(get_dense_layer_registries().keys()),
|
||||||
per_example_loss_fn=[None, common_test_utils.test_loss_fn],
|
per_example_loss_fn=[None, common_test_utils.test_loss_fn],
|
||||||
num_microbatches=[None, 1, 2],
|
num_microbatches=[None, 1, 2],
|
||||||
is_eager=[True, False],
|
is_eager=[True, False],
|
||||||
|
@ -62,38 +77,74 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
layer_name,
|
layer_name,
|
||||||
input_dim,
|
input_dim,
|
||||||
output_dim,
|
output_dim,
|
||||||
|
layer_registry_name,
|
||||||
per_example_loss_fn,
|
per_example_loss_fn,
|
||||||
num_microbatches,
|
num_microbatches,
|
||||||
is_eager,
|
is_eager,
|
||||||
partial,
|
partial,
|
||||||
weighted,
|
weighted,
|
||||||
):
|
):
|
||||||
model_generator = get_dense_model_generators()[model_name]
|
# Parse inputs to generate test data.
|
||||||
layer_generator = get_dense_layer_generators()[layer_name]
|
|
||||||
x_batches, weight_batches = common_test_utils.get_nd_test_batches(input_dim)
|
x_batches, weight_batches = common_test_utils.get_nd_test_batches(input_dim)
|
||||||
default_registry = layer_registry.make_default_layer_registry()
|
|
||||||
|
# Load shared assets to all devices.
|
||||||
|
with self.strategy.scope():
|
||||||
|
model = common_test_utils.get_model_from_generator(
|
||||||
|
model_generator=get_dense_model_generators()[model_name],
|
||||||
|
layer_generator=get_dense_layer_generators()[layer_name],
|
||||||
|
input_dims=input_dim,
|
||||||
|
output_dim=output_dim,
|
||||||
|
is_eager=is_eager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define the main testing ops. These may be later compiled to a Graph op.
|
||||||
|
def test_op(x_batch, weight_batch):
|
||||||
|
return common_test_utils.get_computed_and_true_norms_from_model(
|
||||||
|
model=model,
|
||||||
|
per_example_loss_fn=per_example_loss_fn,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch,
|
||||||
|
weight_batch=weight_batch if weighted else None,
|
||||||
|
registry=get_dense_layer_registries()[layer_registry_name],
|
||||||
|
partial=partial,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TPUs can only run `tf.function`-decorated functions.
|
||||||
|
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
|
||||||
|
if using_tpu:
|
||||||
|
test_op = tf.function(test_op, jit_compile=True, autograph=False)
|
||||||
|
|
||||||
|
# TPUs use lower precision than CPUs, so we relax our criterion.
|
||||||
|
# E.g., one of the TPU runs generated the following results:
|
||||||
|
#
|
||||||
|
# computed_norm = 22.530651
|
||||||
|
# true_norm = 22.570976
|
||||||
|
# abs_diff = 0.04032516
|
||||||
|
# rel_diff = 0.00178659
|
||||||
|
#
|
||||||
|
# which is a reasonable level of error for computing gradient norms.
|
||||||
|
# Other trials also give an absolute (resp. relative) error of around
|
||||||
|
# 0.05 (resp. 0.0015).
|
||||||
|
rtol = 1e-2 if using_tpu else 1e-3
|
||||||
|
atol = 1e-1 if using_tpu else 1e-2
|
||||||
|
|
||||||
for x_batch, weight_batch in zip(x_batches, weight_batches):
|
for x_batch, weight_batch in zip(x_batches, weight_batches):
|
||||||
batch_size = x_batch.shape[0]
|
batch_size = x_batch.shape[0]
|
||||||
if num_microbatches is not None and batch_size % num_microbatches != 0:
|
if num_microbatches is not None and batch_size % num_microbatches != 0:
|
||||||
continue
|
continue
|
||||||
computed_norms, true_norms = (
|
# Set up the device ops and run the test.
|
||||||
common_test_utils.get_computed_and_true_norms(
|
computed_norms, true_norms = self.strategy.run(
|
||||||
model_generator,
|
test_op, args=(x_batch, weight_batch)
|
||||||
layer_generator,
|
|
||||||
input_dim,
|
|
||||||
output_dim,
|
|
||||||
per_example_loss_fn,
|
|
||||||
num_microbatches,
|
|
||||||
is_eager,
|
|
||||||
x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch,
|
|
||||||
weight_batch=weight_batch if weighted else None,
|
|
||||||
registry=default_registry,
|
|
||||||
partial=partial,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
# TPUs return replica contexts, which must be unwrapped.
|
||||||
|
if using_tpu:
|
||||||
|
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||||
|
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||||
|
computed_norms = computed_norms.values[0]
|
||||||
|
true_norms = true_norms.values[0]
|
||||||
expected_size = num_microbatches or batch_size
|
expected_size = num_microbatches or batch_size
|
||||||
self.assertEqual(computed_norms.shape[0], expected_size)
|
self.assertEqual(tf.shape(computed_norms)[0], expected_size)
|
||||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense_test
|
||||||
|
|
||||||
|
|
||||||
|
class GradNormTpuTest(dense_test.GradNormTest):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.strategy = ctu.create_tpu_strategy()
|
||||||
|
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -62,12 +62,46 @@ def embedding_layer_computation(
|
||||||
"'_use_one_hot_matmul' is not supported."
|
"'_use_one_hot_matmul' is not supported."
|
||||||
)
|
)
|
||||||
input_ids = tf.cast(*input_args, tf.int32)
|
input_ids = tf.cast(*input_args, tf.int32)
|
||||||
|
if len(layer_instance.trainable_variables) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Only layer instances with only one set of trainable variables"
|
||||||
|
"are permitted."
|
||||||
|
)
|
||||||
base_vars = layer_instance.trainable_variables[0]
|
base_vars = layer_instance.trainable_variables[0]
|
||||||
tape.watch(base_vars)
|
tape.watch(base_vars)
|
||||||
outputs = tf.nn.embedding_lookup(base_vars, input_ids)
|
outputs = tf.nn.embedding_lookup(base_vars, input_ids)
|
||||||
|
|
||||||
def sqr_norm_fn(base_vars_grads):
|
def sqr_norm_fn(base_vars_grads):
|
||||||
# Get a 1D tensor of the row indices.
|
"""Fast square norm function for Keras embedding layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_vars_grads: A list of batched embedding gradients.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 1D `tf.Tensor` of squared gradient norms.
|
||||||
|
|
||||||
|
NOTE: to help understand the code, we document in the function body what
|
||||||
|
the expected intermediate variables are for the below running example:
|
||||||
|
|
||||||
|
input_ids = [[1, 1, 2], [0], [2, 0]]
|
||||||
|
base_vars_grads.indices = [1, 1, 2, 0, 2, 0]
|
||||||
|
base_vars_grads.values = [[0.2], [0.2], [0.3], [0.1], [0.3], [0.1]]
|
||||||
|
|
||||||
|
For ease of reference, we also list these variables below:
|
||||||
|
|
||||||
|
row_indices = [[0], [0], [0], [1], [2], [2]]
|
||||||
|
slice_indices = [[1], [1], [2], [0], [2], [0]]
|
||||||
|
paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||||
|
unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||||
|
new_index_positions = [0, 0, 1, 2, 3, 4]
|
||||||
|
num_unique_paired_indices = 5
|
||||||
|
unique_batch_ids = [0, 0, 1, 2, 2]
|
||||||
|
summed_gradients
|
||||||
|
= [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
|
||||||
|
= [[0.4], [0.3], [0.1], [0.3], [0.1]]
|
||||||
|
sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
|
||||||
|
"""
|
||||||
|
# We first get a 1D tensor of the row indices.
|
||||||
nrows = tf.shape(input_ids)[0]
|
nrows = tf.shape(input_ids)[0]
|
||||||
if isinstance(input_ids, tf.RaggedTensor):
|
if isinstance(input_ids, tf.RaggedTensor):
|
||||||
row_indices = tf.expand_dims(
|
row_indices = tf.expand_dims(
|
||||||
|
@ -81,16 +115,19 @@ def embedding_layer_computation(
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
|
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
|
||||||
)
|
)
|
||||||
row_indices = tf.cast(row_indices, tf.int32)
|
row_indices = tf.cast(row_indices, tf.int64)
|
||||||
if num_microbatches is not None:
|
if num_microbatches is not None:
|
||||||
microbatch_size = tf.cast(nrows / num_microbatches, tf.int32)
|
microbatch_size = tf.cast(nrows / num_microbatches, tf.int64)
|
||||||
nrows = num_microbatches
|
nrows = num_microbatches
|
||||||
row_indices = tf.cast(
|
row_indices = tf.cast(
|
||||||
tf.math.floordiv(row_indices, microbatch_size), tf.int32
|
tf.math.floordiv(row_indices, microbatch_size), tf.int64
|
||||||
)
|
)
|
||||||
# Sum-reduce the `IndexSlices` that is the result of a `tape.gradient()`
|
# NOTE: expected values for the running example above are
|
||||||
# call. The sum is reduced by the repeated embedding indices and batch
|
# row_indices = [[0], [0], [0], [1], [2], [2]]
|
||||||
# index. It is adapted from the logic in:
|
|
||||||
|
# Now, sum-reduce the `IndexedSlices` that is the result of a
|
||||||
|
# `tape.gradient()` call. The sum is reduced by the repeated embedding
|
||||||
|
# indices and batch index. It is adapted from the logic in:
|
||||||
# tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices
|
# tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices
|
||||||
if not isinstance(base_vars_grads, tf.IndexedSlices):
|
if not isinstance(base_vars_grads, tf.IndexedSlices):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
@ -105,20 +142,39 @@ def embedding_layer_computation(
|
||||||
(unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2(
|
(unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2(
|
||||||
x=paired_indices, axis=[0]
|
x=paired_indices, axis=[0]
|
||||||
)
|
)
|
||||||
|
# NOTE: expected values for the running example above are
|
||||||
|
# slice_indices = [[1], [1], [2], [0], [2], [0]]
|
||||||
|
# paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||||
|
# unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
|
||||||
|
# new_index_positions = [0, 0, 1, 2, 3, 4]
|
||||||
|
|
||||||
|
# Next, sum according to the new positions and compute the squared
|
||||||
|
# gradient norms. Oddly enough, not sorting
|
||||||
|
# these indices will break tensor shape inference logic on TPUs.
|
||||||
|
num_unique_paired_indices = tf.shape(unique_paired_indices)[0]
|
||||||
unique_batch_ids = unique_paired_indices[:, 0]
|
unique_batch_ids = unique_paired_indices[:, 0]
|
||||||
summed_gradients = tf.math.unsorted_segment_sum(
|
summed_gradients = tf.math.unsorted_segment_sum(
|
||||||
base_vars_grads.values,
|
base_vars_grads.values,
|
||||||
new_index_positions,
|
new_index_positions,
|
||||||
tf.shape(unique_paired_indices)[0],
|
num_unique_paired_indices,
|
||||||
)
|
)
|
||||||
# Compute the squared gradient norms at the per-example level.
|
|
||||||
sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1)
|
sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1)
|
||||||
summed_data_range = tf.range(tf.shape(sqr_gradient_sum)[0])
|
# NOTE: expected values for the running example above are
|
||||||
return tf.sparse.segment_sum(
|
# num_unique_paired_indices = 5
|
||||||
|
# unique_batch_ids = [0, 0, 1, 2, 2]
|
||||||
|
# summed_gradients
|
||||||
|
# = [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1]
|
||||||
|
# = [[0.4], [0.3], [0.1], [0.3], [0.1]]
|
||||||
|
# sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01]
|
||||||
|
|
||||||
|
# Use a scatter-add strategy to ensure TPU compatibility.
|
||||||
|
result = tf.zeros([nrows])
|
||||||
|
return tf.tensor_scatter_nd_add(
|
||||||
|
result,
|
||||||
|
tf.expand_dims(unique_batch_ids, axis=-1),
|
||||||
sqr_gradient_sum,
|
sqr_gradient_sum,
|
||||||
summed_data_range,
|
)
|
||||||
tf.sort(unique_batch_ids),
|
# NOTE: the expected output for the running example is
|
||||||
num_segments=nrows,
|
# [0.16 + 0.09, 0.01, 0.09 + 0.01] = [0.25, 0.01, 0.1]
|
||||||
) # fill in empty inputs
|
|
||||||
|
|
||||||
return base_vars, outputs, sqr_norm_fn
|
return base_vars, outputs, sqr_norm_fn
|
||||||
|
|
|
@ -16,6 +16,8 @@ from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import embedding
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
@ -29,41 +31,50 @@ def get_embedding_model_generators():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_inputs():
|
||||||
|
"""Generates test pairs of the form (is_ragged, input_data)."""
|
||||||
|
return [
|
||||||
|
# 2D inputs.
|
||||||
|
(False, [[0, 1]]),
|
||||||
|
(False, [[0, 1], [1, 1], [0, 0]]),
|
||||||
|
(True, [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]]),
|
||||||
|
(True, [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1], [0, 1]]),
|
||||||
|
# 3D inputs.
|
||||||
|
(False, [[[0, 1]]]),
|
||||||
|
(False, [[[0, 1]], [[1, 1]], [[0, 0]]]),
|
||||||
|
(True, [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]),
|
||||||
|
(True, [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]], [[0]]]),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_layer_registries():
|
||||||
|
dbl_registry = layer_registry.LayerRegistry()
|
||||||
|
dbl_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
|
||||||
|
dbl_registry.insert(
|
||||||
|
tf.keras.layers.Embedding, embedding.embedding_layer_computation
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
'embed_and_dense': dbl_registry,
|
||||||
|
'default': layer_registry.make_default_layer_registry(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Main tests.
|
# Main tests.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.strategy = tf.distribute.get_strategy()
|
||||||
|
|
||||||
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
|
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
|
||||||
# supports them for embeddings.
|
# supports them for embeddings.
|
||||||
@parameterized.product(
|
@parameterized.product(
|
||||||
x_batch=[
|
x_inputs=get_embedding_inputs(),
|
||||||
# 2D inputs.
|
|
||||||
tf.convert_to_tensor([[0, 1]], dtype_hint=tf.int32),
|
|
||||||
tf.convert_to_tensor([[0, 1], [1, 1], [0, 0]], dtype_hint=tf.int32),
|
|
||||||
tf.ragged.constant(
|
|
||||||
[[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32
|
|
||||||
),
|
|
||||||
tf.ragged.constant(
|
|
||||||
[[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1], [0, 1]],
|
|
||||||
dtype=tf.int32,
|
|
||||||
),
|
|
||||||
# 3D inputs.
|
|
||||||
tf.convert_to_tensor([[[0, 1]]], dtype_hint=tf.int32),
|
|
||||||
tf.convert_to_tensor(
|
|
||||||
[[[0, 1]], [[1, 1]], [[0, 0]]], dtype_hint=tf.int32
|
|
||||||
),
|
|
||||||
tf.ragged.constant(
|
|
||||||
[[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]],
|
|
||||||
dtype=tf.int32,
|
|
||||||
),
|
|
||||||
tf.ragged.constant(
|
|
||||||
[[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]], [[0]]],
|
|
||||||
dtype=tf.int32,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
model_name=list(get_embedding_model_generators().keys()),
|
model_name=list(get_embedding_model_generators().keys()),
|
||||||
output_dim=[2],
|
output_dim=[2],
|
||||||
|
layer_registry_name=list(get_embedding_layer_registries().keys()),
|
||||||
per_example_loss_fn=[None, common_test_utils.test_loss_fn],
|
per_example_loss_fn=[None, common_test_utils.test_loss_fn],
|
||||||
num_microbatches=[None, 2],
|
num_microbatches=[None, 2],
|
||||||
is_eager=[True, False],
|
is_eager=[True, False],
|
||||||
|
@ -71,39 +82,74 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
)
|
)
|
||||||
def test_gradient_norms_on_various_models(
|
def test_gradient_norms_on_various_models(
|
||||||
self,
|
self,
|
||||||
x_batch,
|
x_inputs,
|
||||||
model_name,
|
model_name,
|
||||||
output_dim,
|
output_dim,
|
||||||
|
layer_registry_name,
|
||||||
per_example_loss_fn,
|
per_example_loss_fn,
|
||||||
num_microbatches,
|
num_microbatches,
|
||||||
is_eager,
|
is_eager,
|
||||||
partial,
|
partial,
|
||||||
):
|
):
|
||||||
batch_size = x_batch.shape[0]
|
# Parse inputs to generate test data.
|
||||||
# The following are invalid test combinations, and are skipped.
|
is_ragged, input_data = x_inputs
|
||||||
|
embed_indices = (
|
||||||
|
tf.ragged.constant(input_data, dtype=tf.int32)
|
||||||
|
if is_ragged
|
||||||
|
else tf.convert_to_tensor(input_data, dtype_hint=tf.int32)
|
||||||
|
)
|
||||||
|
|
||||||
|
# The following are invalid test combinations and, hence, are skipped.
|
||||||
|
batch_size = embed_indices.shape[0]
|
||||||
|
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
|
||||||
if (
|
if (
|
||||||
num_microbatches is not None and batch_size % num_microbatches != 0
|
(num_microbatches is not None and batch_size % num_microbatches != 0)
|
||||||
) or (
|
or (model_name == 'weighted_bow1' and is_ragged)
|
||||||
model_name == 'weighted_bow1' and isinstance(x_batch, tf.RaggedTensor)
|
or (
|
||||||
|
# Current clipping ops do not have corresponding TPU kernels.
|
||||||
|
using_tpu
|
||||||
|
and is_ragged
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
default_registry = layer_registry.make_default_layer_registry()
|
|
||||||
model_generator = get_embedding_model_generators()[model_name]
|
# Load shared assets to all devices.
|
||||||
computed_norms, true_norms = (
|
with self.strategy.scope():
|
||||||
common_test_utils.get_computed_and_true_norms(
|
model = common_test_utils.get_model_from_generator(
|
||||||
model_generator=model_generator,
|
model_generator=get_embedding_model_generators()[model_name],
|
||||||
layer_generator=None,
|
layer_generator=None,
|
||||||
input_dims=x_batch.shape[1:],
|
input_dims=embed_indices.shape[1:],
|
||||||
output_dim=output_dim,
|
output_dim=output_dim,
|
||||||
per_example_loss_fn=per_example_loss_fn,
|
is_eager=is_eager,
|
||||||
num_microbatches=num_microbatches,
|
)
|
||||||
is_eager=is_eager,
|
|
||||||
x_batch=x_batch,
|
# Define the main testing ops. These may be later compiled to a Graph op.
|
||||||
registry=default_registry,
|
def test_op(x_batch):
|
||||||
partial=partial,
|
return common_test_utils.get_computed_and_true_norms_from_model(
|
||||||
)
|
model=model,
|
||||||
|
per_example_loss_fn=per_example_loss_fn,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
x_batch=x_batch,
|
||||||
|
registry=get_embedding_layer_registries()[layer_registry_name],
|
||||||
|
partial=partial,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TPUs can only run `tf.function`-decorated functions.
|
||||||
|
if using_tpu:
|
||||||
|
test_op = tf.function(test_op, autograph=False)
|
||||||
|
|
||||||
|
# Set up the device ops and run the test.
|
||||||
|
computed_norms, true_norms = self.strategy.run(
|
||||||
|
test_op, args=(embed_indices,)
|
||||||
)
|
)
|
||||||
self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size)
|
# TPUs return replica contexts, which must be unwrapped.
|
||||||
|
if using_tpu:
|
||||||
|
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||||
|
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||||
|
computed_norms = computed_norms.values[0]
|
||||||
|
true_norms = true_norms.values[0]
|
||||||
|
expected_size = num_microbatches or batch_size
|
||||||
|
self.assertEqual(tf.shape(computed_norms)[0], expected_size)
|
||||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import embedding_test
|
||||||
|
|
||||||
|
|
||||||
|
class GradNormTpuTest(embedding_test.GradNormTest):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.strategy = common_test_utils.create_tpu_strategy()
|
||||||
|
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
Loading…
Reference in a new issue