forked from 626_privacy/tensorflow_privacy
Add a tf.GradientTape
argument to the layer registry functions
PiperOrigin-RevId: 512160655
This commit is contained in:
parent
4dd8d0ffde
commit
dda7fa8b39
3 changed files with 95 additions and 31 deletions
|
@ -42,12 +42,9 @@ def get_registry_generator_fn(tape, layer_registry):
|
||||||
% layer_instance.__class__.__name__
|
% layer_instance.__class__.__name__
|
||||||
)
|
)
|
||||||
registry_fn = layer_registry.lookup(layer_instance)
|
registry_fn = layer_registry.lookup(layer_instance)
|
||||||
(layer_vars, transform, layer_sqr_norm_fn) = registry_fn(
|
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
|
||||||
layer_instance, args
|
layer_instance, args, tape
|
||||||
)
|
)
|
||||||
if tape is not None:
|
|
||||||
tape.watch(layer_vars)
|
|
||||||
layer_outputs = transform(layer_vars) if transform else layer_vars
|
|
||||||
return layer_outputs, (layer_vars, layer_sqr_norm_fn)
|
return layer_outputs, (layer_vars, layer_sqr_norm_fn)
|
||||||
else:
|
else:
|
||||||
# Non-trainable layer.
|
# Non-trainable layer.
|
||||||
|
|
|
@ -21,8 +21,38 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Helper functions.
|
# Helper functions and classes.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
class DoubleDense(tf.keras.layers.Layer):
|
||||||
|
"""Generates two dense layers nested together."""
|
||||||
|
|
||||||
|
def __init__(self, units):
|
||||||
|
super().__init__()
|
||||||
|
self.dense1 = tf.keras.layers.Dense(units)
|
||||||
|
self.dense2 = tf.keras.layers.Dense(1)
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
x = self.dense1(inputs)
|
||||||
|
return self.dense2(x)
|
||||||
|
|
||||||
|
|
||||||
|
def double_dense_layer_computation(layer_instance, inputs, tape):
|
||||||
|
"""Layer registry function for the custom `DoubleDense` layer class."""
|
||||||
|
vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation(
|
||||||
|
layer_instance.dense1, inputs, tape
|
||||||
|
)
|
||||||
|
vars2, outputs, sqr_norm_fn2 = layer_registry.dense_layer_computation(
|
||||||
|
layer_instance.dense2, (outputs,), tape
|
||||||
|
)
|
||||||
|
|
||||||
|
def sqr_norm_fn(base_vars):
|
||||||
|
norms1 = sqr_norm_fn1(base_vars[0])
|
||||||
|
norms2 = sqr_norm_fn2(base_vars[1])
|
||||||
|
return norms1 + norms2
|
||||||
|
|
||||||
|
return [vars1, vars2], outputs, sqr_norm_fn
|
||||||
|
|
||||||
|
|
||||||
def compute_true_gradient_norms(input_model, x_batch, y_batch):
|
def compute_true_gradient_norms(input_model, x_batch, y_batch):
|
||||||
"""Computes the real gradient norms for an input `(model, x, y)`."""
|
"""Computes the real gradient norms for an input `(model, x, y)`."""
|
||||||
loss_config = input_model.loss.get_config()
|
loss_config = input_model.loss.get_config()
|
||||||
|
@ -50,6 +80,7 @@ def get_computed_and_true_norms(
|
||||||
is_eager,
|
is_eager,
|
||||||
x_input,
|
x_input,
|
||||||
rng_seed=777,
|
rng_seed=777,
|
||||||
|
registry=None,
|
||||||
):
|
):
|
||||||
"""Obtains the true and computed gradient norms for a model and batch input.
|
"""Obtains the true and computed gradient norms for a model and batch input.
|
||||||
|
|
||||||
|
@ -69,6 +100,7 @@ def get_computed_and_true_norms(
|
||||||
is_eager: A `bool` that is `True` if the model should be run eagerly.
|
is_eager: A `bool` that is `True` if the model should be run eagerly.
|
||||||
x_input: `tf.Tensor` inputs to be tested.
|
x_input: `tf.Tensor` inputs to be tested.
|
||||||
rng_seed: An `int` used to initialize model weights.
|
rng_seed: An `int` used to initialize model weights.
|
||||||
|
registry: A `layer_registry.LayerRegistry` instance.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `tuple` `(computed_norm, true_norms)`. The first element contains the
|
A `tuple` `(computed_norm, true_norms)`. The first element contains the
|
||||||
|
@ -87,7 +119,6 @@ def get_computed_and_true_norms(
|
||||||
)
|
)
|
||||||
y_pred = model(x_input)
|
y_pred = model(x_input)
|
||||||
y_batch = tf.ones_like(y_pred)
|
y_batch = tf.ones_like(y_pred)
|
||||||
registry = layer_registry.make_default_layer_registry()
|
|
||||||
tf.keras.utils.set_random_seed(rng_seed)
|
tf.keras.utils.set_random_seed(rng_seed)
|
||||||
computed_norms = clip_grads.compute_gradient_norms(
|
computed_norms = clip_grads.compute_gradient_norms(
|
||||||
model, x_input, y_batch, layer_registry=registry
|
model, x_input, y_batch, layer_registry=registry
|
||||||
|
@ -285,6 +316,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
model_generator = get_dense_model_generators()[model_name]
|
model_generator = get_dense_model_generators()[model_name]
|
||||||
layer_generator = get_dense_layer_generators()[layer_name]
|
layer_generator = get_dense_layer_generators()[layer_name]
|
||||||
x_batches = get_nd_test_batches(input_dim)
|
x_batches = get_nd_test_batches(input_dim)
|
||||||
|
default_registry = layer_registry.make_default_layer_registry()
|
||||||
for x_batch in x_batches:
|
for x_batch in x_batches:
|
||||||
if model_name == 'tower1':
|
if model_name == 'tower1':
|
||||||
x_input = [x_batch, x_batch]
|
x_input = [x_batch, x_batch]
|
||||||
|
@ -297,6 +329,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
output_dim,
|
output_dim,
|
||||||
is_eager,
|
is_eager,
|
||||||
x_input,
|
x_input,
|
||||||
|
registry=default_registry,
|
||||||
)
|
)
|
||||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||||
|
|
||||||
|
@ -335,6 +368,7 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
and model_name == 'weighted_bow1'
|
and model_name == 'weighted_bow1'
|
||||||
) or (model_name != 'weighted_bow1')
|
) or (model_name != 'weighted_bow1')
|
||||||
if valid_test_input:
|
if valid_test_input:
|
||||||
|
default_registry = layer_registry.make_default_layer_registry()
|
||||||
model_generator = get_embedding_model_generators()[model_name]
|
model_generator = get_embedding_model_generators()[model_name]
|
||||||
(computed_norms, true_norms) = get_computed_and_true_norms(
|
(computed_norms, true_norms) = get_computed_and_true_norms(
|
||||||
model_generator=model_generator,
|
model_generator=model_generator,
|
||||||
|
@ -343,6 +377,33 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
output_dim=output_dim,
|
output_dim=output_dim,
|
||||||
is_eager=is_eager,
|
is_eager=is_eager,
|
||||||
x_input=x_batch,
|
x_input=x_batch,
|
||||||
|
registry=default_registry,
|
||||||
|
)
|
||||||
|
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
input_dim=[1, 2],
|
||||||
|
output_dim=[1, 2],
|
||||||
|
is_eager=[True, False],
|
||||||
|
)
|
||||||
|
def test_gradient_norms_on_various_models(
|
||||||
|
self, input_dim, output_dim, is_eager
|
||||||
|
):
|
||||||
|
registry = layer_registry.make_default_layer_registry()
|
||||||
|
registry.insert(DoubleDense, double_dense_layer_computation)
|
||||||
|
x_batches = get_nd_test_batches(input_dim)
|
||||||
|
for x_batch in x_batches:
|
||||||
|
(computed_norms, true_norms) = get_computed_and_true_norms(
|
||||||
|
model_generator=make_two_layer_sequential_model,
|
||||||
|
layer_generator=lambda a, b: DoubleDense(b),
|
||||||
|
input_dims=input_dim,
|
||||||
|
output_dim=output_dim,
|
||||||
|
is_eager=is_eager,
|
||||||
|
x_input=x_batch,
|
||||||
|
registry=registry,
|
||||||
)
|
)
|
||||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||||
|
|
||||||
|
|
|
@ -72,7 +72,7 @@ class LayerRegistry:
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Supported Keras layers
|
# Supported Keras layers
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
def dense_layer_computation(layer_instance, inputs):
|
def dense_layer_computation(layer_instance, inputs, tape):
|
||||||
"""Registry function for `tf.keras.layers.Dense`.
|
"""Registry function for `tf.keras.layers.Dense`.
|
||||||
|
|
||||||
The logic for this computation is based on the following paper:
|
The logic for this computation is based on the following paper:
|
||||||
|
@ -85,23 +85,27 @@ def dense_layer_computation(layer_instance, inputs):
|
||||||
layer_instance: A `tf.keras.layers.Dense` instance.
|
layer_instance: A `tf.keras.layers.Dense` instance.
|
||||||
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
||||||
`layer_instance(inputs)` returns a valid output.
|
`layer_instance(inputs)` returns a valid output.
|
||||||
|
tape: A `tf.GradientTape` instance that will be used to watch the output
|
||||||
|
`base_vars`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `tuple` `(base_vars, transform, sqr_norm_fn)`. `base_vars` is the
|
A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the
|
||||||
intermediate Tensor used in the chain-rule / "fast" clipping trick,
|
intermediate Tensor used in the chain-rule / "fast" clipping trick,
|
||||||
`transform` is a function that maps `base_vars` to the layer outputs, and
|
`outputs` is the result of `layer_instance(*inputs)`, and `sqr_norm_fn` is
|
||||||
`sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that
|
a function that takes one input, a `tf.Tensor` that represents the output
|
||||||
represents the output of the call `tape.gradient(summed_loss, base_vars)`
|
of the call `tape.gradient(summed_loss, base_vars)` where `tape` is a
|
||||||
where `tape` is a `tf.GradientTape` instance that records the dense
|
`tf.GradientTape` instance that records the dense layer computation and
|
||||||
layer computation and `summed_loss` is the sum of the per-example losses
|
`summed_loss` is the sum of the per-example losses of the underlying model.
|
||||||
of the underlying model. This function then returns the per-example squared
|
This function then returns the per-example squared L2 gradient norms of the
|
||||||
L2 gradient norms of the trainable variables in `layer_instance`. These
|
trainable variables in `layer_instance`. These squared norms should be a 1D
|
||||||
squared norms should be a 1D `tf.Tensor` of length `batch_size`.
|
`tf.Tensor` of length `batch_size`.
|
||||||
"""
|
"""
|
||||||
orig_activation = layer_instance.activation
|
orig_activation = layer_instance.activation
|
||||||
layer_instance.activation = None
|
layer_instance.activation = None
|
||||||
base_vars = layer_instance(*inputs)
|
base_vars = layer_instance(*inputs)
|
||||||
|
tape.watch(base_vars)
|
||||||
layer_instance.activation = orig_activation
|
layer_instance.activation = orig_activation
|
||||||
|
outputs = orig_activation(base_vars) if orig_activation else base_vars
|
||||||
def sqr_norm_fn(base_vars_grads):
|
def sqr_norm_fn(base_vars_grads):
|
||||||
sqr_inputs = tf.square(*inputs)
|
sqr_inputs = tf.square(*inputs)
|
||||||
inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs))
|
inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs))
|
||||||
|
@ -118,10 +122,10 @@ def dense_layer_computation(layer_instance, inputs):
|
||||||
)
|
)
|
||||||
return input_sqr_norms * base_vars_sqr_norms
|
return input_sqr_norms * base_vars_sqr_norms
|
||||||
|
|
||||||
return base_vars, layer_instance.activation, sqr_norm_fn
|
return base_vars, outputs, sqr_norm_fn
|
||||||
|
|
||||||
|
|
||||||
def embedding_layer_computation(layer_instance, inputs):
|
def embedding_layer_computation(layer_instance, inputs, tape):
|
||||||
"""Registry function for `tf.keras.layers.Embedding`.
|
"""Registry function for `tf.keras.layers.Embedding`.
|
||||||
|
|
||||||
The logic of this computation is based on the `tf.keras.layers.Dense`
|
The logic of this computation is based on the `tf.keras.layers.Dense`
|
||||||
|
@ -134,17 +138,20 @@ def embedding_layer_computation(layer_instance, inputs):
|
||||||
layer_instance: A `tf.keras.layers.Embedding` instance.
|
layer_instance: A `tf.keras.layers.Embedding` instance.
|
||||||
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
||||||
`layer_instance(inputs)` returns a valid output.
|
`layer_instance(inputs)` returns a valid output.
|
||||||
|
tape: A `tf.GradientTape` instance that will be used to watch the output
|
||||||
|
`base_vars`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `tuple` `(base_vars, transform, sqr_norm_fn)`, `base_vars` is the
|
A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the
|
||||||
intermediate Tensor used in the chain-rule / "fast" clipping trick,
|
intermediate Tensor used in the chain-rule / "fast" clipping trick,
|
||||||
`sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that
|
`outputs` is the result of `layer_instance(*inputs)`, and `sqr_norm_fn` is
|
||||||
represents the output of the call `tape.gradient(summed_loss, base_vars)`
|
a function that takes one input, a `tf.Tensor` that represents the output
|
||||||
where `tape` is a `tf.GradientTape` instance that records the dense
|
of the call `tape.gradient(summed_loss, base_vars)` where `tape` is a
|
||||||
layer computation and `summed_loss` is the sum of the per-example losses
|
`tf.GradientTape` instance that records the dense layer computation and
|
||||||
of the underlying model. This function then returns the per-example squared
|
`summed_loss` is the sum of the per-example losses of the underlying model.
|
||||||
L2 gradient norms of the trainable variables in `layer_instance`. These
|
This function then returns the per-example squared L2 gradient norms of the
|
||||||
squared norms should be a 1D `tf.Tensor` of length `batch_size`.
|
trainable variables in `layer_instance`. These squared norms should be a 1D
|
||||||
|
`tf.Tensor` of length `batch_size`.
|
||||||
"""
|
"""
|
||||||
if hasattr(layer_instance, "sparse"): # for backwards compatibility
|
if hasattr(layer_instance, "sparse"): # for backwards compatibility
|
||||||
if layer_instance.sparse:
|
if layer_instance.sparse:
|
||||||
|
@ -161,9 +168,8 @@ def embedding_layer_computation(layer_instance, inputs):
|
||||||
)
|
)
|
||||||
input_ids = tf.cast(*inputs, tf.int32)
|
input_ids = tf.cast(*inputs, tf.int32)
|
||||||
base_vars = layer_instance.trainable_variables[0]
|
base_vars = layer_instance.trainable_variables[0]
|
||||||
|
tape.watch(base_vars)
|
||||||
def lookup_inputs(embeddings):
|
outputs = tf.nn.embedding_lookup(base_vars, input_ids)
|
||||||
return tf.nn.embedding_lookup(embeddings, input_ids)
|
|
||||||
|
|
||||||
def sqr_norm_fn(base_vars_grads):
|
def sqr_norm_fn(base_vars_grads):
|
||||||
# Get a 1D tensor of the row indices.
|
# Get a 1D tensor of the row indices.
|
||||||
|
@ -213,7 +219,7 @@ def embedding_layer_computation(layer_instance, inputs):
|
||||||
num_segments=nrows,
|
num_segments=nrows,
|
||||||
) # fill in empty inputs
|
) # fill in empty inputs
|
||||||
|
|
||||||
return base_vars, lookup_inputs, sqr_norm_fn
|
return base_vars, outputs, sqr_norm_fn
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
Loading…
Reference in a new issue