forked from 626_privacy/tensorflow_privacy
Allow custom per example loss functions for computing per microbatch gradient norm.
PiperOrigin-RevId: 516897864
This commit is contained in:
parent
d7d497bb69
commit
8f4ab1a8bb
2 changed files with 51 additions and 11 deletions
|
@ -21,7 +21,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
|
||||||
`compute_gradient_norms()` function).
|
`compute_gradient_norms()` function).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Iterable, Optional, Text, Union
|
from typing import Any, Callable, Dict, Iterable, Optional, Text, Union
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
|
@ -67,6 +67,7 @@ def compute_gradient_norms(
|
||||||
x_batch: InputTensor,
|
x_batch: InputTensor,
|
||||||
y_batch: tf.Tensor,
|
y_batch: tf.Tensor,
|
||||||
layer_registry: lr.LayerRegistry,
|
layer_registry: lr.LayerRegistry,
|
||||||
|
per_example_loss_fn: Optional[Callable[[tf.Tensor, Any], tf.Tensor]] = None,
|
||||||
num_microbatches: Optional[lr.BatchSize] = None,
|
num_microbatches: Optional[lr.BatchSize] = None,
|
||||||
):
|
):
|
||||||
"""Computes the per-example loss gradient norms for given data.
|
"""Computes the per-example loss gradient norms for given data.
|
||||||
|
@ -86,6 +87,9 @@ def compute_gradient_norms(
|
||||||
compute gradient norms quickly. See
|
compute gradient norms quickly. See
|
||||||
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
|
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
|
||||||
more details.
|
more details.
|
||||||
|
per_example_loss_fn: If not None, used as the function to compute the
|
||||||
|
vectorized per example loss. Otherwise, we derive it from `input_model`'s
|
||||||
|
loss function.
|
||||||
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
|
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
|
||||||
microbatches. If not None, indicates that the loss is grouped into
|
microbatches. If not None, indicates that the loss is grouped into
|
||||||
num_microbatches (in this case, the batch dimension needs to be a multiple
|
num_microbatches (in this case, the batch dimension needs to be a multiple
|
||||||
|
@ -109,9 +113,10 @@ def compute_gradient_norms(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Ignore the original loss function's reduction to get per-example loss.
|
# Ignore the original loss function's reduction to get per-example loss.
|
||||||
loss_config = input_model.loss.get_config()
|
if per_example_loss_fn is None:
|
||||||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
loss_config = input_model.loss.get_config()
|
||||||
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||||
|
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
||||||
losses = per_example_loss_fn(y_batch, model_outputs)
|
losses = per_example_loss_fn(y_batch, model_outputs)
|
||||||
if num_microbatches is not None:
|
if num_microbatches is not None:
|
||||||
losses = tf.reduce_mean(
|
losses = tf.reduce_mean(
|
||||||
|
@ -204,7 +209,11 @@ def compute_pred_and_clipped_gradients(
|
||||||
gradient of the loss function.
|
gradient of the loss function.
|
||||||
"""
|
"""
|
||||||
gradient_norms = compute_gradient_norms(
|
gradient_norms = compute_gradient_norms(
|
||||||
input_model, x_batch, y_batch, layer_registry, num_microbatches
|
input_model,
|
||||||
|
x_batch,
|
||||||
|
y_batch,
|
||||||
|
layer_registry,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
)
|
)
|
||||||
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
|
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
|
||||||
with tf.GradientTape() as tape:
|
with tf.GradientTape() as tape:
|
||||||
|
|
|
@ -71,16 +71,25 @@ def double_dense_layer_computation(
|
||||||
return [vars1, vars2], outputs, sqr_norm_fn
|
return [vars1, vars2], outputs, sqr_norm_fn
|
||||||
|
|
||||||
|
|
||||||
|
def test_loss_fn(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
|
||||||
|
x = tf.reshape(x, (tf.shape(x)[0], -1))
|
||||||
|
y = tf.reshape(y, (tf.shape(y)[0], -1))
|
||||||
|
# Define a loss function which is unlikely to be coincidently defined.
|
||||||
|
return 3.14 * tf.reduce_sum(tf.square(x - y), axis=1)
|
||||||
|
|
||||||
|
|
||||||
def compute_true_gradient_norms(
|
def compute_true_gradient_norms(
|
||||||
input_model: tf.keras.Model,
|
input_model: tf.keras.Model,
|
||||||
x_batch: tf.Tensor,
|
x_batch: tf.Tensor,
|
||||||
y_batch: tf.Tensor,
|
y_batch: tf.Tensor,
|
||||||
|
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
|
||||||
num_microbatches: Optional[int],
|
num_microbatches: Optional[int],
|
||||||
) -> layer_registry.OutputTensor:
|
) -> layer_registry.OutputTensor:
|
||||||
"""Computes the real gradient norms for an input `(model, x, y)`."""
|
"""Computes the real gradient norms for an input `(model, x, y)`."""
|
||||||
loss_config = input_model.loss.get_config()
|
if per_example_loss_fn is None:
|
||||||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
loss_config = input_model.loss.get_config()
|
||||||
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||||
|
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
||||||
with tf.GradientTape(persistent=True) as tape:
|
with tf.GradientTape(persistent=True) as tape:
|
||||||
y_pred = input_model(x_batch)
|
y_pred = input_model(x_batch)
|
||||||
loss = per_example_loss_fn(y_batch, y_pred)
|
loss = per_example_loss_fn(y_batch, y_pred)
|
||||||
|
@ -109,6 +118,7 @@ def get_computed_and_true_norms(
|
||||||
layer_generator: LayerGenerator,
|
layer_generator: LayerGenerator,
|
||||||
input_dims: Union[int, List[int]],
|
input_dims: Union[int, List[int]],
|
||||||
output_dim: int,
|
output_dim: int,
|
||||||
|
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
|
||||||
num_microbatches: Optional[int],
|
num_microbatches: Optional[int],
|
||||||
is_eager: bool,
|
is_eager: bool,
|
||||||
x_input: tf.Tensor,
|
x_input: tf.Tensor,
|
||||||
|
@ -130,6 +140,8 @@ def get_computed_and_true_norms(
|
||||||
`idim` and returns output tensors of dimension `odim`.
|
`idim` and returns output tensors of dimension `odim`.
|
||||||
input_dims: The input dimension(s) of the test `tf.keras.Model` instance.
|
input_dims: The input dimension(s) of the test `tf.keras.Model` instance.
|
||||||
output_dim: The output dimension of the test `tf.keras.Model` instance.
|
output_dim: The output dimension of the test `tf.keras.Model` instance.
|
||||||
|
per_example_loss_fn: If not None, used as vectorized per example loss
|
||||||
|
function.
|
||||||
num_microbatches: The number of microbatches. None or an integer.
|
num_microbatches: The number of microbatches. None or an integer.
|
||||||
is_eager: A `bool` that is `True` if the model should be run eagerly.
|
is_eager: A `bool` that is `True` if the model should be run eagerly.
|
||||||
x_input: `tf.Tensor` inputs to be tested.
|
x_input: `tf.Tensor` inputs to be tested.
|
||||||
|
@ -159,11 +171,12 @@ def get_computed_and_true_norms(
|
||||||
x_input,
|
x_input,
|
||||||
y_batch,
|
y_batch,
|
||||||
layer_registry=registry,
|
layer_registry=registry,
|
||||||
|
per_example_loss_fn=per_example_loss_fn,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
)
|
)
|
||||||
tf.keras.utils.set_random_seed(rng_seed)
|
tf.keras.utils.set_random_seed(rng_seed)
|
||||||
true_norms = compute_true_gradient_norms(
|
true_norms = compute_true_gradient_norms(
|
||||||
model, x_input, y_batch, num_microbatches
|
model, x_input, y_batch, per_example_loss_fn, num_microbatches
|
||||||
)
|
)
|
||||||
return (computed_norms, true_norms)
|
return (computed_norms, true_norms)
|
||||||
|
|
||||||
|
@ -348,6 +361,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
layer_name=list(get_dense_layer_generators().keys()),
|
layer_name=list(get_dense_layer_generators().keys()),
|
||||||
input_dim=[4],
|
input_dim=[4],
|
||||||
output_dim=[1, 2],
|
output_dim=[1, 2],
|
||||||
|
per_example_loss_fn=[None, test_loss_fn],
|
||||||
num_microbatches=[None, 1, 2],
|
num_microbatches=[None, 1, 2],
|
||||||
is_eager=[True, False],
|
is_eager=[True, False],
|
||||||
)
|
)
|
||||||
|
@ -357,6 +371,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
layer_name,
|
layer_name,
|
||||||
input_dim,
|
input_dim,
|
||||||
output_dim,
|
output_dim,
|
||||||
|
per_example_loss_fn,
|
||||||
num_microbatches,
|
num_microbatches,
|
||||||
is_eager,
|
is_eager,
|
||||||
):
|
):
|
||||||
|
@ -379,6 +394,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
layer_generator,
|
layer_generator,
|
||||||
input_dim,
|
input_dim,
|
||||||
output_dim,
|
output_dim,
|
||||||
|
per_example_loss_fn,
|
||||||
num_microbatches,
|
num_microbatches,
|
||||||
is_eager,
|
is_eager,
|
||||||
x_input,
|
x_input,
|
||||||
|
@ -419,11 +435,18 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
],
|
],
|
||||||
model_name=list(get_embedding_model_generators().keys()),
|
model_name=list(get_embedding_model_generators().keys()),
|
||||||
output_dim=[2],
|
output_dim=[2],
|
||||||
|
per_example_loss_fn=[None, test_loss_fn],
|
||||||
num_microbatches=[None, 1, 2],
|
num_microbatches=[None, 1, 2],
|
||||||
is_eager=[True],
|
is_eager=[True],
|
||||||
)
|
)
|
||||||
def test_gradient_norms_on_various_models(
|
def test_gradient_norms_on_various_models(
|
||||||
self, x_batch, model_name, output_dim, num_microbatches, is_eager
|
self,
|
||||||
|
x_batch,
|
||||||
|
model_name,
|
||||||
|
output_dim,
|
||||||
|
per_example_loss_fn,
|
||||||
|
num_microbatches,
|
||||||
|
is_eager,
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
num_microbatches is not None
|
num_microbatches is not None
|
||||||
|
@ -442,6 +465,7 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
layer_generator=None,
|
layer_generator=None,
|
||||||
input_dims=x_batch.shape[1:],
|
input_dims=x_batch.shape[1:],
|
||||||
output_dim=output_dim,
|
output_dim=output_dim,
|
||||||
|
per_example_loss_fn=per_example_loss_fn,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
is_eager=is_eager,
|
is_eager=is_eager,
|
||||||
x_input=x_batch,
|
x_input=x_batch,
|
||||||
|
@ -455,11 +479,17 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
@parameterized.product(
|
@parameterized.product(
|
||||||
input_dim=[1, 2],
|
input_dim=[1, 2],
|
||||||
output_dim=[1, 2],
|
output_dim=[1, 2],
|
||||||
|
per_example_loss_fn=[None, test_loss_fn],
|
||||||
num_microbatches=[None, 1, 2],
|
num_microbatches=[None, 1, 2],
|
||||||
is_eager=[True, False],
|
is_eager=[True, False],
|
||||||
)
|
)
|
||||||
def test_gradient_norms_on_various_models(
|
def test_gradient_norms_on_various_models(
|
||||||
self, input_dim, output_dim, num_microbatches, is_eager
|
self,
|
||||||
|
input_dim,
|
||||||
|
output_dim,
|
||||||
|
per_example_loss_fn,
|
||||||
|
num_microbatches,
|
||||||
|
is_eager,
|
||||||
):
|
):
|
||||||
registry = layer_registry.make_default_layer_registry()
|
registry = layer_registry.make_default_layer_registry()
|
||||||
registry.insert(DoubleDense, double_dense_layer_computation)
|
registry.insert(DoubleDense, double_dense_layer_computation)
|
||||||
|
@ -475,6 +505,7 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
layer_generator=lambda a, b: DoubleDense(b),
|
layer_generator=lambda a, b: DoubleDense(b),
|
||||||
input_dims=input_dim,
|
input_dims=input_dim,
|
||||||
output_dim=output_dim,
|
output_dim=output_dim,
|
||||||
|
per_example_loss_fn=per_example_loss_fn,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
is_eager=is_eager,
|
is_eager=is_eager,
|
||||||
x_input=x_batch,
|
x_input=x_batch,
|
||||||
|
|
Loading…
Reference in a new issue