Allow custom per example loss functions for computing per microbatch gradient norm.

PiperOrigin-RevId: 516897864
This commit is contained in:
A. Unique TensorFlower 2023-03-15 12:27:59 -07:00
parent d7d497bb69
commit 8f4ab1a8bb
2 changed files with 51 additions and 11 deletions

View file

@ -21,7 +21,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
`compute_gradient_norms()` function).
"""
from typing import Dict, Iterable, Optional, Text, Union
from typing import Any, Callable, Dict, Iterable, Optional, Text, Union
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
@ -67,6 +67,7 @@ def compute_gradient_norms(
x_batch: InputTensor,
y_batch: tf.Tensor,
layer_registry: lr.LayerRegistry,
per_example_loss_fn: Optional[Callable[[tf.Tensor, Any], tf.Tensor]] = None,
num_microbatches: Optional[lr.BatchSize] = None,
):
"""Computes the per-example loss gradient norms for given data.
@ -86,6 +87,9 @@ def compute_gradient_norms(
compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details.
per_example_loss_fn: If not None, used as the function to compute the
vectorized per example loss. Otherwise, we derive it from `input_model`'s
loss function.
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
microbatches. If not None, indicates that the loss is grouped into
num_microbatches (in this case, the batch dimension needs to be a multiple
@ -109,9 +113,10 @@ def compute_gradient_norms(
)
)
# Ignore the original loss function's reduction to get per-example loss.
loss_config = input_model.loss.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
per_example_loss_fn = input_model.loss.from_config(loss_config)
if per_example_loss_fn is None:
loss_config = input_model.loss.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
per_example_loss_fn = input_model.loss.from_config(loss_config)
losses = per_example_loss_fn(y_batch, model_outputs)
if num_microbatches is not None:
losses = tf.reduce_mean(
@ -204,7 +209,11 @@ def compute_pred_and_clipped_gradients(
gradient of the loss function.
"""
gradient_norms = compute_gradient_norms(
input_model, x_batch, y_batch, layer_registry, num_microbatches
input_model,
x_batch,
y_batch,
layer_registry,
num_microbatches=num_microbatches,
)
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
with tf.GradientTape() as tape:

View file

@ -71,16 +71,25 @@ def double_dense_layer_computation(
return [vars1, vars2], outputs, sqr_norm_fn
def test_loss_fn(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
x = tf.reshape(x, (tf.shape(x)[0], -1))
y = tf.reshape(y, (tf.shape(y)[0], -1))
# Define a loss function which is unlikely to be coincidently defined.
return 3.14 * tf.reduce_sum(tf.square(x - y), axis=1)
def compute_true_gradient_norms(
input_model: tf.keras.Model,
x_batch: tf.Tensor,
y_batch: tf.Tensor,
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
num_microbatches: Optional[int],
) -> layer_registry.OutputTensor:
"""Computes the real gradient norms for an input `(model, x, y)`."""
loss_config = input_model.loss.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
per_example_loss_fn = input_model.loss.from_config(loss_config)
if per_example_loss_fn is None:
loss_config = input_model.loss.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
per_example_loss_fn = input_model.loss.from_config(loss_config)
with tf.GradientTape(persistent=True) as tape:
y_pred = input_model(x_batch)
loss = per_example_loss_fn(y_batch, y_pred)
@ -109,6 +118,7 @@ def get_computed_and_true_norms(
layer_generator: LayerGenerator,
input_dims: Union[int, List[int]],
output_dim: int,
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
num_microbatches: Optional[int],
is_eager: bool,
x_input: tf.Tensor,
@ -130,6 +140,8 @@ def get_computed_and_true_norms(
`idim` and returns output tensors of dimension `odim`.
input_dims: The input dimension(s) of the test `tf.keras.Model` instance.
output_dim: The output dimension of the test `tf.keras.Model` instance.
per_example_loss_fn: If not None, used as vectorized per example loss
function.
num_microbatches: The number of microbatches. None or an integer.
is_eager: A `bool` that is `True` if the model should be run eagerly.
x_input: `tf.Tensor` inputs to be tested.
@ -159,11 +171,12 @@ def get_computed_and_true_norms(
x_input,
y_batch,
layer_registry=registry,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
)
tf.keras.utils.set_random_seed(rng_seed)
true_norms = compute_true_gradient_norms(
model, x_input, y_batch, num_microbatches
model, x_input, y_batch, per_example_loss_fn, num_microbatches
)
return (computed_norms, true_norms)
@ -348,6 +361,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
layer_name=list(get_dense_layer_generators().keys()),
input_dim=[4],
output_dim=[1, 2],
per_example_loss_fn=[None, test_loss_fn],
num_microbatches=[None, 1, 2],
is_eager=[True, False],
)
@ -357,6 +371,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
layer_name,
input_dim,
output_dim,
per_example_loss_fn,
num_microbatches,
is_eager,
):
@ -379,6 +394,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
layer_generator,
input_dim,
output_dim,
per_example_loss_fn,
num_microbatches,
is_eager,
x_input,
@ -419,11 +435,18 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
],
model_name=list(get_embedding_model_generators().keys()),
output_dim=[2],
per_example_loss_fn=[None, test_loss_fn],
num_microbatches=[None, 1, 2],
is_eager=[True],
)
def test_gradient_norms_on_various_models(
self, x_batch, model_name, output_dim, num_microbatches, is_eager
self,
x_batch,
model_name,
output_dim,
per_example_loss_fn,
num_microbatches,
is_eager,
):
if (
num_microbatches is not None
@ -442,6 +465,7 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
layer_generator=None,
input_dims=x_batch.shape[1:],
output_dim=output_dim,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
is_eager=is_eager,
x_input=x_batch,
@ -455,11 +479,17 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
input_dim=[1, 2],
output_dim=[1, 2],
per_example_loss_fn=[None, test_loss_fn],
num_microbatches=[None, 1, 2],
is_eager=[True, False],
)
def test_gradient_norms_on_various_models(
self, input_dim, output_dim, num_microbatches, is_eager
self,
input_dim,
output_dim,
per_example_loss_fn,
num_microbatches,
is_eager,
):
registry = layer_registry.make_default_layer_registry()
registry.insert(DoubleDense, double_dense_layer_computation)
@ -475,6 +505,7 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
layer_generator=lambda a, b: DoubleDense(b),
input_dims=input_dim,
output_dim=output_dim,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
is_eager=is_eager,
x_input=x_batch,