Add a tf.GradientTape argument to the layer registry functions

PiperOrigin-RevId: 512160655
This commit is contained in:
A. Unique TensorFlower 2023-02-24 14:14:01 -08:00
parent 4dd8d0ffde
commit dda7fa8b39
3 changed files with 95 additions and 31 deletions

View file

@ -42,12 +42,9 @@ def get_registry_generator_fn(tape, layer_registry):
% layer_instance.__class__.__name__
)
registry_fn = layer_registry.lookup(layer_instance)
(layer_vars, transform, layer_sqr_norm_fn) = registry_fn(
layer_instance, args
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
layer_instance, args, tape
)
if tape is not None:
tape.watch(layer_vars)
layer_outputs = transform(layer_vars) if transform else layer_vars
return layer_outputs, (layer_vars, layer_sqr_norm_fn)
else:
# Non-trainable layer.

View file

@ -21,8 +21,38 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
# ==============================================================================
# Helper functions.
# Helper functions and classes.
# ==============================================================================
class DoubleDense(tf.keras.layers.Layer):
"""Generates two dense layers nested together."""
def __init__(self, units):
super().__init__()
self.dense1 = tf.keras.layers.Dense(units)
self.dense2 = tf.keras.layers.Dense(1)
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
def double_dense_layer_computation(layer_instance, inputs, tape):
"""Layer registry function for the custom `DoubleDense` layer class."""
vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation(
layer_instance.dense1, inputs, tape
)
vars2, outputs, sqr_norm_fn2 = layer_registry.dense_layer_computation(
layer_instance.dense2, (outputs,), tape
)
def sqr_norm_fn(base_vars):
norms1 = sqr_norm_fn1(base_vars[0])
norms2 = sqr_norm_fn2(base_vars[1])
return norms1 + norms2
return [vars1, vars2], outputs, sqr_norm_fn
def compute_true_gradient_norms(input_model, x_batch, y_batch):
"""Computes the real gradient norms for an input `(model, x, y)`."""
loss_config = input_model.loss.get_config()
@ -50,6 +80,7 @@ def get_computed_and_true_norms(
is_eager,
x_input,
rng_seed=777,
registry=None,
):
"""Obtains the true and computed gradient norms for a model and batch input.
@ -69,6 +100,7 @@ def get_computed_and_true_norms(
is_eager: A `bool` that is `True` if the model should be run eagerly.
x_input: `tf.Tensor` inputs to be tested.
rng_seed: An `int` used to initialize model weights.
registry: A `layer_registry.LayerRegistry` instance.
Returns:
A `tuple` `(computed_norm, true_norms)`. The first element contains the
@ -87,7 +119,6 @@ def get_computed_and_true_norms(
)
y_pred = model(x_input)
y_batch = tf.ones_like(y_pred)
registry = layer_registry.make_default_layer_registry()
tf.keras.utils.set_random_seed(rng_seed)
computed_norms = clip_grads.compute_gradient_norms(
model, x_input, y_batch, layer_registry=registry
@ -285,6 +316,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
model_generator = get_dense_model_generators()[model_name]
layer_generator = get_dense_layer_generators()[layer_name]
x_batches = get_nd_test_batches(input_dim)
default_registry = layer_registry.make_default_layer_registry()
for x_batch in x_batches:
if model_name == 'tower1':
x_input = [x_batch, x_batch]
@ -297,6 +329,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
output_dim,
is_eager,
x_input,
registry=default_registry,
)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
@ -335,6 +368,7 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
and model_name == 'weighted_bow1'
) or (model_name != 'weighted_bow1')
if valid_test_input:
default_registry = layer_registry.make_default_layer_registry()
model_generator = get_embedding_model_generators()[model_name]
(computed_norms, true_norms) = get_computed_and_true_norms(
model_generator=model_generator,
@ -343,6 +377,33 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
output_dim=output_dim,
is_eager=is_eager,
x_input=x_batch,
registry=default_registry,
)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
input_dim=[1, 2],
output_dim=[1, 2],
is_eager=[True, False],
)
def test_gradient_norms_on_various_models(
self, input_dim, output_dim, is_eager
):
registry = layer_registry.make_default_layer_registry()
registry.insert(DoubleDense, double_dense_layer_computation)
x_batches = get_nd_test_batches(input_dim)
for x_batch in x_batches:
(computed_norms, true_norms) = get_computed_and_true_norms(
model_generator=make_two_layer_sequential_model,
layer_generator=lambda a, b: DoubleDense(b),
input_dims=input_dim,
output_dim=output_dim,
is_eager=is_eager,
x_input=x_batch,
registry=registry,
)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)

View file

@ -72,7 +72,7 @@ class LayerRegistry:
# ==============================================================================
# Supported Keras layers
# ==============================================================================
def dense_layer_computation(layer_instance, inputs):
def dense_layer_computation(layer_instance, inputs, tape):
"""Registry function for `tf.keras.layers.Dense`.
The logic for this computation is based on the following paper:
@ -85,23 +85,27 @@ def dense_layer_computation(layer_instance, inputs):
layer_instance: A `tf.keras.layers.Dense` instance.
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
`layer_instance(inputs)` returns a valid output.
tape: A `tf.GradientTape` instance that will be used to watch the output
`base_vars`.
Returns:
A `tuple` `(base_vars, transform, sqr_norm_fn)`. `base_vars` is the
A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the
intermediate Tensor used in the chain-rule / "fast" clipping trick,
`transform` is a function that maps `base_vars` to the layer outputs, and
`sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that
represents the output of the call `tape.gradient(summed_loss, base_vars)`
where `tape` is a `tf.GradientTape` instance that records the dense
layer computation and `summed_loss` is the sum of the per-example losses
of the underlying model. This function then returns the per-example squared
L2 gradient norms of the trainable variables in `layer_instance`. These
squared norms should be a 1D `tf.Tensor` of length `batch_size`.
`outputs` is the result of `layer_instance(*inputs)`, and `sqr_norm_fn` is
a function that takes one input, a `tf.Tensor` that represents the output
of the call `tape.gradient(summed_loss, base_vars)` where `tape` is a
`tf.GradientTape` instance that records the dense layer computation and
`summed_loss` is the sum of the per-example losses of the underlying model.
This function then returns the per-example squared L2 gradient norms of the
trainable variables in `layer_instance`. These squared norms should be a 1D
`tf.Tensor` of length `batch_size`.
"""
orig_activation = layer_instance.activation
layer_instance.activation = None
base_vars = layer_instance(*inputs)
tape.watch(base_vars)
layer_instance.activation = orig_activation
outputs = orig_activation(base_vars) if orig_activation else base_vars
def sqr_norm_fn(base_vars_grads):
sqr_inputs = tf.square(*inputs)
inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs))
@ -118,10 +122,10 @@ def dense_layer_computation(layer_instance, inputs):
)
return input_sqr_norms * base_vars_sqr_norms
return base_vars, layer_instance.activation, sqr_norm_fn
return base_vars, outputs, sqr_norm_fn
def embedding_layer_computation(layer_instance, inputs):
def embedding_layer_computation(layer_instance, inputs, tape):
"""Registry function for `tf.keras.layers.Embedding`.
The logic of this computation is based on the `tf.keras.layers.Dense`
@ -134,17 +138,20 @@ def embedding_layer_computation(layer_instance, inputs):
layer_instance: A `tf.keras.layers.Embedding` instance.
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
`layer_instance(inputs)` returns a valid output.
tape: A `tf.GradientTape` instance that will be used to watch the output
`base_vars`.
Returns:
A `tuple` `(base_vars, transform, sqr_norm_fn)`, `base_vars` is the
A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the
intermediate Tensor used in the chain-rule / "fast" clipping trick,
`sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that
represents the output of the call `tape.gradient(summed_loss, base_vars)`
where `tape` is a `tf.GradientTape` instance that records the dense
layer computation and `summed_loss` is the sum of the per-example losses
of the underlying model. This function then returns the per-example squared
L2 gradient norms of the trainable variables in `layer_instance`. These
squared norms should be a 1D `tf.Tensor` of length `batch_size`.
`outputs` is the result of `layer_instance(*inputs)`, and `sqr_norm_fn` is
a function that takes one input, a `tf.Tensor` that represents the output
of the call `tape.gradient(summed_loss, base_vars)` where `tape` is a
`tf.GradientTape` instance that records the dense layer computation and
`summed_loss` is the sum of the per-example losses of the underlying model.
This function then returns the per-example squared L2 gradient norms of the
trainable variables in `layer_instance`. These squared norms should be a 1D
`tf.Tensor` of length `batch_size`.
"""
if hasattr(layer_instance, "sparse"): # for backwards compatibility
if layer_instance.sparse:
@ -161,9 +168,8 @@ def embedding_layer_computation(layer_instance, inputs):
)
input_ids = tf.cast(*inputs, tf.int32)
base_vars = layer_instance.trainable_variables[0]
def lookup_inputs(embeddings):
return tf.nn.embedding_lookup(embeddings, input_ids)
tape.watch(base_vars)
outputs = tf.nn.embedding_lookup(base_vars, input_ids)
def sqr_norm_fn(base_vars_grads):
# Get a 1D tensor of the row indices.
@ -213,7 +219,7 @@ def embedding_layer_computation(layer_instance, inputs):
num_segments=nrows,
) # fill in empty inputs
return base_vars, lookup_inputs, sqr_norm_fn
return base_vars, outputs, sqr_norm_fn
# ==============================================================================