Support gradient norm computation with respect to a subset of variables.

PiperOrigin-RevId: 519245638
This commit is contained in:
A. Unique TensorFlower 2023-03-24 14:57:14 -07:00
parent d5d60e2eac
commit 7796369d8b
3 changed files with 61 additions and 11 deletions

View file

@ -35,8 +35,10 @@ py_library(
py_test( py_test(
name = "clip_grads_test", name = "clip_grads_test",
size = "large",
srcs = ["clip_grads_test.py"], srcs = ["clip_grads_test.py"],
python_version = "PY3", python_version = "PY3",
shard_count = 8,
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
":clip_grads", ":clip_grads",

View file

@ -54,7 +54,11 @@ def get_registry_generator_fn(
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn( (layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
layer_instance, args, kwargs, tape, num_microbatches layer_instance, args, kwargs, tape, num_microbatches
) )
return layer_outputs, (layer_vars, layer_sqr_norm_fn) return layer_outputs, (
layer_vars,
layer_sqr_norm_fn,
layer_instance.trainable_weights,
)
else: else:
# Non-trainable layer. # Non-trainable layer.
return layer_instance(*args, **kwargs), None return layer_instance(*args, **kwargs), None
@ -69,6 +73,7 @@ def compute_gradient_norms(
layer_registry: lr.LayerRegistry, layer_registry: lr.LayerRegistry,
per_example_loss_fn: Optional[Callable[[tf.Tensor, Any], tf.Tensor]] = None, per_example_loss_fn: Optional[Callable[[tf.Tensor, Any], tf.Tensor]] = None,
num_microbatches: Optional[lr.BatchSize] = None, num_microbatches: Optional[lr.BatchSize] = None,
trainable_vars: Optional[List[tf.Variable]] = None,
): ):
"""Computes the per-example loss gradient norms for given data. """Computes the per-example loss gradient norms for given data.
@ -96,6 +101,10 @@ def compute_gradient_norms(
of num_microbatches). When there is microbatches, we always assume the of num_microbatches). When there is microbatches, we always assume the
loss is the mean over a microbatch. And the gradient norm is computed for loss is the mean over a microbatch. And the gradient norm is computed for
each microbatch. each microbatch.
trainable_vars: The list of variables included in computing the gradient
norm. When a layer has multiple variables, we include all the variables if
any of the variables is in the list. If `trainable_vars` is None, all the
variables are included.
Returns: Returns:
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
@ -126,8 +135,19 @@ def compute_gradient_norms(
# Unwrap the generator outputs so that the next loop avoids duplicating # Unwrap the generator outputs so that the next loop avoids duplicating
# backprop ops. # backprop ops.
filtered_outputs = [t for t in generator_outputs_list if t is not None] filtered_outputs = [t for t in generator_outputs_list if t is not None]
vars_list = [a for (a, b) in filtered_outputs] vars_list = []
sqr_norm_fns_list = [b for (a, b) in filtered_outputs] sqr_norm_fns_list = []
if trainable_vars is not None:
# Create a set using `ref()` for fast set membership check. tf.Variable
# itself is not hashable.
trainable_vars = set([v.ref() for v in trainable_vars])
for v, f, weights_list in filtered_outputs:
if trainable_vars is None or any(
w.ref() in trainable_vars for w in weights_list
):
# Include only those variables in trainable_vars.
vars_list.append(v)
sqr_norm_fns_list.append(f)
# Second loop evaluates the squared L2 norm functions and appends the results. # Second loop evaluates the squared L2 norm functions and appends the results.
grads_list = tape.gradient( grads_list = tape.gradient(
summed_loss, summed_loss,
@ -218,6 +238,7 @@ def compute_clipped_gradients_and_outputs(
y_batch, y_batch,
layer_registry, layer_registry,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
trainable_vars=input_model.trainable_variables,
) )
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms) loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
with tf.GradientTape() as tape: with tf.GradientTape() as tape:

View file

@ -84,6 +84,7 @@ def compute_true_gradient_norms(
y_batch: tf.Tensor, y_batch: tf.Tensor,
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]], per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
num_microbatches: Optional[int], num_microbatches: Optional[int],
trainable_vars: Optional[tf.Variable] = None,
) -> layer_registry.OutputTensor: ) -> layer_registry.OutputTensor:
"""Computes the real gradient norms for an input `(model, x, y)`.""" """Computes the real gradient norms for an input `(model, x, y)`."""
if per_example_loss_fn is None: if per_example_loss_fn is None:
@ -104,7 +105,8 @@ def compute_true_gradient_norms(
if isinstance(loss, tf.RaggedTensor): if isinstance(loss, tf.RaggedTensor):
loss = loss.to_tensor() loss = loss.to_tensor()
sqr_norms = [] sqr_norms = []
for var in input_model.trainable_variables: trainable_vars = trainable_vars or input_model.trainable_variables
for var in trainable_vars:
jacobian = tape.jacobian(loss, var, experimental_use_pfor=False) jacobian = tape.jacobian(loss, var, experimental_use_pfor=False)
reduction_axes = tf.range(1, len(jacobian.shape)) reduction_axes = tf.range(1, len(jacobian.shape))
sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes) sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes)
@ -124,6 +126,7 @@ def get_computed_and_true_norms(
x_input: tf.Tensor, x_input: tf.Tensor,
rng_seed: int = 777, rng_seed: int = 777,
registry: layer_registry.LayerRegistry = None, registry: layer_registry.LayerRegistry = None,
partial: bool = False,
) -> Tuple[tf.Tensor, tf.Tensor]: ) -> Tuple[tf.Tensor, tf.Tensor]:
"""Obtains the true and computed gradient norms for a model and batch input. """Obtains the true and computed gradient norms for a model and batch input.
@ -147,6 +150,8 @@ def get_computed_and_true_norms(
x_input: `tf.Tensor` inputs to be tested. x_input: `tf.Tensor` inputs to be tested.
rng_seed: An `int` used to initialize model weights. rng_seed: An `int` used to initialize model weights.
registry: A `layer_registry.LayerRegistry` instance. registry: A `layer_registry.LayerRegistry` instance.
partial: Whether to compute the gradient norm with respect to a partial set
of varibles. If True, only consider the variables in the first layer.
Returns: Returns:
A `tuple` `(computed_norm, true_norms)`. The first element contains the A `tuple` `(computed_norm, true_norms)`. The first element contains the
@ -163,6 +168,13 @@ def get_computed_and_true_norms(
), ),
run_eagerly=is_eager, run_eagerly=is_eager,
) )
trainable_vars = None
if partial:
# Gets the first layer with variables.
for l in model.layers:
trainable_vars = l.trainable_variables
if trainable_vars:
break
y_pred = model(x_input) y_pred = model(x_input)
y_batch = tf.ones_like(y_pred) y_batch = tf.ones_like(y_pred)
tf.keras.utils.set_random_seed(rng_seed) tf.keras.utils.set_random_seed(rng_seed)
@ -173,10 +185,16 @@ def get_computed_and_true_norms(
layer_registry=registry, layer_registry=registry,
per_example_loss_fn=per_example_loss_fn, per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
trainable_vars=trainable_vars,
) )
tf.keras.utils.set_random_seed(rng_seed) tf.keras.utils.set_random_seed(rng_seed)
true_norms = compute_true_gradient_norms( true_norms = compute_true_gradient_norms(
model, x_input, y_batch, per_example_loss_fn, num_microbatches model,
x_input,
y_batch,
per_example_loss_fn,
num_microbatches,
trainable_vars=trainable_vars,
) )
return (computed_norms, true_norms) return (computed_norms, true_norms)
@ -360,10 +378,11 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
model_name=list(get_dense_model_generators().keys()), model_name=list(get_dense_model_generators().keys()),
layer_name=list(get_dense_layer_generators().keys()), layer_name=list(get_dense_layer_generators().keys()),
input_dim=[4], input_dim=[4],
output_dim=[1, 2], output_dim=[2],
per_example_loss_fn=[None, test_loss_fn], per_example_loss_fn=[None, test_loss_fn],
num_microbatches=[None, 1, 2], num_microbatches=[None, 1, 2],
is_eager=[True, False], is_eager=[True, False],
partial=[True, False],
) )
def test_gradient_norms_on_various_models( def test_gradient_norms_on_various_models(
self, self,
@ -374,6 +393,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
per_example_loss_fn, per_example_loss_fn,
num_microbatches, num_microbatches,
is_eager, is_eager,
partial,
): ):
model_generator = get_dense_model_generators()[model_name] model_generator = get_dense_model_generators()[model_name]
layer_generator = get_dense_layer_generators()[layer_name] layer_generator = get_dense_layer_generators()[layer_name]
@ -399,6 +419,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
is_eager, is_eager,
x_input, x_input,
registry=default_registry, registry=default_registry,
partial=partial,
) )
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
@ -436,8 +457,9 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
model_name=list(get_embedding_model_generators().keys()), model_name=list(get_embedding_model_generators().keys()),
output_dim=[2], output_dim=[2],
per_example_loss_fn=[None, test_loss_fn], per_example_loss_fn=[None, test_loss_fn],
num_microbatches=[None, 1, 2], num_microbatches=[None, 2],
is_eager=[True], is_eager=[True, False],
partial=[True, False],
) )
def test_gradient_norms_on_various_models( def test_gradient_norms_on_various_models(
self, self,
@ -447,6 +469,7 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
per_example_loss_fn, per_example_loss_fn,
num_microbatches, num_microbatches,
is_eager, is_eager,
partial,
): ):
if ( if (
num_microbatches is not None num_microbatches is not None
@ -470,6 +493,7 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
is_eager=is_eager, is_eager=is_eager,
x_input=x_batch, x_input=x_batch,
registry=default_registry, registry=default_registry,
partial=partial,
) )
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
@ -477,11 +501,12 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase): class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product( @parameterized.product(
input_dim=[1, 2], input_dim=[3],
output_dim=[1, 2], output_dim=[2],
per_example_loss_fn=[None, test_loss_fn], per_example_loss_fn=[None, test_loss_fn],
num_microbatches=[None, 1, 2], num_microbatches=[None, 2],
is_eager=[True, False], is_eager=[True, False],
partial=[True, False],
) )
def test_gradient_norms_on_various_models( def test_gradient_norms_on_various_models(
self, self,
@ -490,6 +515,7 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
per_example_loss_fn, per_example_loss_fn,
num_microbatches, num_microbatches,
is_eager, is_eager,
partial,
): ):
registry = layer_registry.make_default_layer_registry() registry = layer_registry.make_default_layer_registry()
registry.insert(DoubleDense, double_dense_layer_computation) registry.insert(DoubleDense, double_dense_layer_computation)
@ -510,6 +536,7 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
is_eager=is_eager, is_eager=is_eager,
x_input=x_batch, x_input=x_batch,
registry=registry, registry=registry,
partial=partial,
) )
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)