From 8bfafdd74d42cfd0426fed2efdc63af78cbfa468 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 3 Mar 2023 12:03:15 -0800 Subject: [PATCH] Efficient DPSGD with support to microbatched losses. PiperOrigin-RevId: 513886957 --- .../fast_gradient_clipping/clip_grads.py | 42 ++++++-- .../fast_gradient_clipping/clip_grads_test.py | 86 +++++++++++++--- .../gradient_clipping_utils.py | 19 +--- .../fast_gradient_clipping/layer_registry.py | 84 ++++++++++++++-- .../privacy/keras_models/dp_keras_model.py | 98 +++++++++++-------- .../keras_models/dp_keras_model_test.py | 27 ++--- 6 files changed, 252 insertions(+), 104 deletions(-) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index 4af6695..32880af 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -21,7 +21,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the `compute_gradient_norms()` function). """ -from typing import Dict, Iterable, Text, Union +from typing import Dict, Iterable, Optional, Text, Union import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils @@ -31,7 +31,9 @@ InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] def get_registry_generator_fn( - tape: tf.GradientTape, layer_registry: lr.LayerRegistry + tape: tf.GradientTape, + layer_registry: lr.LayerRegistry, + num_microbatches: Optional[lr.BatchSize] = None, ): """Creates the generator function for `compute_gradient_norms()`.""" if layer_registry is None: @@ -50,14 +52,14 @@ def get_registry_generator_fn( ) registry_fn = layer_registry.lookup(layer_instance) (layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn( - layer_instance, args, tape + layer_instance, args, tape, num_microbatches ) return layer_outputs, (layer_vars, layer_sqr_norm_fn) else: # Non-trainable layer. return layer_instance(*args, **kwargs), None - return registry_generator_fn + return registry_generator_fn def compute_gradient_norms( @@ -65,6 +67,7 @@ def compute_gradient_norms( x_batch: InputTensor, y_batch: tf.Tensor, layer_registry: lr.LayerRegistry, + num_microbatches: Optional[lr.BatchSize] = None, ): """Computes the per-example loss gradient norms for given data. @@ -83,13 +86,21 @@ def compute_gradient_norms( compute gradient norms quickly. See `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for more details. + num_microbatches: An optional number or scalar `tf.Tensor` for the number of + microbatches. If not None, indicates that the loss is grouped into + num_microbatches (in this case, the batch dimension needs to be a multiple + of num_microbatches). When there is microbatches, we always assume the + loss is the mean over a microbatch. And the gradient norm is computed for + each microbatch. Returns: A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th per-example loss function. """ tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) - registry_generator_fn = get_registry_generator_fn(tape, layer_registry) + registry_generator_fn = get_registry_generator_fn( + tape, layer_registry, num_microbatches + ) # First loop computes the model outputs, summed loss, and generator outputs. with tape: model_outputs, generator_outputs_list = ( @@ -102,6 +113,10 @@ def compute_gradient_norms( loss_config['reduction'] = tf.keras.losses.Reduction.NONE per_example_loss_fn = input_model.loss.from_config(loss_config) losses = per_example_loss_fn(y_batch, model_outputs) + if num_microbatches is not None: + losses = tf.reduce_mean( + lr.add_microbatch_axis(losses, num_microbatches), axis=1 + ) summed_loss = tf.reduce_sum(losses) # Unwrap the generator outputs so that the next loop avoids duplicating # backprop ops. @@ -149,6 +164,7 @@ def compute_pred_and_clipped_gradients( y_batch: tf.Tensor, l2_norm_clip: float, layer_registry: lr.LayerRegistry, + num_microbatches: Optional[lr.BatchSize] = None, ): """Computes the per-example predictions and per-example clipped loss gradient. @@ -177,6 +193,10 @@ def compute_pred_and_clipped_gradients( `output` is the pre-activator tensor, `sqr_grad_norms` is related to the squared norms of a layer's pre-activation tensor, and `vars` are relevant trainable weights (see `layer_registry_factories.py` for examples). + num_microbatches: An optional number or scalar `tf.Tensor` for the number of + microbatches. If not None, indicates that the loss is grouped into + num_microbatches (in this case, the batch dimension needs to be a multiple + of num_microbatches). Returns: A `tuple` `(y_pred, grad)`. The first element is the prediction generated by @@ -184,11 +204,21 @@ def compute_pred_and_clipped_gradients( gradient of the loss function. """ gradient_norms = compute_gradient_norms( - input_model, x_batch, y_batch, layer_registry + input_model, x_batch, y_batch, layer_registry, num_microbatches ) loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms) with tf.GradientTape() as tape: y_pred = input_model(x_batch, training=True) + if num_microbatches is not None: + y_batch = lr.add_microbatch_axis(y_batch, num_microbatches) + y_pred = lr.add_microbatch_axis(y_pred, num_microbatches) + # Warning: When num_microbatches is not None, we need to be sure that + # `compute_loss` always computes the mean over the microbatches + # as it is the assumption made when computing the gradient norm. + # It is indeed the case for multiple keras loss functions + # (e.g. mean_squared_error and binary_crossentropy). However it + # is not defined in the contract so may not hold, especially for + # custom losses. loss_value = input_model.compute_loss( x_batch, y_batch, y_pred, loss_weights ) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index 1933e21..5275b2a 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -13,7 +13,7 @@ # limitations under the License. import itertools -from typing import Callable, Any, List, Union +from typing import Any, Callable, List, Optional, Union from absl.testing import parameterized import tensorflow as tf @@ -49,14 +49,17 @@ class DoubleDense(tf.keras.layers.Layer): def double_dense_layer_computation( - layer_instance: tf.keras.layers.Layer, inputs: Any, tape: tf.GradientTape + layer_instance: tf.keras.layers.Layer, + inputs: Any, + tape: tf.GradientTape, + num_microbatches: Optional[int], ): """Layer registry function for the custom `DoubleDense` layer class.""" vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation( - layer_instance.dense1, inputs, tape + layer_instance.dense1, inputs, tape, num_microbatches ) vars2, outputs, sqr_norm_fn2 = layer_registry.dense_layer_computation( - layer_instance.dense2, (outputs,), tape + layer_instance.dense2, (outputs,), tape, num_microbatches ) def sqr_norm_fn(base_vars): @@ -68,7 +71,10 @@ def double_dense_layer_computation( def compute_true_gradient_norms( - input_model: tf.keras.Model, x_batch: tf.Tensor, y_batch: tf.Tensor + input_model: tf.keras.Model, + x_batch: tf.Tensor, + y_batch: tf.Tensor, + num_microbatches: Optional[int], ): """Computes the real gradient norms for an input `(model, x, y)`.""" loss_config = input_model.loss.get_config() @@ -77,13 +83,22 @@ def compute_true_gradient_norms( with tf.GradientTape(persistent=True) as tape: y_pred = input_model(x_batch) loss = per_example_loss_fn(y_batch, y_pred) + if num_microbatches is not None: + loss = tf.reduce_mean( + tf.reshape( + loss, + tf.concat([[num_microbatches, -1], tf.shape(loss)[1:]], axis=0), + ), + axis=1, + ) if isinstance(loss, tf.RaggedTensor): loss = loss.to_tensor() sqr_norms = [] for var in input_model.trainable_variables: jacobian = tape.jacobian(loss, var, experimental_use_pfor=False) reduction_axes = tf.range(1, len(jacobian.shape)) - sqr_norms.append(tf.reduce_sum(tf.square(jacobian), axis=reduction_axes)) + sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes) + sqr_norms.append(sqr_norm) sqr_norm_tsr = tf.stack(sqr_norms, axis=1) return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) @@ -93,6 +108,7 @@ def get_computed_and_true_norms( layer_generator: LayerGenerator, input_dims: Union[int, List[int]], output_dim: int, + num_microbatches: Optional[int], is_eager: bool, x_input: tf.Tensor, rng_seed: int = 777, @@ -113,6 +129,7 @@ def get_computed_and_true_norms( `idim` and returns output tensors of dimension `odim`. input_dims: The input dimension(s) of the test `tf.keras.Model` instance. output_dim: The output dimension of the test `tf.keras.Model` instance. + num_microbatches: The number of microbatches. None or an integer. is_eager: A `bool` that is `True` if the model should be run eagerly. x_input: `tf.Tensor` inputs to be tested. rng_seed: An `int` used to initialize model weights. @@ -137,10 +154,16 @@ def get_computed_and_true_norms( y_batch = tf.ones_like(y_pred) tf.keras.utils.set_random_seed(rng_seed) computed_norms = clip_grads.compute_gradient_norms( - model, x_input, y_batch, layer_registry=registry + model, + x_input, + y_batch, + layer_registry=registry, + num_microbatches=num_microbatches, ) tf.keras.utils.set_random_seed(rng_seed) - true_norms = compute_true_gradient_norms(model, x_input, y_batch) + true_norms = compute_true_gradient_norms( + model, x_input, y_batch, num_microbatches + ) return (computed_norms, true_norms) @@ -322,18 +345,30 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase): @parameterized.product( model_name=list(get_dense_model_generators().keys()), layer_name=list(get_dense_layer_generators().keys()), - input_dim=[1, 2], + input_dim=[4], output_dim=[1, 2], + num_microbatches=[None, 1, 2], is_eager=[True, False], ) def test_gradient_norms_on_various_models( - self, model_name, layer_name, input_dim, output_dim, is_eager + self, + model_name, + layer_name, + input_dim, + output_dim, + num_microbatches, + is_eager, ): model_generator = get_dense_model_generators()[model_name] layer_generator = get_dense_layer_generators()[layer_name] x_batches = get_nd_test_batches(input_dim) default_registry = layer_registry.make_default_layer_registry() for x_batch in x_batches: + if ( + num_microbatches is not None + and x_batch.shape[0] % num_microbatches != 0 + ): + continue if model_name == 'tower1': x_input = [x_batch, x_batch] else: @@ -343,6 +378,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase): layer_generator, input_dim, output_dim, + num_microbatches, is_eager, x_input, registry=default_registry, @@ -362,6 +398,10 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase): tf.ragged.constant( [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32 ), + tf.ragged.constant( + [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1], [0, 1]], + dtype=tf.int32, + ), # 3D inputs. tf.convert_to_tensor([[[0, 1]]], dtype_hint=tf.int32), tf.convert_to_tensor( @@ -371,14 +411,24 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase): [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]], dtype=tf.int32, ), + tf.ragged.constant( + [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]], [[0]]], + dtype=tf.int32, + ), ], model_name=list(get_embedding_model_generators().keys()), - output_dim=[1, 2], - is_eager=[True, False], + output_dim=[2], + num_microbatches=[None, 1, 2], + is_eager=[True], ) def test_gradient_norms_on_various_models( - self, x_batch, model_name, output_dim, is_eager + self, x_batch, model_name, output_dim, num_microbatches, is_eager ): + if ( + num_microbatches is not None + and x_batch.shape[0] % num_microbatches != 0 + ): + return valid_test_input = ( not isinstance(x_batch, tf.RaggedTensor) and model_name == 'weighted_bow1' @@ -391,6 +441,7 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase): layer_generator=None, input_dims=x_batch.shape[1:], output_dim=output_dim, + num_microbatches=num_microbatches, is_eager=is_eager, x_input=x_batch, registry=default_registry, @@ -403,20 +454,27 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase): @parameterized.product( input_dim=[1, 2], output_dim=[1, 2], + num_microbatches=[None, 1, 2], is_eager=[True, False], ) def test_gradient_norms_on_various_models( - self, input_dim, output_dim, is_eager + self, input_dim, output_dim, num_microbatches, is_eager ): registry = layer_registry.make_default_layer_registry() registry.insert(DoubleDense, double_dense_layer_computation) x_batches = get_nd_test_batches(input_dim) for x_batch in x_batches: + if ( + num_microbatches is not None + and x_batch.shape[0] % num_microbatches != 0 + ): + continue (computed_norms, true_norms) = get_computed_and_true_norms( model_generator=make_two_layer_sequential_model, layer_generator=lambda a, b: DoubleDense(b), input_dims=input_dim, output_dim=output_dim, + num_microbatches=num_microbatches, is_eager=is_eager, x_input=x_batch, registry=registry, diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index 428dc0f..ec9d996 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -157,8 +157,8 @@ def all_trainable_layers_are_registered( def add_aggregate_noise( input_model: tf.keras.Model, - x_batch: InputTensor, - clipped_grads: List[tf.Tensor], + clipped_grads: list[tf.Tensor], + batch_size: tf.Tensor, l2_norm_clip: float, noise_multiplier: float, ) -> List[tf.Tensor]: @@ -169,8 +169,9 @@ def add_aggregate_noise( Args: input_model: The `tf.keras.Model` to obtain the layers from. - x_batch: An `InputTensor` to be fed into the input layer of the model. clipped_grads: A list of `tf.Tensor`s representing the clipped gradients. + batch_size: The batch size, used for normalizing the noise, when the loss + reduction is AUTO or SUM_OVER_BATCH_SIZE. l2_norm_clip: Clipping norm (max L2 norm of each gradient). noise_multiplier: Ratio of the standard deviation to the clipping norm. @@ -186,17 +187,7 @@ def add_aggregate_noise( ]: if input_model.loss.reduction == tf.keras.losses.Reduction.AUTO: logging.info('Assuming that the loss reduction is `SUM_OVER_BATCH_SIZE`.') - if isinstance(x_batch, tf.Tensor): - scale /= tf.cast(tf.shape(x_batch)[0], tf.float32) - elif isinstance(x_batch, dict): - batch_sizes = [ - tf.cast(tf.shape(v)[0], tf.float32) for v in x_batch.values() - ] - scale /= tf.math.reduce_min(batch_sizes) - else: - raise NotImplementedError( - 'Unknown container/class %s for input' % x_batch.__class__.__name__ - ) + scale /= tf.cast(batch_size, tf.float32) def add_noise(g): return g + tf.random.normal( diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py index 556fcc4..eaa188d 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py @@ -38,9 +38,18 @@ whose i-th entry is the L2 norm of the i-th input vector, then where `l2_row_norm(y)` computes the L2 norm for each row of an input `y`. Details of this decomposition can be found in https://arxiv.org/abs/1510.01799 -""" -from typing import Any, Callable, Dict, Iterable, Text, Tuple, Type, Union +We also extend fast gradient norm computation to the case when the losses +are microbatched, i.e. each per example loss is the mean of a set of losses. +This could be useful for achieving user-level privacy and for improving the +quality of DP models, through better estimation of the gradients due to +aggregation at the microbatch level. +""" +# copybara.strip_begin +# The detailed algorithm can be found in go/fast-dpsgd-mb. +# copybara.strip_end + +from typing import Any, Callable, Dict, Iterable, Optional, Text, Tuple, Type, Union import tensorflow as tf @@ -56,6 +65,7 @@ RegistryFunction = Callable[ ] InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] +BatchSize = Union[int, tf.Tensor] # ============================================================================== @@ -88,6 +98,37 @@ class LayerRegistry: self._registry[layer_key] = layer_registry_function +# ============================================================================== +# Utilities +# ============================================================================== +def add_microbatch_axis( + x: tf.Tensor, + num_microbatches: Optional[BatchSize], +) -> tf.Tensor: + """Adds the microbatch axis. + + Reshape the input tensor to replace the first(batch) dimension with the + shape [num_microbatches, batch_size / num_microbatches]. The batch size + must be a multiple of num_microbatches (unless it is None, meaning + num_microbatches is the same as the batch size). + + Args: + x: the input tensor. + num_microbatches: None or a numeric value or a scalar `tf.Tensor`. + + Returns: + The reshaped input tensor. + """ + if num_microbatches is None: + return tf.expand_dims(x, 1) + with tf.control_dependencies( + [tf.assert_equal(tf.math.floormod(tf.shape(x)[0], num_microbatches), 0)] + ): + return tf.reshape( + x, tf.concat([[num_microbatches, -1], tf.shape(x)[1:]], axis=0) + ) + + # ============================================================================== # Supported Keras layers # ============================================================================== @@ -95,6 +136,7 @@ def dense_layer_computation( layer_instance: tf.keras.layers.Dense, inputs: Tuple[InputTensor], tape: tf.GradientTape, + num_microbatches: Optional[tf.Tensor] = None, ) -> RegistryFunctionOutput: """Registry function for `tf.keras.layers.Dense`. @@ -111,6 +153,9 @@ def dense_layer_computation( output. tape: A `tf.GradientTape` instance that will be used to watch the output `base_vars`. + num_microbatches: An optional numeric value or scalar `tf.Tensor` for + indicating whether and how the losses are grouped into microbatches. If + not None, num_microbatches must divide the batch size. Returns: A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the @@ -132,21 +177,29 @@ def dense_layer_computation( tape.watch(base_vars) layer_instance.activation = orig_activation outputs = orig_activation(base_vars) if orig_activation else base_vars + def sqr_norm_fn(base_vars_grads): - sqr_inputs = tf.square(*inputs) - inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs)) - input_sqr_norms = tf.reduce_sum(sqr_inputs, axis=inputs_reduction_axes) + + def _compute_gramian(x): + if num_microbatches is not None: + x_microbatched = add_microbatch_axis(x, num_microbatches) + return tf.matmul(x_microbatched, x_microbatched, transpose_b=True) + else: + # Special handling for better efficiency + return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x))) + + inputs_gram = _compute_gramian(*inputs) + base_vars_grads_gram = _compute_gramian(base_vars_grads) if layer_instance.use_bias: # Adding a bias term is equivalent to a layer with no bias term and which # adds an additional variable to the layer input that only takes a # constant value of 1.0. This is thus equivalent to adding 1.0 to the sum # of the squared values of the inputs. - input_sqr_norms += tf.cast(1.0, dtype=input_sqr_norms.dtype) - reduction_axes = tf.range(1, tf.rank(base_vars_grads)) - base_vars_sqr_norms = tf.reduce_sum( - tf.square(base_vars_grads), axis=reduction_axes + inputs_gram += 1.0 + return tf.reduce_sum( + inputs_gram * base_vars_grads_gram, + axis=tf.range(1, tf.rank(inputs_gram)), ) - return input_sqr_norms * base_vars_sqr_norms return base_vars, outputs, sqr_norm_fn @@ -155,6 +208,7 @@ def embedding_layer_computation( layer_instance: tf.keras.layers.Embedding, inputs: Tuple[InputTensor], tape: tf.GradientTape, + num_microbatches: Optional[tf.Tensor] = None, ) -> RegistryFunctionOutput: """Registry function for `tf.keras.layers.Embedding`. @@ -171,6 +225,9 @@ def embedding_layer_computation( output. tape: A `tf.GradientTape` instance that will be used to watch the output `base_vars`. + num_microbatches: An optional numeric value or scalar `tf.Tensor` for + indicating whether and how the losses are grouped into microbatches. If + not None, num_microbatches must divide the batch size. Returns: A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the @@ -219,6 +276,13 @@ def embedding_layer_computation( raise NotImplementedError( "Cannot parse input_ids of type %s" % input_ids.__class__.__name__ ) + row_indices = tf.cast(row_indices, tf.int32) + if num_microbatches is not None: + microbatch_size = tf.cast(nrows / num_microbatches, tf.int32) + nrows = num_microbatches + row_indices = tf.cast( + tf.math.floordiv(row_indices, microbatch_size), tf.int32 + ) # Sum-reduce the `IndexSlices` that is the result of a `tape.gradient()` # call. The sum is reduced by the repeated embedding indices and batch # index. It is adapted from the logic in: diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index 5f445d2..7bafd18 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -26,53 +26,64 @@ def make_dp_model_class(cls): __doc__ = ( """DP subclass of `{base_model}`. - This can be used as a differentially private replacement for - {base_model}. This class implements DP-SGD using the standard - Gaussian mechanism. + This can be used as a differentially private replacement for + {base_model}. This class implements DP-SGD using the standard + Gaussian mechanism. - This class also utilizes a faster gradient clipping algorithm if the - following two conditions hold: + This class also utilizes a faster gradient clipping algorithm if the + following two conditions hold: (i) the trainable layers of the model are keys in the `dict` input `layer_registry`, (ii) the loss `tf.Tensor` for a given batch of examples is either a scalar or a 2D `tf.Tensor` that has only one column `(i.e., tf.shape(loss)[1] == 1)` and whose i-th row corresponds to the loss of the i-th example. - This clipping algorithm specifically computes clipped gradients at the - per-example level using the layer registry functions in `layer_registry` - (see clip_grads.py for more information about the algorithm). In this - setting, microbatching is not used (it is equivalent to - `num_microbatches == batch_size`), and the input `num_microbatches` - is ignored. + This clipping algorithm specifically computes clipped gradients at the + per-example or per microbatch (when `num_microbatches` is not None) + level using the layer registry functions in `layer_registry` (see + clip_grads.py for more information about the algorithm). - When instantiating this class, you need to supply several - DP-related arguments followed by the standard arguments for - `{short_base_model}`. + WARNING: with faster gradient clipping, and when num_microbatches is not + None, the per microbatch loss is assumed to be computed as the mean + of the loss over the microbatch, or effectively, by reshaping the loss + from the shape [batch_size, ...] to the shape + [num_microbatches, batch_size/num_microbatches, ...] and computing the + mean of the loss over the microbatches. This would require that the loss + function behaves accordingly. This is true for multiple common + predefined keras loss functions (e.g. mean_squared_loss, + binary_crossentropy) but may not hold for custom losses (and how such + aggregation is done is not exposed by the loss function, unfortunately). + It is the caller's responsibility to make sure that the loss function + does behave this way. - Examples: + When instantiating this class, you need to supply several + DP-related arguments followed by the standard arguments for + `{short_base_model}`. - ```python - # Create Model instance. - model = {dp_model_class}(l2_norm_clip=1.0, noise_multiplier=0.5, use_xla=True, - ) - ``` + Examples: - You should use your {dp_model_class} instance with a standard instance - of `tf.keras.Optimizer` as the optimizer, and a standard reduced loss. - You do not need to use a differentially private optimizer. + ```python + # Create Model instance. + model = {dp_model_class}(l2_norm_clip=1.0, noise_multiplier=0.5, use_xla=True, + ) + ``` - ```python - # Use a standard (non-DP) optimizer. - optimizer = tf.keras.optimizers.SGD(learning_rate=0.01) + You should use your {dp_model_class} instance with a standard instance + of `tf.keras.Optimizer` as the optimizer, and a standard reduced loss. + You do not need to use a differentially private optimizer. - # Use a standard reduced loss. - loss = tf.keras.losses.MeanSquaredError() + ```python + # Use a standard (non-DP) optimizer. + optimizer = tf.keras.optimizers.SGD(learning_rate=0.01) - model.compile(optimizer=optimizer, loss=loss) - model.fit(train_data, train_labels, epochs=1, batch_size=32) - ``` + # Use a standard reduced loss. + loss = tf.keras.losses.MeanSquaredError() - """ + model.compile(optimizer=optimizer, loss=loss) + model.fit(train_data, train_labels, epochs=1, batch_size=32) + ``` + + """ ).format( base_model='tf.keras.' + cls.__name__, short_base_model=cls.__name__, @@ -115,6 +126,7 @@ def make_dp_model_class(cls): if isinstance(num_microbatches, bool): raise ValueError('Boolean value supplied for `num_microbatches`. ' 'Did you intend it for `use_xla`?') + self._num_microbatches = num_microbatches # If all the trainable layers are in the input layer registry, we # don't need to use microbatching and can instead use the "fast" @@ -126,16 +138,8 @@ def make_dp_model_class(cls): ) and gradient_clipping_utils.has_internal_compute_graph(self) ): - if num_microbatches is not None: - raise ValueError( - 'Cannot initialize a model where num_microbatches ' - 'is not `None` and all trainable layers are ' - 'registered in layer_registry.' - ) - self._num_microbatches = None self._enable_fast_peg_computation = True else: - self._num_microbatches = num_microbatches self._enable_fast_peg_computation = False if use_xla: @@ -198,10 +202,20 @@ def make_dp_model_class(cls): # trick, and uses these norms to clip the per-example gradients. x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data) y_pred, clipped_grads = clip_grads.compute_pred_and_clipped_gradients( - self, x, y, self._l2_norm_clip, self._layer_registry + self, + x, + y, + self._l2_norm_clip, + self._layer_registry, + self._num_microbatches, ) + batch_size = self._num_microbatches or tf.shape(y)[0] grads = gradient_clipping_utils.add_aggregate_noise( - self, x, clipped_grads, self._l2_norm_clip, self._noise_multiplier + self, + clipped_grads, + batch_size, + self._l2_norm_clip, + self._noise_multiplier, ) else: logging.info('Computing gradients using microbatching.') diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py index 4bb3c4f..59e60a4 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py @@ -139,9 +139,7 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): train_labels = np.array([[1.0], [3.0], [-2.0], [-4.0]]) learning_rate = 1.0 - for test_reg, test_nm in zip( - get_layer_registries(), [num_microbatches, None] - ): + for test_reg in get_layer_registries(): optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) loss = tf.keras.losses.MeanSquaredError() @@ -149,7 +147,7 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): model = dp_keras_model.DPSequential( l2_norm_clip=l2_norm_clip, noise_multiplier=0.0, - num_microbatches=test_nm, + num_microbatches=num_microbatches, layer_registry=test_reg, layers=[ tf.keras.layers.InputLayer(input_shape=(2,)), @@ -173,10 +171,11 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): train_data, train_labels, w, l2_norm_clip, effective_num_microbatches ) expected_weights = np.squeeze(-learning_rate * expected_grads) - self.assertAllClose(model_weights, expected_weights) @parameterized.named_parameters( + ('noise_multiplier 3 2 None', 3.0, 2.0, None), + ('noise_multiplier 5 4 None', 5.0, 4.0, None), ('noise_multiplier 3 2 1', 3.0, 2.0, 1), ('noise_multiplier 5 4 1', 5.0, 4.0, 1), ('noise_multiplier 3 2 2', 3.0, 2.0, 2), @@ -198,9 +197,7 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): learning_rate = 1.0 - for test_reg, test_nm in zip( - get_layer_registries(), [num_microbatches, None] - ): + for test_reg in get_layer_registries(): optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) loss = tf.keras.losses.MeanSquaredError() @@ -208,7 +205,7 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): model = dp_keras_model.DPSequential( l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, - num_microbatches=test_nm, + num_microbatches=num_microbatches, layer_registry=test_reg, layers=[ tf.keras.layers.InputLayer(input_shape=(1000,)), @@ -220,11 +217,7 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): model.compile(optimizer=optimizer, loss=loss) model.fit(train_data, train_labels, epochs=1, batch_size=4) - effective_num_microbatches = ( - train_data.shape[0] - if model._num_microbatches is None - else num_microbatches - ) + effective_num_microbatches = num_microbatches or train_data.shape[0] model_weights = model.get_weights() measured_std = np.std(model_weights[0]) @@ -248,16 +241,14 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): train_labels = np.array([[0], [1], [1], [0]]) learning_rate = 1.0 - for test_reg, test_nm in zip( - get_layer_registries(), [num_microbatches, None] - ): + for test_reg in get_layer_registries(): optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) model = dp_keras_model.DPSequential( l2_norm_clip=1.0e9, noise_multiplier=0.0, - num_microbatches=test_nm, + num_microbatches=num_microbatches, layer_registry=test_reg, layers=[ tf.keras.layers.InputLayer(input_shape=(2,)),