Add better tests for clip_grads.py

PiperOrigin-RevId: 509529435
This commit is contained in:
A. Unique TensorFlower 2023-02-14 08:01:24 -08:00
parent 430f103354
commit 13534e5159

View file

@ -36,14 +36,20 @@ def compute_true_gradient_norms(input_model, x_batch, y_batch):
sqr_norms = []
for var in input_model.trainable_variables:
jacobian = tape.jacobian(loss, var, experimental_use_pfor=False)
reduction_axes = tf.range(1, tf.rank(jacobian))
reduction_axes = tf.range(1, len(jacobian.shape))
sqr_norms.append(tf.reduce_sum(tf.square(jacobian), axis=reduction_axes))
sqr_norm_tsr = tf.stack(sqr_norms, axis=1)
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
def get_computed_and_true_norms(
model_generator, layer_generator, input_dim, output_dim, is_eager, x_input
model_generator,
layer_generator,
input_dims,
output_dim,
is_eager,
x_input,
rng_seed=777,
):
"""Obtains the true and computed gradient norms for a model and batch input.
@ -58,10 +64,11 @@ def get_computed_and_true_norms(
layer_generator: A function which takes in two arguments: `idim` and `odim`.
Returns a `tf.keras.layers.Layer` that accepts input tensors of dimension
`idim` and returns output tensors of dimension `odim`.
input_dim: The input dimension of the test `tf.keras.Model` instance.
input_dims: The input dimension(s) of the test `tf.keras.Model` instance.
output_dim: The output dimension of the test `tf.keras.Model` instance.
is_eager: A `bool` that is `True` if the model should be run eagerly.
x_input: `tf.Tensor` inputs to be tested.
rng_seed: An `int` used to initialize model weights.
Returns:
A `tuple` `(computed_norm, true_norms)`. The first element contains the
@ -70,8 +77,7 @@ def get_computed_and_true_norms(
model and layer generators. The second element contains the true clipped
gradient norms under the aforementioned setting.
"""
tf.keras.utils.set_random_seed(777)
model = model_generator(layer_generator, input_dim, output_dim)
model = model_generator(layer_generator, input_dims, output_dim)
model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=1.0),
loss=tf.keras.losses.MeanSquaredError(
@ -82,9 +88,11 @@ def get_computed_and_true_norms(
y_pred = model(x_input)
y_batch = tf.ones_like(y_pred)
registry = layer_registry.make_default_layer_registry()
tf.keras.utils.set_random_seed(rng_seed)
computed_norms = clip_grads.compute_gradient_norms(
model, x_input, y_batch, layer_registry=registry
)
tf.keras.utils.set_random_seed(rng_seed)
true_norms = compute_true_gradient_norms(model, x_input, y_batch)
return (computed_norms, true_norms)
@ -140,27 +148,58 @@ def make_two_tower_model(layer_generator, input_dim, output_dim):
return tf.keras.Model(inputs=[inputs1, inputs2], outputs=outputs)
def make_bow_model(layer_generator, input_dim, output_dim):
def make_bow_model(layer_generator, input_dims, output_dim):
del layer_generator
inputs = tf.keras.Input(shape=(input_dim,))
inputs = tf.keras.Input(shape=input_dims)
# For the Embedding layer, input_dim is the vocabulary size. This should
# be distinguished from the input_dim argument, which is the number of ids
# in eache example.
emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim)
feature_embs = emb_layer(inputs)
example_embs = tf.reduce_sum(feature_embs, axis=1)
reduction_axes = tf.range(1, len(feature_embs.shape))
example_embs = tf.expand_dims(
tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1
)
return tf.keras.Model(inputs=inputs, outputs=example_embs)
def make_dense_bow_model(layer_generator, input_dim, output_dim):
def make_dense_bow_model(layer_generator, input_dims, output_dim):
del layer_generator
inputs = tf.keras.Input(shape=(input_dim,))
inputs = tf.keras.Input(shape=input_dims)
# For the Embedding layer, input_dim is the vocabulary size. This should
# be distinguished from the input_dim argument, which is the number of ids
# in eache example.
emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim)
cardinality = 10
emb_layer = tf.keras.layers.Embedding(
input_dim=cardinality, output_dim=output_dim
)
feature_embs = emb_layer(inputs)
example_embs = tf.reduce_sum(feature_embs, axis=1)
reduction_axes = tf.range(1, len(feature_embs.shape))
example_embs = tf.expand_dims(
tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1
)
outputs = tf.keras.layers.Dense(1)(example_embs)
return tf.keras.Model(inputs=inputs, outputs=outputs)
def make_weighted_bow_model(layer_generator, input_dims, output_dim):
# NOTE: This model only accepts dense input tensors.
del layer_generator
inputs = tf.keras.Input(shape=input_dims)
# For the Embedding layer, input_dim is the vocabulary size. This should
# be distinguished from the input_dim argument, which is the number of ids
# in eache example.
cardinality = 10
emb_layer = tf.keras.layers.Embedding(
input_dim=cardinality, output_dim=output_dim
)
feature_embs = emb_layer(inputs)
feature_weights = tf.random.uniform(tf.shape(feature_embs))
weighted_embs = feature_embs * feature_weights
reduction_axes = tf.range(1, len(weighted_embs.shape))
example_embs = tf.expand_dims(
tf.reduce_sum(weighted_embs, axis=reduction_axes), axis=-1
)
outputs = tf.keras.layers.Dense(1)(example_embs)
return tf.keras.Model(inputs=inputs, outputs=outputs)
@ -211,6 +250,7 @@ def get_embedding_model_generators():
return {
'bow1': make_bow_model,
'bow2': make_dense_bow_model,
'weighted_bow1': make_weighted_bow_model,
}
@ -267,10 +307,20 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
# supports them for embeddings.
@parameterized.product(
x_batch=[
tf.convert_to_tensor([[0, 1], [1, 0]], dtype_hint=tf.int32),
# 2D inputs.
tf.convert_to_tensor([[0, 1]], dtype_hint=tf.int32),
tf.convert_to_tensor([[0, 1], [1, 1], [0, 0]], dtype_hint=tf.int32),
tf.ragged.constant(
[[0], [1], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32
[[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32
),
# 3D inputs.
tf.convert_to_tensor([[[0, 1]]], dtype_hint=tf.int32),
tf.convert_to_tensor(
[[[0, 1]], [[1, 1]], [[0, 0]]], dtype_hint=tf.int32
),
tf.ragged.constant(
[[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]],
dtype=tf.int32,
),
],
model_name=list(get_embedding_model_generators().keys()),
@ -280,16 +330,21 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
def test_gradient_norms_on_various_models(
self, x_batch, model_name, output_dim, is_eager
):
model_generator = get_embedding_model_generators()[model_name]
(computed_norms, true_norms) = get_computed_and_true_norms(
model_generator=model_generator,
layer_generator=None,
input_dim=2,
output_dim=output_dim,
is_eager=is_eager,
x_input=x_batch,
)
self.assertAllClose(computed_norms, true_norms)
valid_test_input = (
not isinstance(x_batch, tf.RaggedTensor)
and model_name == 'weighted_bow1'
) or (model_name != 'weighted_bow1')
if valid_test_input:
model_generator = get_embedding_model_generators()[model_name]
(computed_norms, true_norms) = get_computed_and_true_norms(
model_generator=model_generator,
layer_generator=None,
input_dims=x_batch.shape[1:],
output_dim=output_dim,
is_eager=is_eager,
x_input=x_batch,
)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
if __name__ == '__main__':