Support gradient norm computation with respect to a subset of variables.
PiperOrigin-RevId: 519245638
This commit is contained in:
parent
d5d60e2eac
commit
7796369d8b
3 changed files with 61 additions and 11 deletions
|
@ -35,8 +35,10 @@ py_library(
|
|||
|
||||
py_test(
|
||||
name = "clip_grads_test",
|
||||
size = "large",
|
||||
srcs = ["clip_grads_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 8,
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":clip_grads",
|
||||
|
|
|
@ -54,7 +54,11 @@ def get_registry_generator_fn(
|
|||
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
|
||||
layer_instance, args, kwargs, tape, num_microbatches
|
||||
)
|
||||
return layer_outputs, (layer_vars, layer_sqr_norm_fn)
|
||||
return layer_outputs, (
|
||||
layer_vars,
|
||||
layer_sqr_norm_fn,
|
||||
layer_instance.trainable_weights,
|
||||
)
|
||||
else:
|
||||
# Non-trainable layer.
|
||||
return layer_instance(*args, **kwargs), None
|
||||
|
@ -69,6 +73,7 @@ def compute_gradient_norms(
|
|||
layer_registry: lr.LayerRegistry,
|
||||
per_example_loss_fn: Optional[Callable[[tf.Tensor, Any], tf.Tensor]] = None,
|
||||
num_microbatches: Optional[lr.BatchSize] = None,
|
||||
trainable_vars: Optional[List[tf.Variable]] = None,
|
||||
):
|
||||
"""Computes the per-example loss gradient norms for given data.
|
||||
|
||||
|
@ -96,6 +101,10 @@ def compute_gradient_norms(
|
|||
of num_microbatches). When there is microbatches, we always assume the
|
||||
loss is the mean over a microbatch. And the gradient norm is computed for
|
||||
each microbatch.
|
||||
trainable_vars: The list of variables included in computing the gradient
|
||||
norm. When a layer has multiple variables, we include all the variables if
|
||||
any of the variables is in the list. If `trainable_vars` is None, all the
|
||||
variables are included.
|
||||
|
||||
Returns:
|
||||
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
|
||||
|
@ -126,8 +135,19 @@ def compute_gradient_norms(
|
|||
# Unwrap the generator outputs so that the next loop avoids duplicating
|
||||
# backprop ops.
|
||||
filtered_outputs = [t for t in generator_outputs_list if t is not None]
|
||||
vars_list = [a for (a, b) in filtered_outputs]
|
||||
sqr_norm_fns_list = [b for (a, b) in filtered_outputs]
|
||||
vars_list = []
|
||||
sqr_norm_fns_list = []
|
||||
if trainable_vars is not None:
|
||||
# Create a set using `ref()` for fast set membership check. tf.Variable
|
||||
# itself is not hashable.
|
||||
trainable_vars = set([v.ref() for v in trainable_vars])
|
||||
for v, f, weights_list in filtered_outputs:
|
||||
if trainable_vars is None or any(
|
||||
w.ref() in trainable_vars for w in weights_list
|
||||
):
|
||||
# Include only those variables in trainable_vars.
|
||||
vars_list.append(v)
|
||||
sqr_norm_fns_list.append(f)
|
||||
# Second loop evaluates the squared L2 norm functions and appends the results.
|
||||
grads_list = tape.gradient(
|
||||
summed_loss,
|
||||
|
@ -218,6 +238,7 @@ def compute_clipped_gradients_and_outputs(
|
|||
y_batch,
|
||||
layer_registry,
|
||||
num_microbatches=num_microbatches,
|
||||
trainable_vars=input_model.trainable_variables,
|
||||
)
|
||||
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
|
||||
with tf.GradientTape() as tape:
|
||||
|
|
|
@ -84,6 +84,7 @@ def compute_true_gradient_norms(
|
|||
y_batch: tf.Tensor,
|
||||
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
|
||||
num_microbatches: Optional[int],
|
||||
trainable_vars: Optional[tf.Variable] = None,
|
||||
) -> layer_registry.OutputTensor:
|
||||
"""Computes the real gradient norms for an input `(model, x, y)`."""
|
||||
if per_example_loss_fn is None:
|
||||
|
@ -104,7 +105,8 @@ def compute_true_gradient_norms(
|
|||
if isinstance(loss, tf.RaggedTensor):
|
||||
loss = loss.to_tensor()
|
||||
sqr_norms = []
|
||||
for var in input_model.trainable_variables:
|
||||
trainable_vars = trainable_vars or input_model.trainable_variables
|
||||
for var in trainable_vars:
|
||||
jacobian = tape.jacobian(loss, var, experimental_use_pfor=False)
|
||||
reduction_axes = tf.range(1, len(jacobian.shape))
|
||||
sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes)
|
||||
|
@ -124,6 +126,7 @@ def get_computed_and_true_norms(
|
|||
x_input: tf.Tensor,
|
||||
rng_seed: int = 777,
|
||||
registry: layer_registry.LayerRegistry = None,
|
||||
partial: bool = False,
|
||||
) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||
"""Obtains the true and computed gradient norms for a model and batch input.
|
||||
|
||||
|
@ -147,6 +150,8 @@ def get_computed_and_true_norms(
|
|||
x_input: `tf.Tensor` inputs to be tested.
|
||||
rng_seed: An `int` used to initialize model weights.
|
||||
registry: A `layer_registry.LayerRegistry` instance.
|
||||
partial: Whether to compute the gradient norm with respect to a partial set
|
||||
of varibles. If True, only consider the variables in the first layer.
|
||||
|
||||
Returns:
|
||||
A `tuple` `(computed_norm, true_norms)`. The first element contains the
|
||||
|
@ -163,6 +168,13 @@ def get_computed_and_true_norms(
|
|||
),
|
||||
run_eagerly=is_eager,
|
||||
)
|
||||
trainable_vars = None
|
||||
if partial:
|
||||
# Gets the first layer with variables.
|
||||
for l in model.layers:
|
||||
trainable_vars = l.trainable_variables
|
||||
if trainable_vars:
|
||||
break
|
||||
y_pred = model(x_input)
|
||||
y_batch = tf.ones_like(y_pred)
|
||||
tf.keras.utils.set_random_seed(rng_seed)
|
||||
|
@ -173,10 +185,16 @@ def get_computed_and_true_norms(
|
|||
layer_registry=registry,
|
||||
per_example_loss_fn=per_example_loss_fn,
|
||||
num_microbatches=num_microbatches,
|
||||
trainable_vars=trainable_vars,
|
||||
)
|
||||
tf.keras.utils.set_random_seed(rng_seed)
|
||||
true_norms = compute_true_gradient_norms(
|
||||
model, x_input, y_batch, per_example_loss_fn, num_microbatches
|
||||
model,
|
||||
x_input,
|
||||
y_batch,
|
||||
per_example_loss_fn,
|
||||
num_microbatches,
|
||||
trainable_vars=trainable_vars,
|
||||
)
|
||||
return (computed_norms, true_norms)
|
||||
|
||||
|
@ -360,10 +378,11 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
model_name=list(get_dense_model_generators().keys()),
|
||||
layer_name=list(get_dense_layer_generators().keys()),
|
||||
input_dim=[4],
|
||||
output_dim=[1, 2],
|
||||
output_dim=[2],
|
||||
per_example_loss_fn=[None, test_loss_fn],
|
||||
num_microbatches=[None, 1, 2],
|
||||
is_eager=[True, False],
|
||||
partial=[True, False],
|
||||
)
|
||||
def test_gradient_norms_on_various_models(
|
||||
self,
|
||||
|
@ -374,6 +393,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
per_example_loss_fn,
|
||||
num_microbatches,
|
||||
is_eager,
|
||||
partial,
|
||||
):
|
||||
model_generator = get_dense_model_generators()[model_name]
|
||||
layer_generator = get_dense_layer_generators()[layer_name]
|
||||
|
@ -399,6 +419,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
is_eager,
|
||||
x_input,
|
||||
registry=default_registry,
|
||||
partial=partial,
|
||||
)
|
||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
@ -436,8 +457,9 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
model_name=list(get_embedding_model_generators().keys()),
|
||||
output_dim=[2],
|
||||
per_example_loss_fn=[None, test_loss_fn],
|
||||
num_microbatches=[None, 1, 2],
|
||||
is_eager=[True],
|
||||
num_microbatches=[None, 2],
|
||||
is_eager=[True, False],
|
||||
partial=[True, False],
|
||||
)
|
||||
def test_gradient_norms_on_various_models(
|
||||
self,
|
||||
|
@ -447,6 +469,7 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
per_example_loss_fn,
|
||||
num_microbatches,
|
||||
is_eager,
|
||||
partial,
|
||||
):
|
||||
if (
|
||||
num_microbatches is not None
|
||||
|
@ -470,6 +493,7 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
is_eager=is_eager,
|
||||
x_input=x_batch,
|
||||
registry=default_registry,
|
||||
partial=partial,
|
||||
)
|
||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
@ -477,11 +501,12 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.product(
|
||||
input_dim=[1, 2],
|
||||
output_dim=[1, 2],
|
||||
input_dim=[3],
|
||||
output_dim=[2],
|
||||
per_example_loss_fn=[None, test_loss_fn],
|
||||
num_microbatches=[None, 1, 2],
|
||||
num_microbatches=[None, 2],
|
||||
is_eager=[True, False],
|
||||
partial=[True, False],
|
||||
)
|
||||
def test_gradient_norms_on_various_models(
|
||||
self,
|
||||
|
@ -490,6 +515,7 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
per_example_loss_fn,
|
||||
num_microbatches,
|
||||
is_eager,
|
||||
partial,
|
||||
):
|
||||
registry = layer_registry.make_default_layer_registry()
|
||||
registry.insert(DoubleDense, double_dense_layer_computation)
|
||||
|
@ -510,6 +536,7 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
is_eager=is_eager,
|
||||
x_input=x_batch,
|
||||
registry=registry,
|
||||
partial=partial,
|
||||
)
|
||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
|
Loading…
Reference in a new issue