Efficient DPSGD with support to microbatched losses.

PiperOrigin-RevId: 513886957
This commit is contained in:
A. Unique TensorFlower 2023-03-03 12:03:15 -08:00
parent cbf34f2b04
commit 8bfafdd74d
6 changed files with 252 additions and 104 deletions

View file

@ -21,7 +21,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
`compute_gradient_norms()` function). `compute_gradient_norms()` function).
""" """
from typing import Dict, Iterable, Text, Union from typing import Dict, Iterable, Optional, Text, Union
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
@ -31,7 +31,9 @@ InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
def get_registry_generator_fn( def get_registry_generator_fn(
tape: tf.GradientTape, layer_registry: lr.LayerRegistry tape: tf.GradientTape,
layer_registry: lr.LayerRegistry,
num_microbatches: Optional[lr.BatchSize] = None,
): ):
"""Creates the generator function for `compute_gradient_norms()`.""" """Creates the generator function for `compute_gradient_norms()`."""
if layer_registry is None: if layer_registry is None:
@ -50,7 +52,7 @@ def get_registry_generator_fn(
) )
registry_fn = layer_registry.lookup(layer_instance) registry_fn = layer_registry.lookup(layer_instance)
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn( (layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
layer_instance, args, tape layer_instance, args, tape, num_microbatches
) )
return layer_outputs, (layer_vars, layer_sqr_norm_fn) return layer_outputs, (layer_vars, layer_sqr_norm_fn)
else: else:
@ -65,6 +67,7 @@ def compute_gradient_norms(
x_batch: InputTensor, x_batch: InputTensor,
y_batch: tf.Tensor, y_batch: tf.Tensor,
layer_registry: lr.LayerRegistry, layer_registry: lr.LayerRegistry,
num_microbatches: Optional[lr.BatchSize] = None,
): ):
"""Computes the per-example loss gradient norms for given data. """Computes the per-example loss gradient norms for given data.
@ -83,13 +86,21 @@ def compute_gradient_norms(
compute gradient norms quickly. See compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details. more details.
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
microbatches. If not None, indicates that the loss is grouped into
num_microbatches (in this case, the batch dimension needs to be a multiple
of num_microbatches). When there is microbatches, we always assume the
loss is the mean over a microbatch. And the gradient norm is computed for
each microbatch.
Returns: Returns:
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
per-example loss function. per-example loss function.
""" """
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
registry_generator_fn = get_registry_generator_fn(tape, layer_registry) registry_generator_fn = get_registry_generator_fn(
tape, layer_registry, num_microbatches
)
# First loop computes the model outputs, summed loss, and generator outputs. # First loop computes the model outputs, summed loss, and generator outputs.
with tape: with tape:
model_outputs, generator_outputs_list = ( model_outputs, generator_outputs_list = (
@ -102,6 +113,10 @@ def compute_gradient_norms(
loss_config['reduction'] = tf.keras.losses.Reduction.NONE loss_config['reduction'] = tf.keras.losses.Reduction.NONE
per_example_loss_fn = input_model.loss.from_config(loss_config) per_example_loss_fn = input_model.loss.from_config(loss_config)
losses = per_example_loss_fn(y_batch, model_outputs) losses = per_example_loss_fn(y_batch, model_outputs)
if num_microbatches is not None:
losses = tf.reduce_mean(
lr.add_microbatch_axis(losses, num_microbatches), axis=1
)
summed_loss = tf.reduce_sum(losses) summed_loss = tf.reduce_sum(losses)
# Unwrap the generator outputs so that the next loop avoids duplicating # Unwrap the generator outputs so that the next loop avoids duplicating
# backprop ops. # backprop ops.
@ -149,6 +164,7 @@ def compute_pred_and_clipped_gradients(
y_batch: tf.Tensor, y_batch: tf.Tensor,
l2_norm_clip: float, l2_norm_clip: float,
layer_registry: lr.LayerRegistry, layer_registry: lr.LayerRegistry,
num_microbatches: Optional[lr.BatchSize] = None,
): ):
"""Computes the per-example predictions and per-example clipped loss gradient. """Computes the per-example predictions and per-example clipped loss gradient.
@ -177,6 +193,10 @@ def compute_pred_and_clipped_gradients(
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the `output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable weights (see `layer_registry_factories.py` for examples). trainable weights (see `layer_registry_factories.py` for examples).
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
microbatches. If not None, indicates that the loss is grouped into
num_microbatches (in this case, the batch dimension needs to be a multiple
of num_microbatches).
Returns: Returns:
A `tuple` `(y_pred, grad)`. The first element is the prediction generated by A `tuple` `(y_pred, grad)`. The first element is the prediction generated by
@ -184,11 +204,21 @@ def compute_pred_and_clipped_gradients(
gradient of the loss function. gradient of the loss function.
""" """
gradient_norms = compute_gradient_norms( gradient_norms = compute_gradient_norms(
input_model, x_batch, y_batch, layer_registry input_model, x_batch, y_batch, layer_registry, num_microbatches
) )
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms) loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
y_pred = input_model(x_batch, training=True) y_pred = input_model(x_batch, training=True)
if num_microbatches is not None:
y_batch = lr.add_microbatch_axis(y_batch, num_microbatches)
y_pred = lr.add_microbatch_axis(y_pred, num_microbatches)
# Warning: When num_microbatches is not None, we need to be sure that
# `compute_loss` always computes the mean over the microbatches
# as it is the assumption made when computing the gradient norm.
# It is indeed the case for multiple keras loss functions
# (e.g. mean_squared_error and binary_crossentropy). However it
# is not defined in the contract so may not hold, especially for
# custom losses.
loss_value = input_model.compute_loss( loss_value = input_model.compute_loss(
x_batch, y_batch, y_pred, loss_weights x_batch, y_batch, y_pred, loss_weights
) )

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import itertools import itertools
from typing import Callable, Any, List, Union from typing import Any, Callable, List, Optional, Union
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
@ -49,14 +49,17 @@ class DoubleDense(tf.keras.layers.Layer):
def double_dense_layer_computation( def double_dense_layer_computation(
layer_instance: tf.keras.layers.Layer, inputs: Any, tape: tf.GradientTape layer_instance: tf.keras.layers.Layer,
inputs: Any,
tape: tf.GradientTape,
num_microbatches: Optional[int],
): ):
"""Layer registry function for the custom `DoubleDense` layer class.""" """Layer registry function for the custom `DoubleDense` layer class."""
vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation( vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation(
layer_instance.dense1, inputs, tape layer_instance.dense1, inputs, tape, num_microbatches
) )
vars2, outputs, sqr_norm_fn2 = layer_registry.dense_layer_computation( vars2, outputs, sqr_norm_fn2 = layer_registry.dense_layer_computation(
layer_instance.dense2, (outputs,), tape layer_instance.dense2, (outputs,), tape, num_microbatches
) )
def sqr_norm_fn(base_vars): def sqr_norm_fn(base_vars):
@ -68,7 +71,10 @@ def double_dense_layer_computation(
def compute_true_gradient_norms( def compute_true_gradient_norms(
input_model: tf.keras.Model, x_batch: tf.Tensor, y_batch: tf.Tensor input_model: tf.keras.Model,
x_batch: tf.Tensor,
y_batch: tf.Tensor,
num_microbatches: Optional[int],
): ):
"""Computes the real gradient norms for an input `(model, x, y)`.""" """Computes the real gradient norms for an input `(model, x, y)`."""
loss_config = input_model.loss.get_config() loss_config = input_model.loss.get_config()
@ -77,13 +83,22 @@ def compute_true_gradient_norms(
with tf.GradientTape(persistent=True) as tape: with tf.GradientTape(persistent=True) as tape:
y_pred = input_model(x_batch) y_pred = input_model(x_batch)
loss = per_example_loss_fn(y_batch, y_pred) loss = per_example_loss_fn(y_batch, y_pred)
if num_microbatches is not None:
loss = tf.reduce_mean(
tf.reshape(
loss,
tf.concat([[num_microbatches, -1], tf.shape(loss)[1:]], axis=0),
),
axis=1,
)
if isinstance(loss, tf.RaggedTensor): if isinstance(loss, tf.RaggedTensor):
loss = loss.to_tensor() loss = loss.to_tensor()
sqr_norms = [] sqr_norms = []
for var in input_model.trainable_variables: for var in input_model.trainable_variables:
jacobian = tape.jacobian(loss, var, experimental_use_pfor=False) jacobian = tape.jacobian(loss, var, experimental_use_pfor=False)
reduction_axes = tf.range(1, len(jacobian.shape)) reduction_axes = tf.range(1, len(jacobian.shape))
sqr_norms.append(tf.reduce_sum(tf.square(jacobian), axis=reduction_axes)) sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes)
sqr_norms.append(sqr_norm)
sqr_norm_tsr = tf.stack(sqr_norms, axis=1) sqr_norm_tsr = tf.stack(sqr_norms, axis=1)
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
@ -93,6 +108,7 @@ def get_computed_and_true_norms(
layer_generator: LayerGenerator, layer_generator: LayerGenerator,
input_dims: Union[int, List[int]], input_dims: Union[int, List[int]],
output_dim: int, output_dim: int,
num_microbatches: Optional[int],
is_eager: bool, is_eager: bool,
x_input: tf.Tensor, x_input: tf.Tensor,
rng_seed: int = 777, rng_seed: int = 777,
@ -113,6 +129,7 @@ def get_computed_and_true_norms(
`idim` and returns output tensors of dimension `odim`. `idim` and returns output tensors of dimension `odim`.
input_dims: The input dimension(s) of the test `tf.keras.Model` instance. input_dims: The input dimension(s) of the test `tf.keras.Model` instance.
output_dim: The output dimension of the test `tf.keras.Model` instance. output_dim: The output dimension of the test `tf.keras.Model` instance.
num_microbatches: The number of microbatches. None or an integer.
is_eager: A `bool` that is `True` if the model should be run eagerly. is_eager: A `bool` that is `True` if the model should be run eagerly.
x_input: `tf.Tensor` inputs to be tested. x_input: `tf.Tensor` inputs to be tested.
rng_seed: An `int` used to initialize model weights. rng_seed: An `int` used to initialize model weights.
@ -137,10 +154,16 @@ def get_computed_and_true_norms(
y_batch = tf.ones_like(y_pred) y_batch = tf.ones_like(y_pred)
tf.keras.utils.set_random_seed(rng_seed) tf.keras.utils.set_random_seed(rng_seed)
computed_norms = clip_grads.compute_gradient_norms( computed_norms = clip_grads.compute_gradient_norms(
model, x_input, y_batch, layer_registry=registry model,
x_input,
y_batch,
layer_registry=registry,
num_microbatches=num_microbatches,
) )
tf.keras.utils.set_random_seed(rng_seed) tf.keras.utils.set_random_seed(rng_seed)
true_norms = compute_true_gradient_norms(model, x_input, y_batch) true_norms = compute_true_gradient_norms(
model, x_input, y_batch, num_microbatches
)
return (computed_norms, true_norms) return (computed_norms, true_norms)
@ -322,18 +345,30 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product( @parameterized.product(
model_name=list(get_dense_model_generators().keys()), model_name=list(get_dense_model_generators().keys()),
layer_name=list(get_dense_layer_generators().keys()), layer_name=list(get_dense_layer_generators().keys()),
input_dim=[1, 2], input_dim=[4],
output_dim=[1, 2], output_dim=[1, 2],
num_microbatches=[None, 1, 2],
is_eager=[True, False], is_eager=[True, False],
) )
def test_gradient_norms_on_various_models( def test_gradient_norms_on_various_models(
self, model_name, layer_name, input_dim, output_dim, is_eager self,
model_name,
layer_name,
input_dim,
output_dim,
num_microbatches,
is_eager,
): ):
model_generator = get_dense_model_generators()[model_name] model_generator = get_dense_model_generators()[model_name]
layer_generator = get_dense_layer_generators()[layer_name] layer_generator = get_dense_layer_generators()[layer_name]
x_batches = get_nd_test_batches(input_dim) x_batches = get_nd_test_batches(input_dim)
default_registry = layer_registry.make_default_layer_registry() default_registry = layer_registry.make_default_layer_registry()
for x_batch in x_batches: for x_batch in x_batches:
if (
num_microbatches is not None
and x_batch.shape[0] % num_microbatches != 0
):
continue
if model_name == 'tower1': if model_name == 'tower1':
x_input = [x_batch, x_batch] x_input = [x_batch, x_batch]
else: else:
@ -343,6 +378,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
layer_generator, layer_generator,
input_dim, input_dim,
output_dim, output_dim,
num_microbatches,
is_eager, is_eager,
x_input, x_input,
registry=default_registry, registry=default_registry,
@ -362,6 +398,10 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
tf.ragged.constant( tf.ragged.constant(
[[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32 [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32
), ),
tf.ragged.constant(
[[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1], [0, 1]],
dtype=tf.int32,
),
# 3D inputs. # 3D inputs.
tf.convert_to_tensor([[[0, 1]]], dtype_hint=tf.int32), tf.convert_to_tensor([[[0, 1]]], dtype_hint=tf.int32),
tf.convert_to_tensor( tf.convert_to_tensor(
@ -371,14 +411,24 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
[[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]], [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]],
dtype=tf.int32, dtype=tf.int32,
), ),
tf.ragged.constant(
[[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]], [[0]]],
dtype=tf.int32,
),
], ],
model_name=list(get_embedding_model_generators().keys()), model_name=list(get_embedding_model_generators().keys()),
output_dim=[1, 2], output_dim=[2],
is_eager=[True, False], num_microbatches=[None, 1, 2],
is_eager=[True],
) )
def test_gradient_norms_on_various_models( def test_gradient_norms_on_various_models(
self, x_batch, model_name, output_dim, is_eager self, x_batch, model_name, output_dim, num_microbatches, is_eager
): ):
if (
num_microbatches is not None
and x_batch.shape[0] % num_microbatches != 0
):
return
valid_test_input = ( valid_test_input = (
not isinstance(x_batch, tf.RaggedTensor) not isinstance(x_batch, tf.RaggedTensor)
and model_name == 'weighted_bow1' and model_name == 'weighted_bow1'
@ -391,6 +441,7 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
layer_generator=None, layer_generator=None,
input_dims=x_batch.shape[1:], input_dims=x_batch.shape[1:],
output_dim=output_dim, output_dim=output_dim,
num_microbatches=num_microbatches,
is_eager=is_eager, is_eager=is_eager,
x_input=x_batch, x_input=x_batch,
registry=default_registry, registry=default_registry,
@ -403,20 +454,27 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product( @parameterized.product(
input_dim=[1, 2], input_dim=[1, 2],
output_dim=[1, 2], output_dim=[1, 2],
num_microbatches=[None, 1, 2],
is_eager=[True, False], is_eager=[True, False],
) )
def test_gradient_norms_on_various_models( def test_gradient_norms_on_various_models(
self, input_dim, output_dim, is_eager self, input_dim, output_dim, num_microbatches, is_eager
): ):
registry = layer_registry.make_default_layer_registry() registry = layer_registry.make_default_layer_registry()
registry.insert(DoubleDense, double_dense_layer_computation) registry.insert(DoubleDense, double_dense_layer_computation)
x_batches = get_nd_test_batches(input_dim) x_batches = get_nd_test_batches(input_dim)
for x_batch in x_batches: for x_batch in x_batches:
if (
num_microbatches is not None
and x_batch.shape[0] % num_microbatches != 0
):
continue
(computed_norms, true_norms) = get_computed_and_true_norms( (computed_norms, true_norms) = get_computed_and_true_norms(
model_generator=make_two_layer_sequential_model, model_generator=make_two_layer_sequential_model,
layer_generator=lambda a, b: DoubleDense(b), layer_generator=lambda a, b: DoubleDense(b),
input_dims=input_dim, input_dims=input_dim,
output_dim=output_dim, output_dim=output_dim,
num_microbatches=num_microbatches,
is_eager=is_eager, is_eager=is_eager,
x_input=x_batch, x_input=x_batch,
registry=registry, registry=registry,

View file

@ -157,8 +157,8 @@ def all_trainable_layers_are_registered(
def add_aggregate_noise( def add_aggregate_noise(
input_model: tf.keras.Model, input_model: tf.keras.Model,
x_batch: InputTensor, clipped_grads: list[tf.Tensor],
clipped_grads: List[tf.Tensor], batch_size: tf.Tensor,
l2_norm_clip: float, l2_norm_clip: float,
noise_multiplier: float, noise_multiplier: float,
) -> List[tf.Tensor]: ) -> List[tf.Tensor]:
@ -169,8 +169,9 @@ def add_aggregate_noise(
Args: Args:
input_model: The `tf.keras.Model` to obtain the layers from. input_model: The `tf.keras.Model` to obtain the layers from.
x_batch: An `InputTensor` to be fed into the input layer of the model.
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients. clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
batch_size: The batch size, used for normalizing the noise, when the loss
reduction is AUTO or SUM_OVER_BATCH_SIZE.
l2_norm_clip: Clipping norm (max L2 norm of each gradient). l2_norm_clip: Clipping norm (max L2 norm of each gradient).
noise_multiplier: Ratio of the standard deviation to the clipping norm. noise_multiplier: Ratio of the standard deviation to the clipping norm.
@ -186,17 +187,7 @@ def add_aggregate_noise(
]: ]:
if input_model.loss.reduction == tf.keras.losses.Reduction.AUTO: if input_model.loss.reduction == tf.keras.losses.Reduction.AUTO:
logging.info('Assuming that the loss reduction is `SUM_OVER_BATCH_SIZE`.') logging.info('Assuming that the loss reduction is `SUM_OVER_BATCH_SIZE`.')
if isinstance(x_batch, tf.Tensor): scale /= tf.cast(batch_size, tf.float32)
scale /= tf.cast(tf.shape(x_batch)[0], tf.float32)
elif isinstance(x_batch, dict):
batch_sizes = [
tf.cast(tf.shape(v)[0], tf.float32) for v in x_batch.values()
]
scale /= tf.math.reduce_min(batch_sizes)
else:
raise NotImplementedError(
'Unknown container/class %s for input' % x_batch.__class__.__name__
)
def add_noise(g): def add_noise(g):
return g + tf.random.normal( return g + tf.random.normal(

View file

@ -38,9 +38,18 @@ whose i-th entry is the L2 norm of the i-th input vector, then
where `l2_row_norm(y)` computes the L2 norm for each row of an input `y`. where `l2_row_norm(y)` computes the L2 norm for each row of an input `y`.
Details of this decomposition can be found in https://arxiv.org/abs/1510.01799 Details of this decomposition can be found in https://arxiv.org/abs/1510.01799
"""
from typing import Any, Callable, Dict, Iterable, Text, Tuple, Type, Union We also extend fast gradient norm computation to the case when the losses
are microbatched, i.e. each per example loss is the mean of a set of losses.
This could be useful for achieving user-level privacy and for improving the
quality of DP models, through better estimation of the gradients due to
aggregation at the microbatch level.
"""
# copybara.strip_begin
# The detailed algorithm can be found in go/fast-dpsgd-mb.
# copybara.strip_end
from typing import Any, Callable, Dict, Iterable, Optional, Text, Tuple, Type, Union
import tensorflow as tf import tensorflow as tf
@ -56,6 +65,7 @@ RegistryFunction = Callable[
] ]
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
BatchSize = Union[int, tf.Tensor]
# ============================================================================== # ==============================================================================
@ -88,6 +98,37 @@ class LayerRegistry:
self._registry[layer_key] = layer_registry_function self._registry[layer_key] = layer_registry_function
# ==============================================================================
# Utilities
# ==============================================================================
def add_microbatch_axis(
x: tf.Tensor,
num_microbatches: Optional[BatchSize],
) -> tf.Tensor:
"""Adds the microbatch axis.
Reshape the input tensor to replace the first(batch) dimension with the
shape [num_microbatches, batch_size / num_microbatches]. The batch size
must be a multiple of num_microbatches (unless it is None, meaning
num_microbatches is the same as the batch size).
Args:
x: the input tensor.
num_microbatches: None or a numeric value or a scalar `tf.Tensor`.
Returns:
The reshaped input tensor.
"""
if num_microbatches is None:
return tf.expand_dims(x, 1)
with tf.control_dependencies(
[tf.assert_equal(tf.math.floormod(tf.shape(x)[0], num_microbatches), 0)]
):
return tf.reshape(
x, tf.concat([[num_microbatches, -1], tf.shape(x)[1:]], axis=0)
)
# ============================================================================== # ==============================================================================
# Supported Keras layers # Supported Keras layers
# ============================================================================== # ==============================================================================
@ -95,6 +136,7 @@ def dense_layer_computation(
layer_instance: tf.keras.layers.Dense, layer_instance: tf.keras.layers.Dense,
inputs: Tuple[InputTensor], inputs: Tuple[InputTensor],
tape: tf.GradientTape, tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None,
) -> RegistryFunctionOutput: ) -> RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Dense`. """Registry function for `tf.keras.layers.Dense`.
@ -111,6 +153,9 @@ def dense_layer_computation(
output. output.
tape: A `tf.GradientTape` instance that will be used to watch the output tape: A `tf.GradientTape` instance that will be used to watch the output
`base_vars`. `base_vars`.
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
indicating whether and how the losses are grouped into microbatches. If
not None, num_microbatches must divide the batch size.
Returns: Returns:
A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the
@ -132,21 +177,29 @@ def dense_layer_computation(
tape.watch(base_vars) tape.watch(base_vars)
layer_instance.activation = orig_activation layer_instance.activation = orig_activation
outputs = orig_activation(base_vars) if orig_activation else base_vars outputs = orig_activation(base_vars) if orig_activation else base_vars
def sqr_norm_fn(base_vars_grads): def sqr_norm_fn(base_vars_grads):
sqr_inputs = tf.square(*inputs)
inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs)) def _compute_gramian(x):
input_sqr_norms = tf.reduce_sum(sqr_inputs, axis=inputs_reduction_axes) if num_microbatches is not None:
x_microbatched = add_microbatch_axis(x, num_microbatches)
return tf.matmul(x_microbatched, x_microbatched, transpose_b=True)
else:
# Special handling for better efficiency
return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x)))
inputs_gram = _compute_gramian(*inputs)
base_vars_grads_gram = _compute_gramian(base_vars_grads)
if layer_instance.use_bias: if layer_instance.use_bias:
# Adding a bias term is equivalent to a layer with no bias term and which # Adding a bias term is equivalent to a layer with no bias term and which
# adds an additional variable to the layer input that only takes a # adds an additional variable to the layer input that only takes a
# constant value of 1.0. This is thus equivalent to adding 1.0 to the sum # constant value of 1.0. This is thus equivalent to adding 1.0 to the sum
# of the squared values of the inputs. # of the squared values of the inputs.
input_sqr_norms += tf.cast(1.0, dtype=input_sqr_norms.dtype) inputs_gram += 1.0
reduction_axes = tf.range(1, tf.rank(base_vars_grads)) return tf.reduce_sum(
base_vars_sqr_norms = tf.reduce_sum( inputs_gram * base_vars_grads_gram,
tf.square(base_vars_grads), axis=reduction_axes axis=tf.range(1, tf.rank(inputs_gram)),
) )
return input_sqr_norms * base_vars_sqr_norms
return base_vars, outputs, sqr_norm_fn return base_vars, outputs, sqr_norm_fn
@ -155,6 +208,7 @@ def embedding_layer_computation(
layer_instance: tf.keras.layers.Embedding, layer_instance: tf.keras.layers.Embedding,
inputs: Tuple[InputTensor], inputs: Tuple[InputTensor],
tape: tf.GradientTape, tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None,
) -> RegistryFunctionOutput: ) -> RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Embedding`. """Registry function for `tf.keras.layers.Embedding`.
@ -171,6 +225,9 @@ def embedding_layer_computation(
output. output.
tape: A `tf.GradientTape` instance that will be used to watch the output tape: A `tf.GradientTape` instance that will be used to watch the output
`base_vars`. `base_vars`.
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
indicating whether and how the losses are grouped into microbatches. If
not None, num_microbatches must divide the batch size.
Returns: Returns:
A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the
@ -219,6 +276,13 @@ def embedding_layer_computation(
raise NotImplementedError( raise NotImplementedError(
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__ "Cannot parse input_ids of type %s" % input_ids.__class__.__name__
) )
row_indices = tf.cast(row_indices, tf.int32)
if num_microbatches is not None:
microbatch_size = tf.cast(nrows / num_microbatches, tf.int32)
nrows = num_microbatches
row_indices = tf.cast(
tf.math.floordiv(row_indices, microbatch_size), tf.int32
)
# Sum-reduce the `IndexSlices` that is the result of a `tape.gradient()` # Sum-reduce the `IndexSlices` that is the result of a `tape.gradient()`
# call. The sum is reduced by the repeated embedding indices and batch # call. The sum is reduced by the repeated embedding indices and batch
# index. It is adapted from the logic in: # index. It is adapted from the logic in:

View file

@ -39,11 +39,22 @@ def make_dp_model_class(cls):
`(i.e., tf.shape(loss)[1] == 1)` and whose i-th row corresponds to `(i.e., tf.shape(loss)[1] == 1)` and whose i-th row corresponds to
the loss of the i-th example. the loss of the i-th example.
This clipping algorithm specifically computes clipped gradients at the This clipping algorithm specifically computes clipped gradients at the
per-example level using the layer registry functions in `layer_registry` per-example or per microbatch (when `num_microbatches` is not None)
(see clip_grads.py for more information about the algorithm). In this level using the layer registry functions in `layer_registry` (see
setting, microbatching is not used (it is equivalent to clip_grads.py for more information about the algorithm).
`num_microbatches == batch_size`), and the input `num_microbatches`
is ignored. WARNING: with faster gradient clipping, and when num_microbatches is not
None, the per microbatch loss is assumed to be computed as the mean
of the loss over the microbatch, or effectively, by reshaping the loss
from the shape [batch_size, ...] to the shape
[num_microbatches, batch_size/num_microbatches, ...] and computing the
mean of the loss over the microbatches. This would require that the loss
function behaves accordingly. This is true for multiple common
predefined keras loss functions (e.g. mean_squared_loss,
binary_crossentropy) but may not hold for custom losses (and how such
aggregation is done is not exposed by the loss function, unfortunately).
It is the caller's responsibility to make sure that the loss function
does behave this way.
When instantiating this class, you need to supply several When instantiating this class, you need to supply several
DP-related arguments followed by the standard arguments for DP-related arguments followed by the standard arguments for
@ -115,6 +126,7 @@ def make_dp_model_class(cls):
if isinstance(num_microbatches, bool): if isinstance(num_microbatches, bool):
raise ValueError('Boolean value supplied for `num_microbatches`. ' raise ValueError('Boolean value supplied for `num_microbatches`. '
'Did you intend it for `use_xla`?') 'Did you intend it for `use_xla`?')
self._num_microbatches = num_microbatches
# If all the trainable layers are in the input layer registry, we # If all the trainable layers are in the input layer registry, we
# don't need to use microbatching and can instead use the "fast" # don't need to use microbatching and can instead use the "fast"
@ -126,16 +138,8 @@ def make_dp_model_class(cls):
) )
and gradient_clipping_utils.has_internal_compute_graph(self) and gradient_clipping_utils.has_internal_compute_graph(self)
): ):
if num_microbatches is not None:
raise ValueError(
'Cannot initialize a model where num_microbatches '
'is not `None` and all trainable layers are '
'registered in layer_registry.'
)
self._num_microbatches = None
self._enable_fast_peg_computation = True self._enable_fast_peg_computation = True
else: else:
self._num_microbatches = num_microbatches
self._enable_fast_peg_computation = False self._enable_fast_peg_computation = False
if use_xla: if use_xla:
@ -198,10 +202,20 @@ def make_dp_model_class(cls):
# trick, and uses these norms to clip the per-example gradients. # trick, and uses these norms to clip the per-example gradients.
x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data) x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data)
y_pred, clipped_grads = clip_grads.compute_pred_and_clipped_gradients( y_pred, clipped_grads = clip_grads.compute_pred_and_clipped_gradients(
self, x, y, self._l2_norm_clip, self._layer_registry self,
x,
y,
self._l2_norm_clip,
self._layer_registry,
self._num_microbatches,
) )
batch_size = self._num_microbatches or tf.shape(y)[0]
grads = gradient_clipping_utils.add_aggregate_noise( grads = gradient_clipping_utils.add_aggregate_noise(
self, x, clipped_grads, self._l2_norm_clip, self._noise_multiplier self,
clipped_grads,
batch_size,
self._l2_norm_clip,
self._noise_multiplier,
) )
else: else:
logging.info('Computing gradients using microbatching.') logging.info('Computing gradients using microbatching.')

View file

@ -139,9 +139,7 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
train_labels = np.array([[1.0], [3.0], [-2.0], [-4.0]]) train_labels = np.array([[1.0], [3.0], [-2.0], [-4.0]])
learning_rate = 1.0 learning_rate = 1.0
for test_reg, test_nm in zip( for test_reg in get_layer_registries():
get_layer_registries(), [num_microbatches, None]
):
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError() loss = tf.keras.losses.MeanSquaredError()
@ -149,7 +147,7 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
model = dp_keras_model.DPSequential( model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip, l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0, noise_multiplier=0.0,
num_microbatches=test_nm, num_microbatches=num_microbatches,
layer_registry=test_reg, layer_registry=test_reg,
layers=[ layers=[
tf.keras.layers.InputLayer(input_shape=(2,)), tf.keras.layers.InputLayer(input_shape=(2,)),
@ -173,10 +171,11 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
train_data, train_labels, w, l2_norm_clip, effective_num_microbatches train_data, train_labels, w, l2_norm_clip, effective_num_microbatches
) )
expected_weights = np.squeeze(-learning_rate * expected_grads) expected_weights = np.squeeze(-learning_rate * expected_grads)
self.assertAllClose(model_weights, expected_weights) self.assertAllClose(model_weights, expected_weights)
@parameterized.named_parameters( @parameterized.named_parameters(
('noise_multiplier 3 2 None', 3.0, 2.0, None),
('noise_multiplier 5 4 None', 5.0, 4.0, None),
('noise_multiplier 3 2 1', 3.0, 2.0, 1), ('noise_multiplier 3 2 1', 3.0, 2.0, 1),
('noise_multiplier 5 4 1', 5.0, 4.0, 1), ('noise_multiplier 5 4 1', 5.0, 4.0, 1),
('noise_multiplier 3 2 2', 3.0, 2.0, 2), ('noise_multiplier 3 2 2', 3.0, 2.0, 2),
@ -198,9 +197,7 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
learning_rate = 1.0 learning_rate = 1.0
for test_reg, test_nm in zip( for test_reg in get_layer_registries():
get_layer_registries(), [num_microbatches, None]
):
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError() loss = tf.keras.losses.MeanSquaredError()
@ -208,7 +205,7 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
model = dp_keras_model.DPSequential( model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip, l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier, noise_multiplier=noise_multiplier,
num_microbatches=test_nm, num_microbatches=num_microbatches,
layer_registry=test_reg, layer_registry=test_reg,
layers=[ layers=[
tf.keras.layers.InputLayer(input_shape=(1000,)), tf.keras.layers.InputLayer(input_shape=(1000,)),
@ -220,11 +217,7 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
model.compile(optimizer=optimizer, loss=loss) model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=4) model.fit(train_data, train_labels, epochs=1, batch_size=4)
effective_num_microbatches = ( effective_num_microbatches = num_microbatches or train_data.shape[0]
train_data.shape[0]
if model._num_microbatches is None
else num_microbatches
)
model_weights = model.get_weights() model_weights = model.get_weights()
measured_std = np.std(model_weights[0]) measured_std = np.std(model_weights[0])
@ -248,16 +241,14 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
train_labels = np.array([[0], [1], [1], [0]]) train_labels = np.array([[0], [1], [1], [0]])
learning_rate = 1.0 learning_rate = 1.0
for test_reg, test_nm in zip( for test_reg in get_layer_registries():
get_layer_registries(), [num_microbatches, None]
):
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model = dp_keras_model.DPSequential( model = dp_keras_model.DPSequential(
l2_norm_clip=1.0e9, l2_norm_clip=1.0e9,
noise_multiplier=0.0, noise_multiplier=0.0,
num_microbatches=test_nm, num_microbatches=num_microbatches,
layer_registry=test_reg, layer_registry=test_reg,
layers=[ layers=[
tf.keras.layers.InputLayer(input_shape=(2,)), tf.keras.layers.InputLayer(input_shape=(2,)),