Implement fast gradient clipping for loss functions that use inputs that are fed into shared weights.

PiperOrigin-RevId: 625395017
This commit is contained in:
William Kong 2024-04-16 11:18:49 -07:00 committed by A. Unique TensorFlower
parent 0582cfdd1a
commit 44dfac3770
3 changed files with 134 additions and 13 deletions

View file

@ -87,6 +87,7 @@ py_test(
name = "clip_grads_test", name = "clip_grads_test",
srcs = ["clip_grads_test.py"], srcs = ["clip_grads_test.py"],
python_version = "PY3", python_version = "PY3",
shard_count = 8,
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
":clip_grads", ":clip_grads",

View file

@ -21,6 +21,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
`compute_gradient_norms()` function). `compute_gradient_norms()` function).
""" """
import collections
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional from typing import Optional
@ -56,6 +57,7 @@ def get_registry_generator_fn(
layer_instance, args, kwargs, tape, num_microbatches layer_instance, args, kwargs, tape, num_microbatches
) )
return layer_outputs, ( return layer_outputs, (
str(id(layer_instance)),
layer_vars, layer_vars,
layer_sqr_norm_fn, layer_sqr_norm_fn,
layer_instance.trainable_weights, layer_instance.trainable_weights,
@ -156,32 +158,47 @@ def compute_gradient_norms(
# Unwrap the generator outputs so that the next loop avoids duplicating # Unwrap the generator outputs so that the next loop avoids duplicating
# backprop ops. # backprop ops.
filtered_outputs = [t for t in generator_outputs_list if t is not None] filtered_outputs = [t for t in generator_outputs_list if t is not None]
vars_list = []
sqr_norm_fns_list = []
if trainable_vars is not None: if trainable_vars is not None:
# Create a set using `ref()` for fast set membership check. tf.Variable # Create a set using `ref()` for fast set membership check. tf.Variable
# itself is not hashable. # itself is not hashable.
trainable_vars = set([v.ref() for v in trainable_vars]) trainable_vars = set([v.ref() for v in trainable_vars])
for v, f, weights_list in filtered_outputs: layer_vars = collections.defaultdict(list)
layer_sqr_norm_fns = collections.defaultdict(list)
# The case of shared weights:
# If a layer is called k times, it will appear k times in filtered_outputs,
# with the same id, but potentially with different v and f. The code below
# groups filtered_outputs by layer_id, so we can correctly compute gradient
# norms. The gradient norm of a layer that occurs k times is computed as
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
# occurrence. This is an over-estimate of the actual norm. For more details,
# see the explanation in go/dp-sgd-shared-weights.
for layer_id, v, f, weights_list in filtered_outputs:
if trainable_vars is None or any( if trainable_vars is None or any(
w.ref() in trainable_vars for w in weights_list w.ref() in trainable_vars for w in weights_list
): ):
# Include only those variables in trainable_vars. layer_vars[layer_id].append(v)
vars_list.append(v) layer_sqr_norm_fns[layer_id].append(f)
sqr_norm_fns_list.append(f)
# Second loop evaluates the squared L2 norm functions and appends the results. # Second loop evaluates the squared L2 norm functions and appends the results.
grads_list = tape.gradient( layer_grad_vars = tape.gradient(
summed_loss, summed_loss,
vars_list, layer_vars,
unconnected_gradients=tf.UnconnectedGradients.ZERO, unconnected_gradients=tf.UnconnectedGradients.ZERO,
) )
if not grads_list: if not layer_grad_vars:
raise ValueError('The gradient list cannot be empty.') raise ValueError('The gradient list cannot be empty.')
if len(grads_list) != len(sqr_norm_fns_list):
raise ValueError('There must be as many norms as gradients.')
sqr_norm_list = [] sqr_norm_list = []
for grads, f in zip(grads_list, sqr_norm_fns_list): for layer_id in layer_sqr_norm_fns.keys():
sqr_norm_list.append(f(grads)) fns = layer_sqr_norm_fns[layer_id]
grads = layer_grad_vars[layer_id]
# Number of duplicates for this layer in `filtered_outputs`.
num_passes = len(fns)
if len(fns) != len(grads):
raise ValueError(
'There must be as many gradients as squared norm functions.'
)
# See go/dp-sgd-shared-weights for more details.
for fn, grad in zip(fns, grads):
sqr_norm_list.append(num_passes * fn(grad))
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1) sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))

View file

@ -197,5 +197,108 @@ class ComputeClippedGradsAndOutputsTest(
self.assertAlmostEqual(computed_norm, true_norm) self.assertAlmostEqual(computed_norm, true_norm)
class SharedLayerTest(tf.test.TestCase, parameterized.TestCase):
def _make_shared_model(self, num_inputs, input_dim):
base_model = tf.keras.Sequential([tf.keras.layers.Dense(1, use_bias=False)])
inputs = []
outputs = []
for _ in range(num_inputs):
input_tensor = tf.keras.Input(shape=[input_dim])
inputs.append(input_tensor)
output_tensor = base_model(input_tensor)
outputs.append(output_tensor)
return tf.keras.Model(inputs=inputs, outputs=tf.add_n(outputs))
def _get_computed_and_true_norms(self, model, x_batch, y_batch, is_eager):
model.compile(
loss=tf.keras.losses.MeanSquaredError(reduction='none'),
run_eagerly=is_eager,
)
computed_norms = clip_grads.compute_gradient_norms(
model, layer_registry.make_default_layer_registry(), x_batch, y_batch
)
with tf.GradientTape() as tape:
y_pred = model(x_batch)
loss_value = model.loss(y_pred, y_batch)
true_grads = tape.jacobian(loss_value, model.trainable_variables)
true_norms = tf.sqrt(
tf.add_n([tf.reduce_sum(tf.square(g), axis=[1, 2]) for g in true_grads])
)
return computed_norms, true_norms
@parameterized.product(
num_inputs=[1, 2, 10],
batch_size=[1, 2],
input_dim=[1, 3],
is_eager=[True, False],
)
def test_gradient_norms_on_multiple_inputs_are_upper_bounded(
self, num_inputs, batch_size, input_dim, is_eager
):
model = self._make_shared_model(num_inputs, input_dim)
model.compile(
loss=tf.keras.losses.MeanSquaredError(reduction='none'),
run_eagerly=is_eager,
)
x_batch = [
float(k + 1) * tf.ones([batch_size, input_dim], dtype=tf.float64)
for k in range(num_inputs)
]
y_batch = tf.reshape(
1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1]
)
computed_norms, true_norms = self._get_computed_and_true_norms(
model, x_batch, y_batch, is_eager
)
self.assertAllLessEqual(true_norms - computed_norms, 1e-3)
@parameterized.product(
num_repeats=[1, 2, 10],
batch_size=[1, 2],
input_dim=[1, 3],
is_eager=[True, False],
)
def test_gradient_norms_on_single_repeated_input_are_upper_bounded(
self, num_repeats, batch_size, input_dim, is_eager
):
base_model = tf.keras.Sequential([tf.keras.layers.Dense(1, use_bias=False)])
inputs = tf.keras.layers.Input([input_dim])
outputs = tf.add_n([base_model(inputs) for _ in range(num_repeats)])
model = tf.keras.Model(inputs=inputs, outputs=outputs)
x_batch = tf.ones([batch_size, input_dim], dtype=tf.float64)
y_batch = tf.reshape(
1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1]
)
computed_norms, true_norms = self._get_computed_and_true_norms(
model, x_batch, y_batch, is_eager
)
self.assertAllLessEqual(true_norms - computed_norms, 1e-3)
@parameterized.product(
batch_size=[1, 2],
input_dim=[1, 3],
is_eager=[True, False],
)
def test_gradient_norms_on_input_slices_are_upper_bounded(
self, batch_size, input_dim, is_eager
):
base_model = tf.keras.Sequential([tf.keras.layers.Dense(1, use_bias=False)])
inputs = tf.keras.layers.Input([input_dim, 2])
outputs = base_model(inputs[:, :, 0]) + base_model(inputs[:, :, 1])
model = tf.keras.Model(inputs=inputs, outputs=outputs)
x_batch = tf.reshape(
tf.range(batch_size * input_dim * 2, dtype=tf.float64),
[batch_size, input_dim, -1],
)
y_batch = tf.reshape(
1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1]
)
computed_norms, true_norms = self._get_computed_and_true_norms(
model, x_batch, y_batch, is_eager
)
self.assertAllLessEqual(true_norms - computed_norms, 1e-3)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()