Support gradient norm computation with respect to a subset of variables.
PiperOrigin-RevId: 519245638
This commit is contained in:
parent
d5d60e2eac
commit
7796369d8b
3 changed files with 61 additions and 11 deletions
|
@ -35,8 +35,10 @@ py_library(
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "clip_grads_test",
|
name = "clip_grads_test",
|
||||||
|
size = "large",
|
||||||
srcs = ["clip_grads_test.py"],
|
srcs = ["clip_grads_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
|
shard_count = 8,
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
":clip_grads",
|
":clip_grads",
|
||||||
|
|
|
@ -54,7 +54,11 @@ def get_registry_generator_fn(
|
||||||
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
|
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
|
||||||
layer_instance, args, kwargs, tape, num_microbatches
|
layer_instance, args, kwargs, tape, num_microbatches
|
||||||
)
|
)
|
||||||
return layer_outputs, (layer_vars, layer_sqr_norm_fn)
|
return layer_outputs, (
|
||||||
|
layer_vars,
|
||||||
|
layer_sqr_norm_fn,
|
||||||
|
layer_instance.trainable_weights,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Non-trainable layer.
|
# Non-trainable layer.
|
||||||
return layer_instance(*args, **kwargs), None
|
return layer_instance(*args, **kwargs), None
|
||||||
|
@ -69,6 +73,7 @@ def compute_gradient_norms(
|
||||||
layer_registry: lr.LayerRegistry,
|
layer_registry: lr.LayerRegistry,
|
||||||
per_example_loss_fn: Optional[Callable[[tf.Tensor, Any], tf.Tensor]] = None,
|
per_example_loss_fn: Optional[Callable[[tf.Tensor, Any], tf.Tensor]] = None,
|
||||||
num_microbatches: Optional[lr.BatchSize] = None,
|
num_microbatches: Optional[lr.BatchSize] = None,
|
||||||
|
trainable_vars: Optional[List[tf.Variable]] = None,
|
||||||
):
|
):
|
||||||
"""Computes the per-example loss gradient norms for given data.
|
"""Computes the per-example loss gradient norms for given data.
|
||||||
|
|
||||||
|
@ -96,6 +101,10 @@ def compute_gradient_norms(
|
||||||
of num_microbatches). When there is microbatches, we always assume the
|
of num_microbatches). When there is microbatches, we always assume the
|
||||||
loss is the mean over a microbatch. And the gradient norm is computed for
|
loss is the mean over a microbatch. And the gradient norm is computed for
|
||||||
each microbatch.
|
each microbatch.
|
||||||
|
trainable_vars: The list of variables included in computing the gradient
|
||||||
|
norm. When a layer has multiple variables, we include all the variables if
|
||||||
|
any of the variables is in the list. If `trainable_vars` is None, all the
|
||||||
|
variables are included.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
|
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
|
||||||
|
@ -126,8 +135,19 @@ def compute_gradient_norms(
|
||||||
# Unwrap the generator outputs so that the next loop avoids duplicating
|
# Unwrap the generator outputs so that the next loop avoids duplicating
|
||||||
# backprop ops.
|
# backprop ops.
|
||||||
filtered_outputs = [t for t in generator_outputs_list if t is not None]
|
filtered_outputs = [t for t in generator_outputs_list if t is not None]
|
||||||
vars_list = [a for (a, b) in filtered_outputs]
|
vars_list = []
|
||||||
sqr_norm_fns_list = [b for (a, b) in filtered_outputs]
|
sqr_norm_fns_list = []
|
||||||
|
if trainable_vars is not None:
|
||||||
|
# Create a set using `ref()` for fast set membership check. tf.Variable
|
||||||
|
# itself is not hashable.
|
||||||
|
trainable_vars = set([v.ref() for v in trainable_vars])
|
||||||
|
for v, f, weights_list in filtered_outputs:
|
||||||
|
if trainable_vars is None or any(
|
||||||
|
w.ref() in trainable_vars for w in weights_list
|
||||||
|
):
|
||||||
|
# Include only those variables in trainable_vars.
|
||||||
|
vars_list.append(v)
|
||||||
|
sqr_norm_fns_list.append(f)
|
||||||
# Second loop evaluates the squared L2 norm functions and appends the results.
|
# Second loop evaluates the squared L2 norm functions and appends the results.
|
||||||
grads_list = tape.gradient(
|
grads_list = tape.gradient(
|
||||||
summed_loss,
|
summed_loss,
|
||||||
|
@ -218,6 +238,7 @@ def compute_clipped_gradients_and_outputs(
|
||||||
y_batch,
|
y_batch,
|
||||||
layer_registry,
|
layer_registry,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
|
trainable_vars=input_model.trainable_variables,
|
||||||
)
|
)
|
||||||
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
|
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
|
||||||
with tf.GradientTape() as tape:
|
with tf.GradientTape() as tape:
|
||||||
|
|
|
@ -84,6 +84,7 @@ def compute_true_gradient_norms(
|
||||||
y_batch: tf.Tensor,
|
y_batch: tf.Tensor,
|
||||||
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
|
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
|
||||||
num_microbatches: Optional[int],
|
num_microbatches: Optional[int],
|
||||||
|
trainable_vars: Optional[tf.Variable] = None,
|
||||||
) -> layer_registry.OutputTensor:
|
) -> layer_registry.OutputTensor:
|
||||||
"""Computes the real gradient norms for an input `(model, x, y)`."""
|
"""Computes the real gradient norms for an input `(model, x, y)`."""
|
||||||
if per_example_loss_fn is None:
|
if per_example_loss_fn is None:
|
||||||
|
@ -104,7 +105,8 @@ def compute_true_gradient_norms(
|
||||||
if isinstance(loss, tf.RaggedTensor):
|
if isinstance(loss, tf.RaggedTensor):
|
||||||
loss = loss.to_tensor()
|
loss = loss.to_tensor()
|
||||||
sqr_norms = []
|
sqr_norms = []
|
||||||
for var in input_model.trainable_variables:
|
trainable_vars = trainable_vars or input_model.trainable_variables
|
||||||
|
for var in trainable_vars:
|
||||||
jacobian = tape.jacobian(loss, var, experimental_use_pfor=False)
|
jacobian = tape.jacobian(loss, var, experimental_use_pfor=False)
|
||||||
reduction_axes = tf.range(1, len(jacobian.shape))
|
reduction_axes = tf.range(1, len(jacobian.shape))
|
||||||
sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes)
|
sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes)
|
||||||
|
@ -124,6 +126,7 @@ def get_computed_and_true_norms(
|
||||||
x_input: tf.Tensor,
|
x_input: tf.Tensor,
|
||||||
rng_seed: int = 777,
|
rng_seed: int = 777,
|
||||||
registry: layer_registry.LayerRegistry = None,
|
registry: layer_registry.LayerRegistry = None,
|
||||||
|
partial: bool = False,
|
||||||
) -> Tuple[tf.Tensor, tf.Tensor]:
|
) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||||
"""Obtains the true and computed gradient norms for a model and batch input.
|
"""Obtains the true and computed gradient norms for a model and batch input.
|
||||||
|
|
||||||
|
@ -147,6 +150,8 @@ def get_computed_and_true_norms(
|
||||||
x_input: `tf.Tensor` inputs to be tested.
|
x_input: `tf.Tensor` inputs to be tested.
|
||||||
rng_seed: An `int` used to initialize model weights.
|
rng_seed: An `int` used to initialize model weights.
|
||||||
registry: A `layer_registry.LayerRegistry` instance.
|
registry: A `layer_registry.LayerRegistry` instance.
|
||||||
|
partial: Whether to compute the gradient norm with respect to a partial set
|
||||||
|
of varibles. If True, only consider the variables in the first layer.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `tuple` `(computed_norm, true_norms)`. The first element contains the
|
A `tuple` `(computed_norm, true_norms)`. The first element contains the
|
||||||
|
@ -163,6 +168,13 @@ def get_computed_and_true_norms(
|
||||||
),
|
),
|
||||||
run_eagerly=is_eager,
|
run_eagerly=is_eager,
|
||||||
)
|
)
|
||||||
|
trainable_vars = None
|
||||||
|
if partial:
|
||||||
|
# Gets the first layer with variables.
|
||||||
|
for l in model.layers:
|
||||||
|
trainable_vars = l.trainable_variables
|
||||||
|
if trainable_vars:
|
||||||
|
break
|
||||||
y_pred = model(x_input)
|
y_pred = model(x_input)
|
||||||
y_batch = tf.ones_like(y_pred)
|
y_batch = tf.ones_like(y_pred)
|
||||||
tf.keras.utils.set_random_seed(rng_seed)
|
tf.keras.utils.set_random_seed(rng_seed)
|
||||||
|
@ -173,10 +185,16 @@ def get_computed_and_true_norms(
|
||||||
layer_registry=registry,
|
layer_registry=registry,
|
||||||
per_example_loss_fn=per_example_loss_fn,
|
per_example_loss_fn=per_example_loss_fn,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
|
trainable_vars=trainable_vars,
|
||||||
)
|
)
|
||||||
tf.keras.utils.set_random_seed(rng_seed)
|
tf.keras.utils.set_random_seed(rng_seed)
|
||||||
true_norms = compute_true_gradient_norms(
|
true_norms = compute_true_gradient_norms(
|
||||||
model, x_input, y_batch, per_example_loss_fn, num_microbatches
|
model,
|
||||||
|
x_input,
|
||||||
|
y_batch,
|
||||||
|
per_example_loss_fn,
|
||||||
|
num_microbatches,
|
||||||
|
trainable_vars=trainable_vars,
|
||||||
)
|
)
|
||||||
return (computed_norms, true_norms)
|
return (computed_norms, true_norms)
|
||||||
|
|
||||||
|
@ -360,10 +378,11 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
model_name=list(get_dense_model_generators().keys()),
|
model_name=list(get_dense_model_generators().keys()),
|
||||||
layer_name=list(get_dense_layer_generators().keys()),
|
layer_name=list(get_dense_layer_generators().keys()),
|
||||||
input_dim=[4],
|
input_dim=[4],
|
||||||
output_dim=[1, 2],
|
output_dim=[2],
|
||||||
per_example_loss_fn=[None, test_loss_fn],
|
per_example_loss_fn=[None, test_loss_fn],
|
||||||
num_microbatches=[None, 1, 2],
|
num_microbatches=[None, 1, 2],
|
||||||
is_eager=[True, False],
|
is_eager=[True, False],
|
||||||
|
partial=[True, False],
|
||||||
)
|
)
|
||||||
def test_gradient_norms_on_various_models(
|
def test_gradient_norms_on_various_models(
|
||||||
self,
|
self,
|
||||||
|
@ -374,6 +393,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
per_example_loss_fn,
|
per_example_loss_fn,
|
||||||
num_microbatches,
|
num_microbatches,
|
||||||
is_eager,
|
is_eager,
|
||||||
|
partial,
|
||||||
):
|
):
|
||||||
model_generator = get_dense_model_generators()[model_name]
|
model_generator = get_dense_model_generators()[model_name]
|
||||||
layer_generator = get_dense_layer_generators()[layer_name]
|
layer_generator = get_dense_layer_generators()[layer_name]
|
||||||
|
@ -399,6 +419,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
is_eager,
|
is_eager,
|
||||||
x_input,
|
x_input,
|
||||||
registry=default_registry,
|
registry=default_registry,
|
||||||
|
partial=partial,
|
||||||
)
|
)
|
||||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||||
|
|
||||||
|
@ -436,8 +457,9 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
model_name=list(get_embedding_model_generators().keys()),
|
model_name=list(get_embedding_model_generators().keys()),
|
||||||
output_dim=[2],
|
output_dim=[2],
|
||||||
per_example_loss_fn=[None, test_loss_fn],
|
per_example_loss_fn=[None, test_loss_fn],
|
||||||
num_microbatches=[None, 1, 2],
|
num_microbatches=[None, 2],
|
||||||
is_eager=[True],
|
is_eager=[True, False],
|
||||||
|
partial=[True, False],
|
||||||
)
|
)
|
||||||
def test_gradient_norms_on_various_models(
|
def test_gradient_norms_on_various_models(
|
||||||
self,
|
self,
|
||||||
|
@ -447,6 +469,7 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
per_example_loss_fn,
|
per_example_loss_fn,
|
||||||
num_microbatches,
|
num_microbatches,
|
||||||
is_eager,
|
is_eager,
|
||||||
|
partial,
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
num_microbatches is not None
|
num_microbatches is not None
|
||||||
|
@ -470,6 +493,7 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
is_eager=is_eager,
|
is_eager=is_eager,
|
||||||
x_input=x_batch,
|
x_input=x_batch,
|
||||||
registry=default_registry,
|
registry=default_registry,
|
||||||
|
partial=partial,
|
||||||
)
|
)
|
||||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||||
|
|
||||||
|
@ -477,11 +501,12 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@parameterized.product(
|
@parameterized.product(
|
||||||
input_dim=[1, 2],
|
input_dim=[3],
|
||||||
output_dim=[1, 2],
|
output_dim=[2],
|
||||||
per_example_loss_fn=[None, test_loss_fn],
|
per_example_loss_fn=[None, test_loss_fn],
|
||||||
num_microbatches=[None, 1, 2],
|
num_microbatches=[None, 2],
|
||||||
is_eager=[True, False],
|
is_eager=[True, False],
|
||||||
|
partial=[True, False],
|
||||||
)
|
)
|
||||||
def test_gradient_norms_on_various_models(
|
def test_gradient_norms_on_various_models(
|
||||||
self,
|
self,
|
||||||
|
@ -490,6 +515,7 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
per_example_loss_fn,
|
per_example_loss_fn,
|
||||||
num_microbatches,
|
num_microbatches,
|
||||||
is_eager,
|
is_eager,
|
||||||
|
partial,
|
||||||
):
|
):
|
||||||
registry = layer_registry.make_default_layer_registry()
|
registry = layer_registry.make_default_layer_registry()
|
||||||
registry.insert(DoubleDense, double_dense_layer_computation)
|
registry.insert(DoubleDense, double_dense_layer_computation)
|
||||||
|
@ -510,6 +536,7 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
is_eager=is_eager,
|
is_eager=is_eager,
|
||||||
x_input=x_batch,
|
x_input=x_batch,
|
||||||
registry=registry,
|
registry=registry,
|
||||||
|
partial=partial,
|
||||||
)
|
)
|
||||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue