Add fast gradient clipping tests.
PiperOrigin-RevId: 504923799
This commit is contained in:
parent
a3b14ae20a
commit
bc84ed7bfb
3 changed files with 313 additions and 3 deletions
|
@ -1,4 +1,4 @@
|
|||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
load("@rules_python//python:defs.bzl", "py_library", "py_test")
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
|
@ -23,3 +23,14 @@ py_library(
|
|||
":layer_registry_factories",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "clip_grads_test",
|
||||
srcs = ["clip_grads_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":clip_grads",
|
||||
":layer_registry_factories",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,296 @@
|
|||
# Copyright 2022, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry_factories
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Helper functions.
|
||||
# ==============================================================================
|
||||
def compute_true_gradient_norms(input_model, x_batch, y_batch):
|
||||
"""Computes the real gradient norms for an input `(model, x, y)`."""
|
||||
loss_config = input_model.loss.get_config()
|
||||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
||||
with tf.GradientTape(persistent=True) as tape:
|
||||
y_pred = input_model(x_batch)
|
||||
loss = per_example_loss_fn(y_batch, y_pred)
|
||||
if isinstance(loss, tf.RaggedTensor):
|
||||
loss = loss.to_tensor()
|
||||
sqr_norms = []
|
||||
for var in input_model.trainable_variables:
|
||||
jacobian = tape.jacobian(loss, var, experimental_use_pfor=False)
|
||||
reduction_axes = tf.range(1, tf.rank(jacobian))
|
||||
sqr_norms.append(tf.reduce_sum(tf.square(jacobian), axis=reduction_axes))
|
||||
sqr_norm_tsr = tf.stack(sqr_norms, axis=1)
|
||||
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
||||
|
||||
|
||||
def get_computed_and_true_norms(
|
||||
model_generator, layer_generator, input_dim, output_dim, is_eager, x_input
|
||||
):
|
||||
"""Obtains the true and computed gradient norms for a model and batch input.
|
||||
|
||||
Helpful testing wrapper function used to avoid code duplication.
|
||||
|
||||
Args:
|
||||
model_generator: A function which takes in three arguments:
|
||||
`layer_generator`, `idim`, and `odim`. Returns a `tf.keras.Model` that
|
||||
accepts input tensors of dimension `idim` and returns output tensors of
|
||||
dimension `odim`. Layers of the model are based on the `layer_generator`
|
||||
(see below for its description).
|
||||
layer_generator: A function which takes in two arguments: `idim` and `odim`.
|
||||
Returns a `tf.keras.layers.Layer` that accepts input tensors of dimension
|
||||
`idim` and returns output tensors of dimension `odim`.
|
||||
input_dim: The input dimension of the test `tf.keras.Model` instance.
|
||||
output_dim: The output dimension of the test `tf.keras.Model` instance.
|
||||
is_eager: A `bool` that is `True` if the model should be run eagerly.
|
||||
x_input: `tf.Tensor` inputs to be tested.
|
||||
|
||||
Returns:
|
||||
A `tuple` `(computed_norm, true_norms)`. The first element contains the
|
||||
clipped gradient norms that are generated by
|
||||
`clip_grads.compute_gradient_norms()` under the setting given by the given
|
||||
model and layer generators. The second element contains the true clipped
|
||||
gradient norms under the aforementioned setting.
|
||||
"""
|
||||
tf.keras.utils.set_random_seed(777)
|
||||
model = model_generator(layer_generator, input_dim, output_dim)
|
||||
model.compile(
|
||||
optimizer=tf.keras.optimizers.SGD(learning_rate=1.0),
|
||||
loss=tf.keras.losses.MeanSquaredError(
|
||||
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
|
||||
),
|
||||
run_eagerly=is_eager,
|
||||
)
|
||||
y_pred = model(x_input)
|
||||
y_batch = tf.ones_like(y_pred)
|
||||
registry = layer_registry_factories.make_default_layer_registry()
|
||||
computed_norms = clip_grads.compute_gradient_norms(
|
||||
model, x_input, y_batch, layer_registry=registry
|
||||
)
|
||||
true_norms = compute_true_gradient_norms(model, x_input, y_batch)
|
||||
return (computed_norms, true_norms)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Model generators.
|
||||
# ==============================================================================
|
||||
def make_two_layer_sequential_model(layer_generator, input_dim, output_dim):
|
||||
"""Creates a 2-layer sequential model."""
|
||||
model = tf.keras.Sequential()
|
||||
model.add(tf.keras.Input(shape=(input_dim,)))
|
||||
model.add(layer_generator(input_dim, output_dim))
|
||||
model.add(tf.keras.layers.Dense(1))
|
||||
return model
|
||||
|
||||
|
||||
def make_three_layer_sequential_model(layer_generator, input_dim, output_dim):
|
||||
"""Creates a 3-layer sequential model."""
|
||||
model = tf.keras.Sequential()
|
||||
model.add(tf.keras.Input(shape=(input_dim,)))
|
||||
layer1 = layer_generator(input_dim, output_dim)
|
||||
model.add(layer1)
|
||||
if isinstance(layer1, tf.keras.layers.Embedding):
|
||||
# Having multiple consecutive embedding layers does not make sense since
|
||||
# embedding layers only map integers to real-valued vectors.
|
||||
model.add(tf.keras.layers.Dense(output_dim))
|
||||
else:
|
||||
model.add(layer_generator(output_dim, output_dim))
|
||||
model.add(tf.keras.layers.Dense(1))
|
||||
return model
|
||||
|
||||
|
||||
def make_two_layer_functional_model(layer_generator, input_dim, output_dim):
|
||||
"""Creates a 2-layer 1-input functional model with a pre-output square op."""
|
||||
inputs = tf.keras.Input(shape=(input_dim,))
|
||||
layer1 = layer_generator(input_dim, output_dim)
|
||||
temp1 = layer1(inputs)
|
||||
temp2 = tf.square(temp1)
|
||||
outputs = tf.keras.layers.Dense(1)(temp2)
|
||||
return tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
|
||||
def make_two_tower_model(layer_generator, input_dim, output_dim):
|
||||
"""Creates a 2-layer 2-input functional model."""
|
||||
inputs1 = tf.keras.Input(shape=(input_dim,))
|
||||
layer1 = layer_generator(input_dim, output_dim)
|
||||
temp1 = layer1(inputs1)
|
||||
inputs2 = tf.keras.Input(shape=(input_dim,))
|
||||
layer2 = layer_generator(input_dim, output_dim)
|
||||
temp2 = layer2(inputs2)
|
||||
temp3 = tf.add(temp1, temp2)
|
||||
outputs = tf.keras.layers.Dense(1)(temp3)
|
||||
return tf.keras.Model(inputs=[inputs1, inputs2], outputs=outputs)
|
||||
|
||||
|
||||
def make_bow_model(layer_generator, input_dim, output_dim):
|
||||
del layer_generator
|
||||
inputs = tf.keras.Input(shape=(input_dim,))
|
||||
# For the Embedding layer, input_dim is the vocabulary size. This should
|
||||
# be distinguished from the input_dim argument, which is the number of ids
|
||||
# in eache example.
|
||||
emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim)
|
||||
feature_embs = emb_layer(inputs)
|
||||
example_embs = tf.reduce_sum(feature_embs, axis=1)
|
||||
return tf.keras.Model(inputs=inputs, outputs=example_embs)
|
||||
|
||||
|
||||
def make_dense_bow_model(layer_generator, input_dim, output_dim):
|
||||
del layer_generator
|
||||
inputs = tf.keras.Input(shape=(input_dim,))
|
||||
# For the Embedding layer, input_dim is the vocabulary size. This should
|
||||
# be distinguished from the input_dim argument, which is the number of ids
|
||||
# in eache example.
|
||||
emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim)
|
||||
feature_embs = emb_layer(inputs)
|
||||
example_embs = tf.reduce_sum(feature_embs, axis=1)
|
||||
outputs = tf.keras.layers.Dense(1)(example_embs)
|
||||
return tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Factory functions.
|
||||
# ==============================================================================
|
||||
def get_nd_test_tensors(n):
|
||||
"""Returns a list of candidate tests for a given dimension n."""
|
||||
return [
|
||||
tf.zeros((n,), dtype=tf.float64),
|
||||
tf.convert_to_tensor(range(n), dtype_hint=tf.float64),
|
||||
]
|
||||
|
||||
|
||||
def get_nd_test_batches(n):
|
||||
"""Returns a list of candidate input batches of dimension n."""
|
||||
result = []
|
||||
tensors = get_nd_test_tensors(n)
|
||||
for batch_size in range(1, len(tensors) + 1, 1):
|
||||
combinations = list(
|
||||
itertools.combinations(get_nd_test_tensors(n), batch_size)
|
||||
)
|
||||
result = result + [tf.stack(ts, axis=0) for ts in combinations]
|
||||
return result
|
||||
|
||||
|
||||
def get_dense_layer_generators():
|
||||
def sigmoid_dense_layer(b):
|
||||
return tf.keras.layers.Dense(b, activation='sigmoid')
|
||||
|
||||
return {
|
||||
'pure_dense': lambda a, b: tf.keras.layers.Dense(b),
|
||||
'sigmoid_dense': lambda a, b: sigmoid_dense_layer(b),
|
||||
}
|
||||
|
||||
|
||||
def get_dense_model_generators():
|
||||
return {
|
||||
'seq1': make_two_layer_sequential_model,
|
||||
'seq2': make_three_layer_sequential_model,
|
||||
'func1': make_two_layer_functional_model,
|
||||
'tower1': make_two_tower_model,
|
||||
}
|
||||
|
||||
|
||||
def get_embedding_model_generators():
|
||||
return {
|
||||
'bow1': make_bow_model,
|
||||
'bow2': make_dense_bow_model,
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main tests.
|
||||
# ==============================================================================
|
||||
class ClipGradsDirectTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.product(
|
||||
input_dim=[1, 2], clip_value=[1e-6, 0.5, 1.0, 2.0, 10.0, 1e6]
|
||||
)
|
||||
def test_clip_weights(self, input_dim, clip_value):
|
||||
tol = 1e-6
|
||||
for t in get_nd_test_tensors(input_dim):
|
||||
self.assertIsNone(clip_grads.compute_clip_weights(None, t))
|
||||
weights = clip_grads.compute_clip_weights(clip_value, t)
|
||||
self.assertAllLessEqual(t * weights, clip_value + tol)
|
||||
|
||||
|
||||
class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.product(
|
||||
model_name=list(get_dense_model_generators().keys()),
|
||||
layer_name=list(get_dense_layer_generators().keys()),
|
||||
input_dim=[1, 2],
|
||||
output_dim=[1, 2],
|
||||
is_eager=[True, False],
|
||||
)
|
||||
def test_gradient_norms_on_various_models(
|
||||
self, model_name, layer_name, input_dim, output_dim, is_eager
|
||||
):
|
||||
model_generator = get_dense_model_generators()[model_name]
|
||||
layer_generator = get_dense_layer_generators()[layer_name]
|
||||
x_batches = get_nd_test_batches(input_dim)
|
||||
for x_batch in x_batches:
|
||||
if model_name == 'tower1':
|
||||
x_input = [x_batch, x_batch]
|
||||
else:
|
||||
x_input = x_batch
|
||||
(computed_norms, true_norms) = get_computed_and_true_norms(
|
||||
model_generator,
|
||||
layer_generator,
|
||||
input_dim,
|
||||
output_dim,
|
||||
is_eager,
|
||||
x_input,
|
||||
)
|
||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
# TODO(wkong): Test sparse input tensors when the GitHub CI environment
|
||||
# supports them for embeddings.
|
||||
@parameterized.product(
|
||||
x_batch=[
|
||||
tf.convert_to_tensor([[0, 1], [1, 0]], dtype_hint=tf.int32),
|
||||
tf.convert_to_tensor([[0, 1], [1, 1], [0, 0]], dtype_hint=tf.int32),
|
||||
tf.ragged.constant(
|
||||
[[0], [1], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32
|
||||
),
|
||||
],
|
||||
model_name=list(get_embedding_model_generators().keys()),
|
||||
output_dim=[1, 2],
|
||||
is_eager=[True, False],
|
||||
)
|
||||
def test_gradient_norms_on_various_models(
|
||||
self, x_batch, model_name, output_dim, is_eager
|
||||
):
|
||||
model_generator = get_embedding_model_generators()[model_name]
|
||||
(computed_norms, true_norms) = get_computed_and_true_norms(
|
||||
model_generator=model_generator,
|
||||
layer_generator=None,
|
||||
input_dim=2,
|
||||
output_dim=output_dim,
|
||||
is_eager=is_eager,
|
||||
x_input=x_batch,
|
||||
)
|
||||
self.assertAllClose(computed_norms, true_norms)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -100,8 +100,11 @@ def embedding_layer_computation(layer_instance, inputs):
|
|||
and `base_vars` is the intermediate Tensor used in the chain-rule / "fast"
|
||||
clipping trick.
|
||||
"""
|
||||
if layer_instance.sparse:
|
||||
raise NotImplementedError("Sparse output vectors are not supported.")
|
||||
if hasattr(layer_instance, "sparse"): # for backwards compatibility
|
||||
if layer_instance.sparse:
|
||||
raise NotImplementedError("Sparse output vectors are not supported.")
|
||||
if tf.rank(*inputs) != 2:
|
||||
raise NotImplementedError("Only 2D embedding inputs are supported.")
|
||||
# The logic below is applied to properly handle repeated embedding indices.
|
||||
# Specifically, sqr_grad_norms will contain the total counts of each embedding
|
||||
# index (see how it is processed in the combine_pre_and_post_sqr_norms()
|
||||
|
|
Loading…
Reference in a new issue