From 09c68750d715cca3fe2e12e9f93db7e899bda6f0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 8 Aug 2024 11:51:20 -0700 Subject: [PATCH] Sparsity Preserving DP-SGD in TF Privacy Refactor model_forward_backward_pass out of compute_gradients to allow for other optimizations such as sparsity preserving noise to integrate with it. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 660924829 --- .../privacy/fast_gradient_clipping/BUILD | 2 + .../fast_gradient_clipping/clip_grads.py | 267 +++++++----------- .../fast_gradient_clipping/clip_grads_test.py | 34 ++- .../gradient_clipping_utils.py | 145 ++++++++++ .../fast_gradient_clipping/type_aliases.py | 2 + .../privacy/keras_models/dp_keras_model.py | 24 +- 6 files changed, 310 insertions(+), 164 deletions(-) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD index adb5a76..f5b920f 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD @@ -43,6 +43,7 @@ py_library( srcs = ["gradient_clipping_utils.py"], srcs_version = "PY3", deps = [ + ":common_manip_utils", ":layer_registry", ":type_aliases", ], @@ -94,6 +95,7 @@ py_test( deps = [ ":clip_grads", ":common_test_utils", + ":gradient_clipping_utils", ":layer_registry", ":type_aliases", ], diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index 100e66a..e31f178 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -22,7 +22,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the """ import collections -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import Optional import tensorflow as tf @@ -32,73 +32,81 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases -def _infer_per_example_loss_fn(model: tf.keras.Model): - """Infer the per-example loss from model config.""" +def _compute_gradient_norms_internal( + registry_fn_outputs_list: Sequence[ + gradient_clipping_utils.RegistryGeneratorFunctionOutput + ], + layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]], + trainable_vars: Optional[Sequence[tf.Variable]] = None, +) -> tf.Tensor: + """Computes the per-example loss gradient norms for given data. - def _convert(loss_fn): - loss_config = loss_fn.get_config() - loss_config['reduction'] = tf.keras.losses.Reduction.NONE - return loss_fn.from_config(loss_config) + Args: + registry_fn_outputs_list: A sequence of RegistryGeneratorFunctionOutput + containing information required to compute the gradient norms and + contribution counts. Output from + `gradient_clipping_utils.model_forward_backward_pass()`. + layer_grad_vars: A mapping of layer id to a list of gradients for each + trainable variable in the layer. Output from + `gradient_clipping_utils.model_forward_backward_pass()`. + trainable_vars: The list of variables included in computing the gradient + norm. When a layer has multiple variables, we include all the variables if + any of the variables is in the list. If `trainable_vars` is None, all the + variables are included. - model_loss = model.loss - if isinstance(model_loss, tf.keras.losses.Loss): - return _convert(model_loss) - elif isinstance(model_loss, dict): - # Note that we cannot call the public method `.get_compile_config()` because - # it calls a numpy function, which is not supported inside a `tf.function` - # wrapped function. - compile_config = model._compile_config.config # pylint: disable=protected-access - if compile_config is None: - raise ValueError('Model must be compiled for loss function conversion') - # Does a weighted mean of the configured losses. Note that we cannot build - # from the config of the compiled loss because (i) it builds a - # `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s - # during its construction, (ii) non-unique `tf.Variables` cannot be used - # inside a `tf.function`, which is usually where this function is used. - if 'loss_weights' not in compile_config: - raise ValueError( - 'Models with multiple loss must have corresponding loss weights for' - ' loss function conversion' + Returns: + A scalar vector, whose i-th entry is the norm of the gradient of the i-th + weighted example loss (when num_microbatches is None) or the norm of the + gradient of the i-th microbatch loss (define as a mean over the microbatch). + Note that when the loss is weighted (`weight_batch` is not None), weights + are applied prior to clipping. + + Raises: + ValueError: If `layer_grad_vars` is empty. + ValueError: If the number of gradients for a layer is not equal to the + number of squared norm functions for that layer. + """ + if trainable_vars is not None: + # Create a set using `ref()` for fast set membership check. tf.Variable + # itself is not hashable. + trainable_vars = set([v.ref() for v in trainable_vars]) + + layer_sqr_norm_fns = collections.defaultdict(list) + # The case of shared weights: + # If a layer is called k times, it will appear k times in filtered_outputs, + # with the same id, but potentially with different v and f. The code below + # groups filtered_outputs by layer_id, so we can correctly compute gradient + # norms. The gradient norm of a layer that occurs k times is computed as + # $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th + # occurrence. This is an over-estimate of the actual norm. For more details, + # see the explanation in go/dp-sgd-shared-weights. + for registry_fn_output in registry_fn_outputs_list: + if trainable_vars is None or any( + w.ref() in trainable_vars + for w in registry_fn_output.layer_trainable_weights + ): + layer_sqr_norm_fns[registry_fn_output.layer_id].append( + registry_fn_output.layer_sqr_norm_fn ) - weights = compile_config['loss_weights'] - per_example_losses = {k: _convert(v) for k, v in model_loss.items()} - num_losses = len(weights) - def _per_example_loss_fn(y_true, y_pred, sample_weight=None): - loss_values = [] - if model_loss.keys() - y_pred.keys(): - raise ValueError( - 'y_pred must contain the same keys and the model losses, but ' - 'got %s and %s' % (y_pred.keys(), model_loss.keys()) - ) - if model_loss.keys() - y_true.keys(): - raise ValueError( - 'y_true must contain the same keys and the model losses, but ' - 'got %s and %s' % (y_true.keys(), model_loss.keys()) - ) - if sample_weight is not None: - if model_loss.keys() - sample_weight.keys(): - raise ValueError( - 'sample_weight must contain the same keys and the model losses,' - ' but got %s and %s' % (y_true.keys(), model_loss.keys()) - ) - for k in y_true.keys(): - sgl_sample_weight = None if sample_weight is None else sample_weight[k] - sgl_value = ( - weights[k] - * per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight) - / num_losses - ) - loss_values.append(tf.reshape(sgl_value, shape=[-1])) - return tf.math.add_n(loss_values) - - return _per_example_loss_fn - else: - raise ValueError( - 'Unsupported type for loss function conversion: {}'.format( - type(model_loss) - ) - ) + if not layer_grad_vars: + raise ValueError('The gradient list cannot be empty.') + sqr_norm_list = [] + for layer_id in layer_sqr_norm_fns.keys(): + fns = layer_sqr_norm_fns[layer_id] + grads = layer_grad_vars[layer_id] + # Number of duplicates for this layer in `filtered_outputs`. + num_passes = len(fns) + if len(fns) != len(grads): + raise ValueError( + 'There must be as many gradients as squared norm functions.' + ) + # See go/dp-sgd-shared-weights for more details. + for fn, grad in zip(fns, grads): + sqr_norm_list.append(num_passes * fn(grad)) + sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1) + gradient_norms = tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) + return gradient_norms def compute_gradient_norms( @@ -110,7 +118,7 @@ def compute_gradient_norms( per_example_loss_fn: Optional[type_aliases.LossFn] = None, num_microbatches: Optional[type_aliases.BatchSize] = None, trainable_vars: Optional[Sequence[tf.Variable]] = None, -): +) -> tf.Tensor: """Computes the per-example loss gradient norms for given data. Applies a variant of the approach given in @@ -154,90 +162,27 @@ def compute_gradient_norms( """ tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn( - tape, layer_registry, num_microbatches + tape=tape, + layer_registry=layer_registry, + num_microbatches=num_microbatches, ) - # First loop computes the model outputs, summed loss, and generator outputs. - with tape: - model_outputs, generator_outputs_list = ( - gradient_clipping_utils.model_forward_pass( - input_model, x_batch, generator_fn=registry_generator_fn - ) - ) - - # Ignore the original loss function's reduction to get per-example loss. - if per_example_loss_fn is None: - per_example_loss_fn = _infer_per_example_loss_fn(input_model) - - losses = per_example_loss_fn(y_batch, model_outputs, weight_batch) - if losses.shape is None: - raise NotImplementedError( - "The unreduced (or per-example) loss's shape cannot be `None`" + layer_grad_vars, generator_outputs_list = ( + gradient_clipping_utils.model_forward_backward_pass( + tape=tape, + input_model=input_model, + x_batch=x_batch, + y_batch=y_batch, + registry_generator_fn=registry_generator_fn, + weight_batch=weight_batch, + per_example_loss_fn=per_example_loss_fn, + num_microbatches=num_microbatches, ) - if len(losses.shape) != 1: - raise NotImplementedError( - 'The unreduced (or per-example) loss needs to have a shape of length ' - 'one, but received an unreduced loss of shape length %s' - % len(losses.shape) - ) - if num_microbatches is not None: - losses = tf.reduce_mean( - common_manip_utils.maybe_add_microbatch_axis( - losses, num_microbatches - ), - axis=1, - ) - summed_loss = tf.reduce_sum(losses) - # Unwrap the generator outputs so that the next loop avoids duplicating - # backprop ops. - filtered_outputs = [t for t in generator_outputs_list if t is not None] - if trainable_vars is not None: - # Create a set using `ref()` for fast set membership check. tf.Variable - # itself is not hashable. - trainable_vars = set([v.ref() for v in trainable_vars]) - layer_vars = collections.defaultdict(list) - layer_sqr_norm_fns = collections.defaultdict(list) - # The case of shared weights: - # If a layer is called k times, it will appear k times in filtered_outputs, - # with the same id, but potentially with different v and f. The code below - # groups filtered_outputs by layer_id, so we can correctly compute gradient - # norms. The gradient norm of a layer that occurs k times is computed as - # $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th - # occurrence. This is an over-estimate of the actual norm. For more details, - # see the explanation in go/dp-sgd-shared-weights. - for registry_fn_output in filtered_outputs: - if trainable_vars is None or any( - w.ref() in trainable_vars - for w in registry_fn_output.layer_trainable_weights - ): - layer_vars[registry_fn_output.layer_id].append( - registry_fn_output.layer_vars - ) - layer_sqr_norm_fns[registry_fn_output.layer_id].append( - registry_fn_output.layer_sqr_norm_fn - ) - # Second loop evaluates the squared L2 norm functions and appends the results. - layer_grad_vars = tape.gradient( - summed_loss, - layer_vars, - unconnected_gradients=tf.UnconnectedGradients.ZERO, ) - if not layer_grad_vars: - raise ValueError('The gradient list cannot be empty.') - sqr_norm_list = [] - for layer_id in layer_sqr_norm_fns.keys(): - fns = layer_sqr_norm_fns[layer_id] - grads = layer_grad_vars[layer_id] - # Number of duplicates for this layer in `filtered_outputs`. - num_passes = len(fns) - if len(fns) != len(grads): - raise ValueError( - 'There must be as many gradients as squared norm functions.' - ) - # See go/dp-sgd-shared-weights for more details. - for fn, grad in zip(fns, grads): - sqr_norm_list.append(num_passes * fn(grad)) - sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1) - return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) + return _compute_gradient_norms_internal( + registry_fn_outputs_list=generator_outputs_list, + layer_grad_vars=layer_grad_vars, + trainable_vars=trainable_vars, + ) def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor): @@ -267,14 +212,17 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor): def compute_clipped_gradients_and_outputs( input_model: tf.keras.Model, + registry_fn_outputs_list: Sequence[ + gradient_clipping_utils.RegistryGeneratorFunctionOutput + ], + layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]], l2_norm_clip: float, - layer_registry: lr.LayerRegistry, x_batch: type_aliases.InputTensors, y_batch: type_aliases.OutputTensors, weight_batch: Optional[tf.Tensor] = None, num_microbatches: Optional[type_aliases.BatchSize] = None, clipping_loss: Optional[type_aliases.LossFn] = None, -) -> tuple[Sequence[tf.Tensor], tf.Tensor, tf.Tensor]: +) -> tuple[Sequence[type_aliases.Tensor], tf.Tensor, tf.Tensor]: """Computes the per-example clipped loss gradient and other useful outputs. Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main @@ -287,15 +235,16 @@ def compute_clipped_gradients_and_outputs( Args: input_model: The `tf.keras.Model` from which to obtain the layers from. + registry_fn_outputs_list: A `list` of RegistryGeneratorFunctionOutput + containing information required to compute the gradient norms and + contribution counts. Output from + `gradient_clipping_utils.model_forward_backward_pass()`. + layer_grad_vars: A mapping of layer id to a list of gradients for each + trainablev ariable in the layer. Output from + `gradient_clipping_utils.model_forward_backward_pass()`. l2_norm_clip: A `float` indicating the norm to which per-example gradients will be clipped. That is, all gradients of the per-example loss functions will have norm at most `l2_norm_clip`. - layer_registry: A `dict` of layers that support "fast" gradient norm - computations. The key is the class of the layer and the value is a - function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where - `output` is the pre-activator tensor, `sqr_grad_norms` is related to the - squared norms of a layer's pre-activation tensor, and `vars` are relevant - trainable weights (see `layer_registry_factories.py` for examples). x_batch: An `InputTensor` representing a batch of inputs to the model. The first axes of each tensor must be the batch dimension. y_batch: An `OutputTensor` representing a batch of output labels. The first @@ -330,13 +279,9 @@ def compute_clipped_gradients_and_outputs( ) if clipping_loss is None: clipping_loss = input_model.compiled_loss - gradient_norms = compute_gradient_norms( - input_model, - layer_registry, - x_batch, - y_batch, - weight_batch, - num_microbatches=num_microbatches, + gradient_norms = _compute_gradient_norms_internal( + registry_fn_outputs_list=registry_fn_outputs_list, + layer_grad_vars=layer_grad_vars, trainable_vars=input_model.trainable_variables, ) clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index a028aff..7b91461 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -19,6 +19,7 @@ from absl.testing import parameterized import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils +from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases @@ -122,6 +123,29 @@ class CustomLayerTest(tf.test.TestCase, parameterized.TestCase): self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) +def _run_model_forward_backward_pass( + model: tf.keras.Model, + x_batch: type_aliases.InputTensors, + y_batch: type_aliases.OutputTensors, +): + tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) + registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn( + tape=tape, + layer_registry=layer_registry.make_default_layer_registry(), + num_microbatches=None, + ) + layer_grad_vars, registry_fn_outputs_list = ( + gradient_clipping_utils.model_forward_backward_pass( + tape=tape, + input_model=model, + x_batch=x_batch, + y_batch=y_batch, + registry_generator_fn=registry_generator_fn, + ) + ) + return layer_grad_vars, registry_fn_outputs_list + + class ComputeClippedGradsAndOutputsTest( tf.test.TestCase, parameterized.TestCase ): @@ -153,13 +177,17 @@ class ComputeClippedGradsAndOutputsTest( y_batch = tf.reshape( 1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1] ) + layer_grad_vars, registry_fn_outputs_list = ( + _run_model_forward_backward_pass(self._model, x_batch, y_batch) + ) # Stop early for efficiency. if reduction == 'none': with self.assertRaises(NotImplementedError): clip_grads.compute_clipped_gradients_and_outputs( self._model, + registry_fn_outputs_list, + layer_grad_vars, l2_norm_clip, - layer_registry.make_default_layer_registry(), x_batch, y_batch, ) @@ -169,10 +197,12 @@ class ComputeClippedGradsAndOutputsTest( y_pred = self._model(x_batch) loss_value = loss_fn(y_pred, y_batch) true_grads = tape.gradient(loss_value, self._model.trainable_variables) + clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs( self._model, + registry_fn_outputs_list, + layer_grad_vars, l2_norm_clip, - layer_registry.make_default_layer_registry(), x_batch, y_batch, ) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index 7a060e9..bac323c 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -13,11 +13,13 @@ # limitations under the License. """Utility functions that help in the computation of per-example gradient norms.""" +import collections from collections.abc import Callable, Sequence, Set import dataclasses from typing import Any, Optional, Tuple import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases @@ -98,6 +100,149 @@ def get_registry_generator_fn( return registry_generator_fn +def _infer_per_example_loss_fn(model: tf.keras.Model): + """Infer the per-example loss from model config.""" + + def _convert(loss_fn): + loss_config = loss_fn.get_config() + loss_config['reduction'] = tf.keras.losses.Reduction.NONE + return loss_fn.from_config(loss_config) + + model_loss = model.loss + if isinstance(model_loss, tf.keras.losses.Loss): + return _convert(model_loss) + elif isinstance(model_loss, dict): + # Note that we cannot call the public method `.get_compile_config()` because + # it calls a numpy function, which is not supported inside a `tf.function` + # wrapped function. + compile_config = model._compile_config.config # pylint: disable=protected-access + if compile_config is None: + raise ValueError('Model must be compiled for loss function conversion') + # Does a weighted mean of the configured losses. Note that we cannot build + # from the config of the compiled loss because (i) it builds a + # `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s + # during its construction, (ii) non-unique `tf.Variables` cannot be used + # inside a `tf.function`, which is usually where this function is used. + if 'loss_weights' not in compile_config: + raise ValueError( + 'Models with multiple loss must have corresponding loss weights for' + ' loss function conversion' + ) + weights = compile_config['loss_weights'] + per_example_losses = {k: _convert(v) for k, v in model_loss.items()} + num_losses = len(weights) + + def _per_example_loss_fn(y_true, y_pred, sample_weight=None): + loss_values = [] + if model_loss.keys() - y_pred.keys(): + raise ValueError( + 'y_pred must contain the same keys and the model losses, but ' + 'got %s and %s' % (y_pred.keys(), model_loss.keys()) + ) + if model_loss.keys() - y_true.keys(): + raise ValueError( + 'y_true must contain the same keys and the model losses, but ' + 'got %s and %s' % (y_true.keys(), model_loss.keys()) + ) + if sample_weight is not None: + if model_loss.keys() - sample_weight.keys(): + raise ValueError( + 'sample_weight must contain the same keys and the model losses,' + ' but got %s and %s' % (y_true.keys(), model_loss.keys()) + ) + for k in y_true.keys(): + sgl_sample_weight = None if sample_weight is None else sample_weight[k] + sgl_value = ( + weights[k] + * per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight) + / num_losses + ) + loss_values.append(tf.reshape(sgl_value, shape=[-1])) + return tf.math.add_n(loss_values) + + return _per_example_loss_fn + else: + raise ValueError( + 'Unsupported type for loss function conversion: {}'.format( + type(model_loss) + ) + ) + + +def model_forward_backward_pass( + tape: tf.GradientTape, + input_model: tf.keras.Model, + x_batch: type_aliases.InputTensors, + y_batch: type_aliases.OutputTensors, + registry_generator_fn: Optional[ + Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]] + ], + weight_batch: Optional[tf.Tensor] = None, + per_example_loss_fn: Optional[type_aliases.LossFn] = None, + num_microbatches: Optional[type_aliases.BatchSize] = None, + trainable_vars: Optional[Sequence[tf.Variable]] = None, +) -> tuple[ + dict[str, list[type_aliases.Tensor]], list[RegistryGeneratorFunctionOutput] +]: + """Does a forward and backward pass of a model and returns useful intermediates.""" + # First loop computes the model outputs, summed loss, and generator outputs. + with tape: + model_outputs, generator_outputs_list = model_forward_pass( + input_model, x_batch, generator_fn=registry_generator_fn + ) + + # Ignore the original loss function's reduction to get per-example loss. + if per_example_loss_fn is None: + per_example_loss_fn = _infer_per_example_loss_fn(input_model) + + losses = per_example_loss_fn(y_batch, model_outputs, weight_batch) + if losses.shape is None: + raise NotImplementedError( + "The unreduced (or per-example) loss's shape cannot be `None`" + ) + if len(losses.shape) != 1: + raise NotImplementedError( + 'The unreduced (or per-example) loss needs to have a shape of length ' + 'one, but received an unreduced loss of shape length %s' + % len(losses.shape) + ) + if num_microbatches is not None: + losses = tf.reduce_mean( + common_manip_utils.maybe_add_microbatch_axis( + losses, num_microbatches + ), + axis=1, + ) + summed_loss = tf.reduce_sum(losses) + # Unwrap the generator outputs so that the next loop avoids duplicating + # backprop ops. + filtered_outputs = [t for t in generator_outputs_list if t is not None] + + if trainable_vars is not None: + # Create a set using `ref()` for fast set membership check. tf.Variable + # itself is not hashable. + trainable_vars = set([v.ref() for v in trainable_vars]) + layer_vars = collections.defaultdict(list) + for registry_fn_output in filtered_outputs: + if trainable_vars is None or any( + w.ref() in trainable_vars + for w in registry_fn_output.layer_trainable_weights + ): + layer_vars[registry_fn_output.layer_id].append( + registry_fn_output.layer_vars + ) + + layer_grad_vars = tape.gradient( + summed_loss, + layer_vars, + unconnected_gradients=tf.UnconnectedGradients.ZERO, + ) + if not layer_grad_vars: + raise ValueError('The gradient list cannot be empty.') + + return layer_grad_vars, filtered_outputs + + def model_forward_pass( input_model: tf.keras.Model, inputs: type_aliases.PackedTensors, diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py b/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py index 1e602a0..b064c69 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py @@ -19,6 +19,8 @@ import tensorflow as tf # Tensorflow aliases. +Tensor = Union[tf.Tensor, tf.IndexedSlices, tf.SparseTensor, tf.RaggedTensor] + PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]] InputTensors = PackedTensors diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index 2a28d69..b7104f4 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -274,14 +274,36 @@ def make_dp_model_class(cls): # trick, and uses these norms to clip the per-example gradients. # NOTE: Reshaping of the input according to the effective number of # microbatches is done here. + tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) + + registry_generator_fn = ( + gradient_clipping_utils.get_registry_generator_fn( + tape=tape, + layer_registry=self._layer_registry, + num_microbatches=num_microbatches, + ) + ) + layer_grad_vars, registry_fn_outputs_list = ( + gradient_clipping_utils.model_forward_backward_pass( + tape=tape, + input_model=self, + x_batch=x, + y_batch=y, + registry_generator_fn=registry_generator_fn, + weight_batch=weights, + num_microbatches=num_microbatches, + trainable_vars=self.trainable_variables, + ) + ) clipped_grads, y_pred, clipping_loss = ( clip_grads.compute_clipped_gradients_and_outputs( input_model=self, + registry_fn_outputs_list=registry_fn_outputs_list, + layer_grad_vars=layer_grad_vars, x_batch=x, y_batch=y, weight_batch=weights, l2_norm_clip=self._l2_norm_clip, - layer_registry=self._layer_registry, num_microbatches=self._num_microbatches, clipping_loss=self._clipping_loss, )