Implement fast gradient clipping for loss functions that use inputs that are fed into shared weights.
PiperOrigin-RevId: 625395017
This commit is contained in:
parent
0582cfdd1a
commit
44dfac3770
3 changed files with 134 additions and 13 deletions
|
@ -87,6 +87,7 @@ py_test(
|
||||||
name = "clip_grads_test",
|
name = "clip_grads_test",
|
||||||
srcs = ["clip_grads_test.py"],
|
srcs = ["clip_grads_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
|
shard_count = 8,
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
":clip_grads",
|
":clip_grads",
|
||||||
|
|
|
@ -21,6 +21,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
|
||||||
`compute_gradient_norms()` function).
|
`compute_gradient_norms()` function).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import collections
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -56,6 +57,7 @@ def get_registry_generator_fn(
|
||||||
layer_instance, args, kwargs, tape, num_microbatches
|
layer_instance, args, kwargs, tape, num_microbatches
|
||||||
)
|
)
|
||||||
return layer_outputs, (
|
return layer_outputs, (
|
||||||
|
str(id(layer_instance)),
|
||||||
layer_vars,
|
layer_vars,
|
||||||
layer_sqr_norm_fn,
|
layer_sqr_norm_fn,
|
||||||
layer_instance.trainable_weights,
|
layer_instance.trainable_weights,
|
||||||
|
@ -156,32 +158,47 @@ def compute_gradient_norms(
|
||||||
# Unwrap the generator outputs so that the next loop avoids duplicating
|
# Unwrap the generator outputs so that the next loop avoids duplicating
|
||||||
# backprop ops.
|
# backprop ops.
|
||||||
filtered_outputs = [t for t in generator_outputs_list if t is not None]
|
filtered_outputs = [t for t in generator_outputs_list if t is not None]
|
||||||
vars_list = []
|
|
||||||
sqr_norm_fns_list = []
|
|
||||||
if trainable_vars is not None:
|
if trainable_vars is not None:
|
||||||
# Create a set using `ref()` for fast set membership check. tf.Variable
|
# Create a set using `ref()` for fast set membership check. tf.Variable
|
||||||
# itself is not hashable.
|
# itself is not hashable.
|
||||||
trainable_vars = set([v.ref() for v in trainable_vars])
|
trainable_vars = set([v.ref() for v in trainable_vars])
|
||||||
for v, f, weights_list in filtered_outputs:
|
layer_vars = collections.defaultdict(list)
|
||||||
|
layer_sqr_norm_fns = collections.defaultdict(list)
|
||||||
|
# The case of shared weights:
|
||||||
|
# If a layer is called k times, it will appear k times in filtered_outputs,
|
||||||
|
# with the same id, but potentially with different v and f. The code below
|
||||||
|
# groups filtered_outputs by layer_id, so we can correctly compute gradient
|
||||||
|
# norms. The gradient norm of a layer that occurs k times is computed as
|
||||||
|
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
|
||||||
|
# occurrence. This is an over-estimate of the actual norm. For more details,
|
||||||
|
# see the explanation in go/dp-sgd-shared-weights.
|
||||||
|
for layer_id, v, f, weights_list in filtered_outputs:
|
||||||
if trainable_vars is None or any(
|
if trainable_vars is None or any(
|
||||||
w.ref() in trainable_vars for w in weights_list
|
w.ref() in trainable_vars for w in weights_list
|
||||||
):
|
):
|
||||||
# Include only those variables in trainable_vars.
|
layer_vars[layer_id].append(v)
|
||||||
vars_list.append(v)
|
layer_sqr_norm_fns[layer_id].append(f)
|
||||||
sqr_norm_fns_list.append(f)
|
|
||||||
# Second loop evaluates the squared L2 norm functions and appends the results.
|
# Second loop evaluates the squared L2 norm functions and appends the results.
|
||||||
grads_list = tape.gradient(
|
layer_grad_vars = tape.gradient(
|
||||||
summed_loss,
|
summed_loss,
|
||||||
vars_list,
|
layer_vars,
|
||||||
unconnected_gradients=tf.UnconnectedGradients.ZERO,
|
unconnected_gradients=tf.UnconnectedGradients.ZERO,
|
||||||
)
|
)
|
||||||
if not grads_list:
|
if not layer_grad_vars:
|
||||||
raise ValueError('The gradient list cannot be empty.')
|
raise ValueError('The gradient list cannot be empty.')
|
||||||
if len(grads_list) != len(sqr_norm_fns_list):
|
|
||||||
raise ValueError('There must be as many norms as gradients.')
|
|
||||||
sqr_norm_list = []
|
sqr_norm_list = []
|
||||||
for grads, f in zip(grads_list, sqr_norm_fns_list):
|
for layer_id in layer_sqr_norm_fns.keys():
|
||||||
sqr_norm_list.append(f(grads))
|
fns = layer_sqr_norm_fns[layer_id]
|
||||||
|
grads = layer_grad_vars[layer_id]
|
||||||
|
# Number of duplicates for this layer in `filtered_outputs`.
|
||||||
|
num_passes = len(fns)
|
||||||
|
if len(fns) != len(grads):
|
||||||
|
raise ValueError(
|
||||||
|
'There must be as many gradients as squared norm functions.'
|
||||||
|
)
|
||||||
|
# See go/dp-sgd-shared-weights for more details.
|
||||||
|
for fn, grad in zip(fns, grads):
|
||||||
|
sqr_norm_list.append(num_passes * fn(grad))
|
||||||
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
|
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
|
||||||
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
||||||
|
|
||||||
|
|
|
@ -197,5 +197,108 @@ class ComputeClippedGradsAndOutputsTest(
|
||||||
self.assertAlmostEqual(computed_norm, true_norm)
|
self.assertAlmostEqual(computed_norm, true_norm)
|
||||||
|
|
||||||
|
|
||||||
|
class SharedLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def _make_shared_model(self, num_inputs, input_dim):
|
||||||
|
base_model = tf.keras.Sequential([tf.keras.layers.Dense(1, use_bias=False)])
|
||||||
|
inputs = []
|
||||||
|
outputs = []
|
||||||
|
for _ in range(num_inputs):
|
||||||
|
input_tensor = tf.keras.Input(shape=[input_dim])
|
||||||
|
inputs.append(input_tensor)
|
||||||
|
output_tensor = base_model(input_tensor)
|
||||||
|
outputs.append(output_tensor)
|
||||||
|
return tf.keras.Model(inputs=inputs, outputs=tf.add_n(outputs))
|
||||||
|
|
||||||
|
def _get_computed_and_true_norms(self, model, x_batch, y_batch, is_eager):
|
||||||
|
model.compile(
|
||||||
|
loss=tf.keras.losses.MeanSquaredError(reduction='none'),
|
||||||
|
run_eagerly=is_eager,
|
||||||
|
)
|
||||||
|
computed_norms = clip_grads.compute_gradient_norms(
|
||||||
|
model, layer_registry.make_default_layer_registry(), x_batch, y_batch
|
||||||
|
)
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
y_pred = model(x_batch)
|
||||||
|
loss_value = model.loss(y_pred, y_batch)
|
||||||
|
true_grads = tape.jacobian(loss_value, model.trainable_variables)
|
||||||
|
true_norms = tf.sqrt(
|
||||||
|
tf.add_n([tf.reduce_sum(tf.square(g), axis=[1, 2]) for g in true_grads])
|
||||||
|
)
|
||||||
|
return computed_norms, true_norms
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
num_inputs=[1, 2, 10],
|
||||||
|
batch_size=[1, 2],
|
||||||
|
input_dim=[1, 3],
|
||||||
|
is_eager=[True, False],
|
||||||
|
)
|
||||||
|
def test_gradient_norms_on_multiple_inputs_are_upper_bounded(
|
||||||
|
self, num_inputs, batch_size, input_dim, is_eager
|
||||||
|
):
|
||||||
|
model = self._make_shared_model(num_inputs, input_dim)
|
||||||
|
model.compile(
|
||||||
|
loss=tf.keras.losses.MeanSquaredError(reduction='none'),
|
||||||
|
run_eagerly=is_eager,
|
||||||
|
)
|
||||||
|
x_batch = [
|
||||||
|
float(k + 1) * tf.ones([batch_size, input_dim], dtype=tf.float64)
|
||||||
|
for k in range(num_inputs)
|
||||||
|
]
|
||||||
|
y_batch = tf.reshape(
|
||||||
|
1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1]
|
||||||
|
)
|
||||||
|
computed_norms, true_norms = self._get_computed_and_true_norms(
|
||||||
|
model, x_batch, y_batch, is_eager
|
||||||
|
)
|
||||||
|
self.assertAllLessEqual(true_norms - computed_norms, 1e-3)
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
num_repeats=[1, 2, 10],
|
||||||
|
batch_size=[1, 2],
|
||||||
|
input_dim=[1, 3],
|
||||||
|
is_eager=[True, False],
|
||||||
|
)
|
||||||
|
def test_gradient_norms_on_single_repeated_input_are_upper_bounded(
|
||||||
|
self, num_repeats, batch_size, input_dim, is_eager
|
||||||
|
):
|
||||||
|
base_model = tf.keras.Sequential([tf.keras.layers.Dense(1, use_bias=False)])
|
||||||
|
inputs = tf.keras.layers.Input([input_dim])
|
||||||
|
outputs = tf.add_n([base_model(inputs) for _ in range(num_repeats)])
|
||||||
|
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
x_batch = tf.ones([batch_size, input_dim], dtype=tf.float64)
|
||||||
|
y_batch = tf.reshape(
|
||||||
|
1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1]
|
||||||
|
)
|
||||||
|
computed_norms, true_norms = self._get_computed_and_true_norms(
|
||||||
|
model, x_batch, y_batch, is_eager
|
||||||
|
)
|
||||||
|
self.assertAllLessEqual(true_norms - computed_norms, 1e-3)
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
batch_size=[1, 2],
|
||||||
|
input_dim=[1, 3],
|
||||||
|
is_eager=[True, False],
|
||||||
|
)
|
||||||
|
def test_gradient_norms_on_input_slices_are_upper_bounded(
|
||||||
|
self, batch_size, input_dim, is_eager
|
||||||
|
):
|
||||||
|
base_model = tf.keras.Sequential([tf.keras.layers.Dense(1, use_bias=False)])
|
||||||
|
inputs = tf.keras.layers.Input([input_dim, 2])
|
||||||
|
outputs = base_model(inputs[:, :, 0]) + base_model(inputs[:, :, 1])
|
||||||
|
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
x_batch = tf.reshape(
|
||||||
|
tf.range(batch_size * input_dim * 2, dtype=tf.float64),
|
||||||
|
[batch_size, input_dim, -1],
|
||||||
|
)
|
||||||
|
y_batch = tf.reshape(
|
||||||
|
1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1]
|
||||||
|
)
|
||||||
|
computed_norms, true_norms = self._get_computed_and_true_norms(
|
||||||
|
model, x_batch, y_batch, is_eager
|
||||||
|
)
|
||||||
|
self.assertAllLessEqual(true_norms - computed_norms, 1e-3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue