Sparsity Preserving DP-SGD in TF Privacy

Refactor model_forward_backward_pass out of compute_gradients to allow for other optimizations such as sparsity preserving noise to integrate with it.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 660924829
This commit is contained in:
A. Unique TensorFlower 2024-08-08 11:51:20 -07:00
parent 8294cec132
commit 09c68750d7
6 changed files with 310 additions and 164 deletions

View file

@ -43,6 +43,7 @@ py_library(
srcs = ["gradient_clipping_utils.py"], srcs = ["gradient_clipping_utils.py"],
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
":common_manip_utils",
":layer_registry", ":layer_registry",
":type_aliases", ":type_aliases",
], ],
@ -94,6 +95,7 @@ py_test(
deps = [ deps = [
":clip_grads", ":clip_grads",
":common_test_utils", ":common_test_utils",
":gradient_clipping_utils",
":layer_registry", ":layer_registry",
":type_aliases", ":type_aliases",
], ],

View file

@ -22,7 +22,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
""" """
import collections import collections
from collections.abc import Sequence from collections.abc import Mapping, Sequence
from typing import Optional from typing import Optional
import tensorflow as tf import tensorflow as tf
@ -32,73 +32,81 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
def _infer_per_example_loss_fn(model: tf.keras.Model): def _compute_gradient_norms_internal(
"""Infer the per-example loss from model config.""" registry_fn_outputs_list: Sequence[
gradient_clipping_utils.RegistryGeneratorFunctionOutput
],
layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]],
trainable_vars: Optional[Sequence[tf.Variable]] = None,
) -> tf.Tensor:
"""Computes the per-example loss gradient norms for given data.
def _convert(loss_fn): Args:
loss_config = loss_fn.get_config() registry_fn_outputs_list: A sequence of RegistryGeneratorFunctionOutput
loss_config['reduction'] = tf.keras.losses.Reduction.NONE containing information required to compute the gradient norms and
return loss_fn.from_config(loss_config) contribution counts. Output from
`gradient_clipping_utils.model_forward_backward_pass()`.
layer_grad_vars: A mapping of layer id to a list of gradients for each
trainable variable in the layer. Output from
`gradient_clipping_utils.model_forward_backward_pass()`.
trainable_vars: The list of variables included in computing the gradient
norm. When a layer has multiple variables, we include all the variables if
any of the variables is in the list. If `trainable_vars` is None, all the
variables are included.
model_loss = model.loss Returns:
if isinstance(model_loss, tf.keras.losses.Loss): A scalar vector, whose i-th entry is the norm of the gradient of the i-th
return _convert(model_loss) weighted example loss (when num_microbatches is None) or the norm of the
elif isinstance(model_loss, dict): gradient of the i-th microbatch loss (define as a mean over the microbatch).
# Note that we cannot call the public method `.get_compile_config()` because Note that when the loss is weighted (`weight_batch` is not None), weights
# it calls a numpy function, which is not supported inside a `tf.function` are applied prior to clipping.
# wrapped function.
compile_config = model._compile_config.config # pylint: disable=protected-access Raises:
if compile_config is None: ValueError: If `layer_grad_vars` is empty.
raise ValueError('Model must be compiled for loss function conversion') ValueError: If the number of gradients for a layer is not equal to the
# Does a weighted mean of the configured losses. Note that we cannot build number of squared norm functions for that layer.
# from the config of the compiled loss because (i) it builds a """
# `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s if trainable_vars is not None:
# during its construction, (ii) non-unique `tf.Variables` cannot be used # Create a set using `ref()` for fast set membership check. tf.Variable
# inside a `tf.function`, which is usually where this function is used. # itself is not hashable.
if 'loss_weights' not in compile_config: trainable_vars = set([v.ref() for v in trainable_vars])
raise ValueError(
'Models with multiple loss must have corresponding loss weights for' layer_sqr_norm_fns = collections.defaultdict(list)
' loss function conversion' # The case of shared weights:
# If a layer is called k times, it will appear k times in filtered_outputs,
# with the same id, but potentially with different v and f. The code below
# groups filtered_outputs by layer_id, so we can correctly compute gradient
# norms. The gradient norm of a layer that occurs k times is computed as
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
# occurrence. This is an over-estimate of the actual norm. For more details,
# see the explanation in go/dp-sgd-shared-weights.
for registry_fn_output in registry_fn_outputs_list:
if trainable_vars is None or any(
w.ref() in trainable_vars
for w in registry_fn_output.layer_trainable_weights
):
layer_sqr_norm_fns[registry_fn_output.layer_id].append(
registry_fn_output.layer_sqr_norm_fn
) )
weights = compile_config['loss_weights']
per_example_losses = {k: _convert(v) for k, v in model_loss.items()}
num_losses = len(weights)
def _per_example_loss_fn(y_true, y_pred, sample_weight=None): if not layer_grad_vars:
loss_values = [] raise ValueError('The gradient list cannot be empty.')
if model_loss.keys() - y_pred.keys(): sqr_norm_list = []
raise ValueError( for layer_id in layer_sqr_norm_fns.keys():
'y_pred must contain the same keys and the model losses, but ' fns = layer_sqr_norm_fns[layer_id]
'got %s and %s' % (y_pred.keys(), model_loss.keys()) grads = layer_grad_vars[layer_id]
) # Number of duplicates for this layer in `filtered_outputs`.
if model_loss.keys() - y_true.keys(): num_passes = len(fns)
raise ValueError( if len(fns) != len(grads):
'y_true must contain the same keys and the model losses, but ' raise ValueError(
'got %s and %s' % (y_true.keys(), model_loss.keys()) 'There must be as many gradients as squared norm functions.'
) )
if sample_weight is not None: # See go/dp-sgd-shared-weights for more details.
if model_loss.keys() - sample_weight.keys(): for fn, grad in zip(fns, grads):
raise ValueError( sqr_norm_list.append(num_passes * fn(grad))
'sample_weight must contain the same keys and the model losses,' sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
' but got %s and %s' % (y_true.keys(), model_loss.keys()) gradient_norms = tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
) return gradient_norms
for k in y_true.keys():
sgl_sample_weight = None if sample_weight is None else sample_weight[k]
sgl_value = (
weights[k]
* per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight)
/ num_losses
)
loss_values.append(tf.reshape(sgl_value, shape=[-1]))
return tf.math.add_n(loss_values)
return _per_example_loss_fn
else:
raise ValueError(
'Unsupported type for loss function conversion: {}'.format(
type(model_loss)
)
)
def compute_gradient_norms( def compute_gradient_norms(
@ -110,7 +118,7 @@ def compute_gradient_norms(
per_example_loss_fn: Optional[type_aliases.LossFn] = None, per_example_loss_fn: Optional[type_aliases.LossFn] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None, num_microbatches: Optional[type_aliases.BatchSize] = None,
trainable_vars: Optional[Sequence[tf.Variable]] = None, trainable_vars: Optional[Sequence[tf.Variable]] = None,
): ) -> tf.Tensor:
"""Computes the per-example loss gradient norms for given data. """Computes the per-example loss gradient norms for given data.
Applies a variant of the approach given in Applies a variant of the approach given in
@ -154,90 +162,27 @@ def compute_gradient_norms(
""" """
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn( registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
tape, layer_registry, num_microbatches tape=tape,
layer_registry=layer_registry,
num_microbatches=num_microbatches,
) )
# First loop computes the model outputs, summed loss, and generator outputs. layer_grad_vars, generator_outputs_list = (
with tape: gradient_clipping_utils.model_forward_backward_pass(
model_outputs, generator_outputs_list = ( tape=tape,
gradient_clipping_utils.model_forward_pass( input_model=input_model,
input_model, x_batch, generator_fn=registry_generator_fn x_batch=x_batch,
) y_batch=y_batch,
) registry_generator_fn=registry_generator_fn,
weight_batch=weight_batch,
# Ignore the original loss function's reduction to get per-example loss. per_example_loss_fn=per_example_loss_fn,
if per_example_loss_fn is None: num_microbatches=num_microbatches,
per_example_loss_fn = _infer_per_example_loss_fn(input_model)
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
if losses.shape is None:
raise NotImplementedError(
"The unreduced (or per-example) loss's shape cannot be `None`"
) )
if len(losses.shape) != 1:
raise NotImplementedError(
'The unreduced (or per-example) loss needs to have a shape of length '
'one, but received an unreduced loss of shape length %s'
% len(losses.shape)
)
if num_microbatches is not None:
losses = tf.reduce_mean(
common_manip_utils.maybe_add_microbatch_axis(
losses, num_microbatches
),
axis=1,
)
summed_loss = tf.reduce_sum(losses)
# Unwrap the generator outputs so that the next loop avoids duplicating
# backprop ops.
filtered_outputs = [t for t in generator_outputs_list if t is not None]
if trainable_vars is not None:
# Create a set using `ref()` for fast set membership check. tf.Variable
# itself is not hashable.
trainable_vars = set([v.ref() for v in trainable_vars])
layer_vars = collections.defaultdict(list)
layer_sqr_norm_fns = collections.defaultdict(list)
# The case of shared weights:
# If a layer is called k times, it will appear k times in filtered_outputs,
# with the same id, but potentially with different v and f. The code below
# groups filtered_outputs by layer_id, so we can correctly compute gradient
# norms. The gradient norm of a layer that occurs k times is computed as
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
# occurrence. This is an over-estimate of the actual norm. For more details,
# see the explanation in go/dp-sgd-shared-weights.
for registry_fn_output in filtered_outputs:
if trainable_vars is None or any(
w.ref() in trainable_vars
for w in registry_fn_output.layer_trainable_weights
):
layer_vars[registry_fn_output.layer_id].append(
registry_fn_output.layer_vars
)
layer_sqr_norm_fns[registry_fn_output.layer_id].append(
registry_fn_output.layer_sqr_norm_fn
)
# Second loop evaluates the squared L2 norm functions and appends the results.
layer_grad_vars = tape.gradient(
summed_loss,
layer_vars,
unconnected_gradients=tf.UnconnectedGradients.ZERO,
) )
if not layer_grad_vars: return _compute_gradient_norms_internal(
raise ValueError('The gradient list cannot be empty.') registry_fn_outputs_list=generator_outputs_list,
sqr_norm_list = [] layer_grad_vars=layer_grad_vars,
for layer_id in layer_sqr_norm_fns.keys(): trainable_vars=trainable_vars,
fns = layer_sqr_norm_fns[layer_id] )
grads = layer_grad_vars[layer_id]
# Number of duplicates for this layer in `filtered_outputs`.
num_passes = len(fns)
if len(fns) != len(grads):
raise ValueError(
'There must be as many gradients as squared norm functions.'
)
# See go/dp-sgd-shared-weights for more details.
for fn, grad in zip(fns, grads):
sqr_norm_list.append(num_passes * fn(grad))
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor): def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
@ -267,14 +212,17 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
def compute_clipped_gradients_and_outputs( def compute_clipped_gradients_and_outputs(
input_model: tf.keras.Model, input_model: tf.keras.Model,
registry_fn_outputs_list: Sequence[
gradient_clipping_utils.RegistryGeneratorFunctionOutput
],
layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]],
l2_norm_clip: float, l2_norm_clip: float,
layer_registry: lr.LayerRegistry,
x_batch: type_aliases.InputTensors, x_batch: type_aliases.InputTensors,
y_batch: type_aliases.OutputTensors, y_batch: type_aliases.OutputTensors,
weight_batch: Optional[tf.Tensor] = None, weight_batch: Optional[tf.Tensor] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None, num_microbatches: Optional[type_aliases.BatchSize] = None,
clipping_loss: Optional[type_aliases.LossFn] = None, clipping_loss: Optional[type_aliases.LossFn] = None,
) -> tuple[Sequence[tf.Tensor], tf.Tensor, tf.Tensor]: ) -> tuple[Sequence[type_aliases.Tensor], tf.Tensor, tf.Tensor]:
"""Computes the per-example clipped loss gradient and other useful outputs. """Computes the per-example clipped loss gradient and other useful outputs.
Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main
@ -287,15 +235,16 @@ def compute_clipped_gradients_and_outputs(
Args: Args:
input_model: The `tf.keras.Model` from which to obtain the layers from. input_model: The `tf.keras.Model` from which to obtain the layers from.
registry_fn_outputs_list: A `list` of RegistryGeneratorFunctionOutput
containing information required to compute the gradient norms and
contribution counts. Output from
`gradient_clipping_utils.model_forward_backward_pass()`.
layer_grad_vars: A mapping of layer id to a list of gradients for each
trainablev ariable in the layer. Output from
`gradient_clipping_utils.model_forward_backward_pass()`.
l2_norm_clip: A `float` indicating the norm to which per-example gradients l2_norm_clip: A `float` indicating the norm to which per-example gradients
will be clipped. That is, all gradients of the per-example loss functions will be clipped. That is, all gradients of the per-example loss functions
will have norm at most `l2_norm_clip`. will have norm at most `l2_norm_clip`.
layer_registry: A `dict` of layers that support "fast" gradient norm
computations. The key is the class of the layer and the value is a
function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable weights (see `layer_registry_factories.py` for examples).
x_batch: An `InputTensor` representing a batch of inputs to the model. The x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axes of each tensor must be the batch dimension. first axes of each tensor must be the batch dimension.
y_batch: An `OutputTensor` representing a batch of output labels. The first y_batch: An `OutputTensor` representing a batch of output labels. The first
@ -330,13 +279,9 @@ def compute_clipped_gradients_and_outputs(
) )
if clipping_loss is None: if clipping_loss is None:
clipping_loss = input_model.compiled_loss clipping_loss = input_model.compiled_loss
gradient_norms = compute_gradient_norms( gradient_norms = _compute_gradient_norms_internal(
input_model, registry_fn_outputs_list=registry_fn_outputs_list,
layer_registry, layer_grad_vars=layer_grad_vars,
x_batch,
y_batch,
weight_batch,
num_microbatches=num_microbatches,
trainable_vars=input_model.trainable_variables, trainable_vars=input_model.trainable_variables,
) )
clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms) clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms)

View file

@ -19,6 +19,7 @@ from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
@ -122,6 +123,29 @@ class CustomLayerTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
def _run_model_forward_backward_pass(
model: tf.keras.Model,
x_batch: type_aliases.InputTensors,
y_batch: type_aliases.OutputTensors,
):
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
tape=tape,
layer_registry=layer_registry.make_default_layer_registry(),
num_microbatches=None,
)
layer_grad_vars, registry_fn_outputs_list = (
gradient_clipping_utils.model_forward_backward_pass(
tape=tape,
input_model=model,
x_batch=x_batch,
y_batch=y_batch,
registry_generator_fn=registry_generator_fn,
)
)
return layer_grad_vars, registry_fn_outputs_list
class ComputeClippedGradsAndOutputsTest( class ComputeClippedGradsAndOutputsTest(
tf.test.TestCase, parameterized.TestCase tf.test.TestCase, parameterized.TestCase
): ):
@ -153,13 +177,17 @@ class ComputeClippedGradsAndOutputsTest(
y_batch = tf.reshape( y_batch = tf.reshape(
1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1] 1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1]
) )
layer_grad_vars, registry_fn_outputs_list = (
_run_model_forward_backward_pass(self._model, x_batch, y_batch)
)
# Stop early for efficiency. # Stop early for efficiency.
if reduction == 'none': if reduction == 'none':
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
clip_grads.compute_clipped_gradients_and_outputs( clip_grads.compute_clipped_gradients_and_outputs(
self._model, self._model,
registry_fn_outputs_list,
layer_grad_vars,
l2_norm_clip, l2_norm_clip,
layer_registry.make_default_layer_registry(),
x_batch, x_batch,
y_batch, y_batch,
) )
@ -169,10 +197,12 @@ class ComputeClippedGradsAndOutputsTest(
y_pred = self._model(x_batch) y_pred = self._model(x_batch)
loss_value = loss_fn(y_pred, y_batch) loss_value = loss_fn(y_pred, y_batch)
true_grads = tape.gradient(loss_value, self._model.trainable_variables) true_grads = tape.gradient(loss_value, self._model.trainable_variables)
clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs( clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs(
self._model, self._model,
registry_fn_outputs_list,
layer_grad_vars,
l2_norm_clip, l2_norm_clip,
layer_registry.make_default_layer_registry(),
x_batch, x_batch,
y_batch, y_batch,
) )

View file

@ -13,11 +13,13 @@
# limitations under the License. # limitations under the License.
"""Utility functions that help in the computation of per-example gradient norms.""" """Utility functions that help in the computation of per-example gradient norms."""
import collections
from collections.abc import Callable, Sequence, Set from collections.abc import Callable, Sequence, Set
import dataclasses import dataclasses
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
@ -98,6 +100,149 @@ def get_registry_generator_fn(
return registry_generator_fn return registry_generator_fn
def _infer_per_example_loss_fn(model: tf.keras.Model):
"""Infer the per-example loss from model config."""
def _convert(loss_fn):
loss_config = loss_fn.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
return loss_fn.from_config(loss_config)
model_loss = model.loss
if isinstance(model_loss, tf.keras.losses.Loss):
return _convert(model_loss)
elif isinstance(model_loss, dict):
# Note that we cannot call the public method `.get_compile_config()` because
# it calls a numpy function, which is not supported inside a `tf.function`
# wrapped function.
compile_config = model._compile_config.config # pylint: disable=protected-access
if compile_config is None:
raise ValueError('Model must be compiled for loss function conversion')
# Does a weighted mean of the configured losses. Note that we cannot build
# from the config of the compiled loss because (i) it builds a
# `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s
# during its construction, (ii) non-unique `tf.Variables` cannot be used
# inside a `tf.function`, which is usually where this function is used.
if 'loss_weights' not in compile_config:
raise ValueError(
'Models with multiple loss must have corresponding loss weights for'
' loss function conversion'
)
weights = compile_config['loss_weights']
per_example_losses = {k: _convert(v) for k, v in model_loss.items()}
num_losses = len(weights)
def _per_example_loss_fn(y_true, y_pred, sample_weight=None):
loss_values = []
if model_loss.keys() - y_pred.keys():
raise ValueError(
'y_pred must contain the same keys and the model losses, but '
'got %s and %s' % (y_pred.keys(), model_loss.keys())
)
if model_loss.keys() - y_true.keys():
raise ValueError(
'y_true must contain the same keys and the model losses, but '
'got %s and %s' % (y_true.keys(), model_loss.keys())
)
if sample_weight is not None:
if model_loss.keys() - sample_weight.keys():
raise ValueError(
'sample_weight must contain the same keys and the model losses,'
' but got %s and %s' % (y_true.keys(), model_loss.keys())
)
for k in y_true.keys():
sgl_sample_weight = None if sample_weight is None else sample_weight[k]
sgl_value = (
weights[k]
* per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight)
/ num_losses
)
loss_values.append(tf.reshape(sgl_value, shape=[-1]))
return tf.math.add_n(loss_values)
return _per_example_loss_fn
else:
raise ValueError(
'Unsupported type for loss function conversion: {}'.format(
type(model_loss)
)
)
def model_forward_backward_pass(
tape: tf.GradientTape,
input_model: tf.keras.Model,
x_batch: type_aliases.InputTensors,
y_batch: type_aliases.OutputTensors,
registry_generator_fn: Optional[
Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]
],
weight_batch: Optional[tf.Tensor] = None,
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None,
trainable_vars: Optional[Sequence[tf.Variable]] = None,
) -> tuple[
dict[str, list[type_aliases.Tensor]], list[RegistryGeneratorFunctionOutput]
]:
"""Does a forward and backward pass of a model and returns useful intermediates."""
# First loop computes the model outputs, summed loss, and generator outputs.
with tape:
model_outputs, generator_outputs_list = model_forward_pass(
input_model, x_batch, generator_fn=registry_generator_fn
)
# Ignore the original loss function's reduction to get per-example loss.
if per_example_loss_fn is None:
per_example_loss_fn = _infer_per_example_loss_fn(input_model)
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
if losses.shape is None:
raise NotImplementedError(
"The unreduced (or per-example) loss's shape cannot be `None`"
)
if len(losses.shape) != 1:
raise NotImplementedError(
'The unreduced (or per-example) loss needs to have a shape of length '
'one, but received an unreduced loss of shape length %s'
% len(losses.shape)
)
if num_microbatches is not None:
losses = tf.reduce_mean(
common_manip_utils.maybe_add_microbatch_axis(
losses, num_microbatches
),
axis=1,
)
summed_loss = tf.reduce_sum(losses)
# Unwrap the generator outputs so that the next loop avoids duplicating
# backprop ops.
filtered_outputs = [t for t in generator_outputs_list if t is not None]
if trainable_vars is not None:
# Create a set using `ref()` for fast set membership check. tf.Variable
# itself is not hashable.
trainable_vars = set([v.ref() for v in trainable_vars])
layer_vars = collections.defaultdict(list)
for registry_fn_output in filtered_outputs:
if trainable_vars is None or any(
w.ref() in trainable_vars
for w in registry_fn_output.layer_trainable_weights
):
layer_vars[registry_fn_output.layer_id].append(
registry_fn_output.layer_vars
)
layer_grad_vars = tape.gradient(
summed_loss,
layer_vars,
unconnected_gradients=tf.UnconnectedGradients.ZERO,
)
if not layer_grad_vars:
raise ValueError('The gradient list cannot be empty.')
return layer_grad_vars, filtered_outputs
def model_forward_pass( def model_forward_pass(
input_model: tf.keras.Model, input_model: tf.keras.Model,
inputs: type_aliases.PackedTensors, inputs: type_aliases.PackedTensors,

View file

@ -19,6 +19,8 @@ import tensorflow as tf
# Tensorflow aliases. # Tensorflow aliases.
Tensor = Union[tf.Tensor, tf.IndexedSlices, tf.SparseTensor, tf.RaggedTensor]
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]] PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]]
InputTensors = PackedTensors InputTensors = PackedTensors

View file

@ -274,14 +274,36 @@ def make_dp_model_class(cls):
# trick, and uses these norms to clip the per-example gradients. # trick, and uses these norms to clip the per-example gradients.
# NOTE: Reshaping of the input according to the effective number of # NOTE: Reshaping of the input according to the effective number of
# microbatches is done here. # microbatches is done here.
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
registry_generator_fn = (
gradient_clipping_utils.get_registry_generator_fn(
tape=tape,
layer_registry=self._layer_registry,
num_microbatches=num_microbatches,
)
)
layer_grad_vars, registry_fn_outputs_list = (
gradient_clipping_utils.model_forward_backward_pass(
tape=tape,
input_model=self,
x_batch=x,
y_batch=y,
registry_generator_fn=registry_generator_fn,
weight_batch=weights,
num_microbatches=num_microbatches,
trainable_vars=self.trainable_variables,
)
)
clipped_grads, y_pred, clipping_loss = ( clipped_grads, y_pred, clipping_loss = (
clip_grads.compute_clipped_gradients_and_outputs( clip_grads.compute_clipped_gradients_and_outputs(
input_model=self, input_model=self,
registry_fn_outputs_list=registry_fn_outputs_list,
layer_grad_vars=layer_grad_vars,
x_batch=x, x_batch=x,
y_batch=y, y_batch=y,
weight_batch=weights, weight_batch=weights,
l2_norm_clip=self._l2_norm_clip, l2_norm_clip=self._l2_norm_clip,
layer_registry=self._layer_registry,
num_microbatches=self._num_microbatches, num_microbatches=self._num_microbatches,
clipping_loss=self._clipping_loss, clipping_loss=self._clipping_loss,
) )