Sparsity Preserving DP-SGD in TF Privacy
Refactor model_forward_backward_pass out of compute_gradients to allow for other optimizations such as sparsity preserving noise to integrate with it. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 660924829
This commit is contained in:
parent
8294cec132
commit
09c68750d7
6 changed files with 310 additions and 164 deletions
|
@ -43,6 +43,7 @@ py_library(
|
|||
srcs = ["gradient_clipping_utils.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":common_manip_utils",
|
||||
":layer_registry",
|
||||
":type_aliases",
|
||||
],
|
||||
|
@ -94,6 +95,7 @@ py_test(
|
|||
deps = [
|
||||
":clip_grads",
|
||||
":common_test_utils",
|
||||
":gradient_clipping_utils",
|
||||
":layer_registry",
|
||||
":type_aliases",
|
||||
],
|
||||
|
|
|
@ -22,7 +22,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
|
|||
"""
|
||||
|
||||
import collections
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Optional
|
||||
|
||||
import tensorflow as tf
|
||||
|
@ -32,73 +32,81 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as
|
|||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||
|
||||
|
||||
def _infer_per_example_loss_fn(model: tf.keras.Model):
|
||||
"""Infer the per-example loss from model config."""
|
||||
def _compute_gradient_norms_internal(
|
||||
registry_fn_outputs_list: Sequence[
|
||||
gradient_clipping_utils.RegistryGeneratorFunctionOutput
|
||||
],
|
||||
layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]],
|
||||
trainable_vars: Optional[Sequence[tf.Variable]] = None,
|
||||
) -> tf.Tensor:
|
||||
"""Computes the per-example loss gradient norms for given data.
|
||||
|
||||
def _convert(loss_fn):
|
||||
loss_config = loss_fn.get_config()
|
||||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||
return loss_fn.from_config(loss_config)
|
||||
Args:
|
||||
registry_fn_outputs_list: A sequence of RegistryGeneratorFunctionOutput
|
||||
containing information required to compute the gradient norms and
|
||||
contribution counts. Output from
|
||||
`gradient_clipping_utils.model_forward_backward_pass()`.
|
||||
layer_grad_vars: A mapping of layer id to a list of gradients for each
|
||||
trainable variable in the layer. Output from
|
||||
`gradient_clipping_utils.model_forward_backward_pass()`.
|
||||
trainable_vars: The list of variables included in computing the gradient
|
||||
norm. When a layer has multiple variables, we include all the variables if
|
||||
any of the variables is in the list. If `trainable_vars` is None, all the
|
||||
variables are included.
|
||||
|
||||
model_loss = model.loss
|
||||
if isinstance(model_loss, tf.keras.losses.Loss):
|
||||
return _convert(model_loss)
|
||||
elif isinstance(model_loss, dict):
|
||||
# Note that we cannot call the public method `.get_compile_config()` because
|
||||
# it calls a numpy function, which is not supported inside a `tf.function`
|
||||
# wrapped function.
|
||||
compile_config = model._compile_config.config # pylint: disable=protected-access
|
||||
if compile_config is None:
|
||||
raise ValueError('Model must be compiled for loss function conversion')
|
||||
# Does a weighted mean of the configured losses. Note that we cannot build
|
||||
# from the config of the compiled loss because (i) it builds a
|
||||
# `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s
|
||||
# during its construction, (ii) non-unique `tf.Variables` cannot be used
|
||||
# inside a `tf.function`, which is usually where this function is used.
|
||||
if 'loss_weights' not in compile_config:
|
||||
raise ValueError(
|
||||
'Models with multiple loss must have corresponding loss weights for'
|
||||
' loss function conversion'
|
||||
Returns:
|
||||
A scalar vector, whose i-th entry is the norm of the gradient of the i-th
|
||||
weighted example loss (when num_microbatches is None) or the norm of the
|
||||
gradient of the i-th microbatch loss (define as a mean over the microbatch).
|
||||
Note that when the loss is weighted (`weight_batch` is not None), weights
|
||||
are applied prior to clipping.
|
||||
|
||||
Raises:
|
||||
ValueError: If `layer_grad_vars` is empty.
|
||||
ValueError: If the number of gradients for a layer is not equal to the
|
||||
number of squared norm functions for that layer.
|
||||
"""
|
||||
if trainable_vars is not None:
|
||||
# Create a set using `ref()` for fast set membership check. tf.Variable
|
||||
# itself is not hashable.
|
||||
trainable_vars = set([v.ref() for v in trainable_vars])
|
||||
|
||||
layer_sqr_norm_fns = collections.defaultdict(list)
|
||||
# The case of shared weights:
|
||||
# If a layer is called k times, it will appear k times in filtered_outputs,
|
||||
# with the same id, but potentially with different v and f. The code below
|
||||
# groups filtered_outputs by layer_id, so we can correctly compute gradient
|
||||
# norms. The gradient norm of a layer that occurs k times is computed as
|
||||
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
|
||||
# occurrence. This is an over-estimate of the actual norm. For more details,
|
||||
# see the explanation in go/dp-sgd-shared-weights.
|
||||
for registry_fn_output in registry_fn_outputs_list:
|
||||
if trainable_vars is None or any(
|
||||
w.ref() in trainable_vars
|
||||
for w in registry_fn_output.layer_trainable_weights
|
||||
):
|
||||
layer_sqr_norm_fns[registry_fn_output.layer_id].append(
|
||||
registry_fn_output.layer_sqr_norm_fn
|
||||
)
|
||||
weights = compile_config['loss_weights']
|
||||
per_example_losses = {k: _convert(v) for k, v in model_loss.items()}
|
||||
num_losses = len(weights)
|
||||
|
||||
def _per_example_loss_fn(y_true, y_pred, sample_weight=None):
|
||||
loss_values = []
|
||||
if model_loss.keys() - y_pred.keys():
|
||||
raise ValueError(
|
||||
'y_pred must contain the same keys and the model losses, but '
|
||||
'got %s and %s' % (y_pred.keys(), model_loss.keys())
|
||||
)
|
||||
if model_loss.keys() - y_true.keys():
|
||||
raise ValueError(
|
||||
'y_true must contain the same keys and the model losses, but '
|
||||
'got %s and %s' % (y_true.keys(), model_loss.keys())
|
||||
)
|
||||
if sample_weight is not None:
|
||||
if model_loss.keys() - sample_weight.keys():
|
||||
raise ValueError(
|
||||
'sample_weight must contain the same keys and the model losses,'
|
||||
' but got %s and %s' % (y_true.keys(), model_loss.keys())
|
||||
)
|
||||
for k in y_true.keys():
|
||||
sgl_sample_weight = None if sample_weight is None else sample_weight[k]
|
||||
sgl_value = (
|
||||
weights[k]
|
||||
* per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight)
|
||||
/ num_losses
|
||||
)
|
||||
loss_values.append(tf.reshape(sgl_value, shape=[-1]))
|
||||
return tf.math.add_n(loss_values)
|
||||
|
||||
return _per_example_loss_fn
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unsupported type for loss function conversion: {}'.format(
|
||||
type(model_loss)
|
||||
)
|
||||
)
|
||||
if not layer_grad_vars:
|
||||
raise ValueError('The gradient list cannot be empty.')
|
||||
sqr_norm_list = []
|
||||
for layer_id in layer_sqr_norm_fns.keys():
|
||||
fns = layer_sqr_norm_fns[layer_id]
|
||||
grads = layer_grad_vars[layer_id]
|
||||
# Number of duplicates for this layer in `filtered_outputs`.
|
||||
num_passes = len(fns)
|
||||
if len(fns) != len(grads):
|
||||
raise ValueError(
|
||||
'There must be as many gradients as squared norm functions.'
|
||||
)
|
||||
# See go/dp-sgd-shared-weights for more details.
|
||||
for fn, grad in zip(fns, grads):
|
||||
sqr_norm_list.append(num_passes * fn(grad))
|
||||
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
|
||||
gradient_norms = tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
||||
return gradient_norms
|
||||
|
||||
|
||||
def compute_gradient_norms(
|
||||
|
@ -110,7 +118,7 @@ def compute_gradient_norms(
|
|||
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
|
||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||
trainable_vars: Optional[Sequence[tf.Variable]] = None,
|
||||
):
|
||||
) -> tf.Tensor:
|
||||
"""Computes the per-example loss gradient norms for given data.
|
||||
|
||||
Applies a variant of the approach given in
|
||||
|
@ -154,90 +162,27 @@ def compute_gradient_norms(
|
|||
"""
|
||||
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
||||
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
|
||||
tape, layer_registry, num_microbatches
|
||||
tape=tape,
|
||||
layer_registry=layer_registry,
|
||||
num_microbatches=num_microbatches,
|
||||
)
|
||||
# First loop computes the model outputs, summed loss, and generator outputs.
|
||||
with tape:
|
||||
model_outputs, generator_outputs_list = (
|
||||
gradient_clipping_utils.model_forward_pass(
|
||||
input_model, x_batch, generator_fn=registry_generator_fn
|
||||
)
|
||||
)
|
||||
|
||||
# Ignore the original loss function's reduction to get per-example loss.
|
||||
if per_example_loss_fn is None:
|
||||
per_example_loss_fn = _infer_per_example_loss_fn(input_model)
|
||||
|
||||
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
|
||||
if losses.shape is None:
|
||||
raise NotImplementedError(
|
||||
"The unreduced (or per-example) loss's shape cannot be `None`"
|
||||
layer_grad_vars, generator_outputs_list = (
|
||||
gradient_clipping_utils.model_forward_backward_pass(
|
||||
tape=tape,
|
||||
input_model=input_model,
|
||||
x_batch=x_batch,
|
||||
y_batch=y_batch,
|
||||
registry_generator_fn=registry_generator_fn,
|
||||
weight_batch=weight_batch,
|
||||
per_example_loss_fn=per_example_loss_fn,
|
||||
num_microbatches=num_microbatches,
|
||||
)
|
||||
if len(losses.shape) != 1:
|
||||
raise NotImplementedError(
|
||||
'The unreduced (or per-example) loss needs to have a shape of length '
|
||||
'one, but received an unreduced loss of shape length %s'
|
||||
% len(losses.shape)
|
||||
)
|
||||
if num_microbatches is not None:
|
||||
losses = tf.reduce_mean(
|
||||
common_manip_utils.maybe_add_microbatch_axis(
|
||||
losses, num_microbatches
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
summed_loss = tf.reduce_sum(losses)
|
||||
# Unwrap the generator outputs so that the next loop avoids duplicating
|
||||
# backprop ops.
|
||||
filtered_outputs = [t for t in generator_outputs_list if t is not None]
|
||||
if trainable_vars is not None:
|
||||
# Create a set using `ref()` for fast set membership check. tf.Variable
|
||||
# itself is not hashable.
|
||||
trainable_vars = set([v.ref() for v in trainable_vars])
|
||||
layer_vars = collections.defaultdict(list)
|
||||
layer_sqr_norm_fns = collections.defaultdict(list)
|
||||
# The case of shared weights:
|
||||
# If a layer is called k times, it will appear k times in filtered_outputs,
|
||||
# with the same id, but potentially with different v and f. The code below
|
||||
# groups filtered_outputs by layer_id, so we can correctly compute gradient
|
||||
# norms. The gradient norm of a layer that occurs k times is computed as
|
||||
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
|
||||
# occurrence. This is an over-estimate of the actual norm. For more details,
|
||||
# see the explanation in go/dp-sgd-shared-weights.
|
||||
for registry_fn_output in filtered_outputs:
|
||||
if trainable_vars is None or any(
|
||||
w.ref() in trainable_vars
|
||||
for w in registry_fn_output.layer_trainable_weights
|
||||
):
|
||||
layer_vars[registry_fn_output.layer_id].append(
|
||||
registry_fn_output.layer_vars
|
||||
)
|
||||
layer_sqr_norm_fns[registry_fn_output.layer_id].append(
|
||||
registry_fn_output.layer_sqr_norm_fn
|
||||
)
|
||||
# Second loop evaluates the squared L2 norm functions and appends the results.
|
||||
layer_grad_vars = tape.gradient(
|
||||
summed_loss,
|
||||
layer_vars,
|
||||
unconnected_gradients=tf.UnconnectedGradients.ZERO,
|
||||
)
|
||||
if not layer_grad_vars:
|
||||
raise ValueError('The gradient list cannot be empty.')
|
||||
sqr_norm_list = []
|
||||
for layer_id in layer_sqr_norm_fns.keys():
|
||||
fns = layer_sqr_norm_fns[layer_id]
|
||||
grads = layer_grad_vars[layer_id]
|
||||
# Number of duplicates for this layer in `filtered_outputs`.
|
||||
num_passes = len(fns)
|
||||
if len(fns) != len(grads):
|
||||
raise ValueError(
|
||||
'There must be as many gradients as squared norm functions.'
|
||||
)
|
||||
# See go/dp-sgd-shared-weights for more details.
|
||||
for fn, grad in zip(fns, grads):
|
||||
sqr_norm_list.append(num_passes * fn(grad))
|
||||
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
|
||||
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
||||
return _compute_gradient_norms_internal(
|
||||
registry_fn_outputs_list=generator_outputs_list,
|
||||
layer_grad_vars=layer_grad_vars,
|
||||
trainable_vars=trainable_vars,
|
||||
)
|
||||
|
||||
|
||||
def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
|
||||
|
@ -267,14 +212,17 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
|
|||
|
||||
def compute_clipped_gradients_and_outputs(
|
||||
input_model: tf.keras.Model,
|
||||
registry_fn_outputs_list: Sequence[
|
||||
gradient_clipping_utils.RegistryGeneratorFunctionOutput
|
||||
],
|
||||
layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]],
|
||||
l2_norm_clip: float,
|
||||
layer_registry: lr.LayerRegistry,
|
||||
x_batch: type_aliases.InputTensors,
|
||||
y_batch: type_aliases.OutputTensors,
|
||||
weight_batch: Optional[tf.Tensor] = None,
|
||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||
clipping_loss: Optional[type_aliases.LossFn] = None,
|
||||
) -> tuple[Sequence[tf.Tensor], tf.Tensor, tf.Tensor]:
|
||||
) -> tuple[Sequence[type_aliases.Tensor], tf.Tensor, tf.Tensor]:
|
||||
"""Computes the per-example clipped loss gradient and other useful outputs.
|
||||
|
||||
Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main
|
||||
|
@ -287,15 +235,16 @@ def compute_clipped_gradients_and_outputs(
|
|||
|
||||
Args:
|
||||
input_model: The `tf.keras.Model` from which to obtain the layers from.
|
||||
registry_fn_outputs_list: A `list` of RegistryGeneratorFunctionOutput
|
||||
containing information required to compute the gradient norms and
|
||||
contribution counts. Output from
|
||||
`gradient_clipping_utils.model_forward_backward_pass()`.
|
||||
layer_grad_vars: A mapping of layer id to a list of gradients for each
|
||||
trainablev ariable in the layer. Output from
|
||||
`gradient_clipping_utils.model_forward_backward_pass()`.
|
||||
l2_norm_clip: A `float` indicating the norm to which per-example gradients
|
||||
will be clipped. That is, all gradients of the per-example loss functions
|
||||
will have norm at most `l2_norm_clip`.
|
||||
layer_registry: A `dict` of layers that support "fast" gradient norm
|
||||
computations. The key is the class of the layer and the value is a
|
||||
function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
|
||||
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
|
||||
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
||||
trainable weights (see `layer_registry_factories.py` for examples).
|
||||
x_batch: An `InputTensor` representing a batch of inputs to the model. The
|
||||
first axes of each tensor must be the batch dimension.
|
||||
y_batch: An `OutputTensor` representing a batch of output labels. The first
|
||||
|
@ -330,13 +279,9 @@ def compute_clipped_gradients_and_outputs(
|
|||
)
|
||||
if clipping_loss is None:
|
||||
clipping_loss = input_model.compiled_loss
|
||||
gradient_norms = compute_gradient_norms(
|
||||
input_model,
|
||||
layer_registry,
|
||||
x_batch,
|
||||
y_batch,
|
||||
weight_batch,
|
||||
num_microbatches=num_microbatches,
|
||||
gradient_norms = _compute_gradient_norms_internal(
|
||||
registry_fn_outputs_list=registry_fn_outputs_list,
|
||||
layer_grad_vars=layer_grad_vars,
|
||||
trainable_vars=input_model.trainable_variables,
|
||||
)
|
||||
clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
|
||||
|
|
|
@ -19,6 +19,7 @@ from absl.testing import parameterized
|
|||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||
|
||||
|
@ -122,6 +123,29 @@ class CustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
def _run_model_forward_backward_pass(
|
||||
model: tf.keras.Model,
|
||||
x_batch: type_aliases.InputTensors,
|
||||
y_batch: type_aliases.OutputTensors,
|
||||
):
|
||||
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
||||
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
|
||||
tape=tape,
|
||||
layer_registry=layer_registry.make_default_layer_registry(),
|
||||
num_microbatches=None,
|
||||
)
|
||||
layer_grad_vars, registry_fn_outputs_list = (
|
||||
gradient_clipping_utils.model_forward_backward_pass(
|
||||
tape=tape,
|
||||
input_model=model,
|
||||
x_batch=x_batch,
|
||||
y_batch=y_batch,
|
||||
registry_generator_fn=registry_generator_fn,
|
||||
)
|
||||
)
|
||||
return layer_grad_vars, registry_fn_outputs_list
|
||||
|
||||
|
||||
class ComputeClippedGradsAndOutputsTest(
|
||||
tf.test.TestCase, parameterized.TestCase
|
||||
):
|
||||
|
@ -153,13 +177,17 @@ class ComputeClippedGradsAndOutputsTest(
|
|||
y_batch = tf.reshape(
|
||||
1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1]
|
||||
)
|
||||
layer_grad_vars, registry_fn_outputs_list = (
|
||||
_run_model_forward_backward_pass(self._model, x_batch, y_batch)
|
||||
)
|
||||
# Stop early for efficiency.
|
||||
if reduction == 'none':
|
||||
with self.assertRaises(NotImplementedError):
|
||||
clip_grads.compute_clipped_gradients_and_outputs(
|
||||
self._model,
|
||||
registry_fn_outputs_list,
|
||||
layer_grad_vars,
|
||||
l2_norm_clip,
|
||||
layer_registry.make_default_layer_registry(),
|
||||
x_batch,
|
||||
y_batch,
|
||||
)
|
||||
|
@ -169,10 +197,12 @@ class ComputeClippedGradsAndOutputsTest(
|
|||
y_pred = self._model(x_batch)
|
||||
loss_value = loss_fn(y_pred, y_batch)
|
||||
true_grads = tape.gradient(loss_value, self._model.trainable_variables)
|
||||
|
||||
clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs(
|
||||
self._model,
|
||||
registry_fn_outputs_list,
|
||||
layer_grad_vars,
|
||||
l2_norm_clip,
|
||||
layer_registry.make_default_layer_registry(),
|
||||
x_batch,
|
||||
y_batch,
|
||||
)
|
||||
|
|
|
@ -13,11 +13,13 @@
|
|||
# limitations under the License.
|
||||
"""Utility functions that help in the computation of per-example gradient norms."""
|
||||
|
||||
import collections
|
||||
from collections.abc import Callable, Sequence, Set
|
||||
import dataclasses
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||
|
||||
|
@ -98,6 +100,149 @@ def get_registry_generator_fn(
|
|||
return registry_generator_fn
|
||||
|
||||
|
||||
def _infer_per_example_loss_fn(model: tf.keras.Model):
|
||||
"""Infer the per-example loss from model config."""
|
||||
|
||||
def _convert(loss_fn):
|
||||
loss_config = loss_fn.get_config()
|
||||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||
return loss_fn.from_config(loss_config)
|
||||
|
||||
model_loss = model.loss
|
||||
if isinstance(model_loss, tf.keras.losses.Loss):
|
||||
return _convert(model_loss)
|
||||
elif isinstance(model_loss, dict):
|
||||
# Note that we cannot call the public method `.get_compile_config()` because
|
||||
# it calls a numpy function, which is not supported inside a `tf.function`
|
||||
# wrapped function.
|
||||
compile_config = model._compile_config.config # pylint: disable=protected-access
|
||||
if compile_config is None:
|
||||
raise ValueError('Model must be compiled for loss function conversion')
|
||||
# Does a weighted mean of the configured losses. Note that we cannot build
|
||||
# from the config of the compiled loss because (i) it builds a
|
||||
# `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s
|
||||
# during its construction, (ii) non-unique `tf.Variables` cannot be used
|
||||
# inside a `tf.function`, which is usually where this function is used.
|
||||
if 'loss_weights' not in compile_config:
|
||||
raise ValueError(
|
||||
'Models with multiple loss must have corresponding loss weights for'
|
||||
' loss function conversion'
|
||||
)
|
||||
weights = compile_config['loss_weights']
|
||||
per_example_losses = {k: _convert(v) for k, v in model_loss.items()}
|
||||
num_losses = len(weights)
|
||||
|
||||
def _per_example_loss_fn(y_true, y_pred, sample_weight=None):
|
||||
loss_values = []
|
||||
if model_loss.keys() - y_pred.keys():
|
||||
raise ValueError(
|
||||
'y_pred must contain the same keys and the model losses, but '
|
||||
'got %s and %s' % (y_pred.keys(), model_loss.keys())
|
||||
)
|
||||
if model_loss.keys() - y_true.keys():
|
||||
raise ValueError(
|
||||
'y_true must contain the same keys and the model losses, but '
|
||||
'got %s and %s' % (y_true.keys(), model_loss.keys())
|
||||
)
|
||||
if sample_weight is not None:
|
||||
if model_loss.keys() - sample_weight.keys():
|
||||
raise ValueError(
|
||||
'sample_weight must contain the same keys and the model losses,'
|
||||
' but got %s and %s' % (y_true.keys(), model_loss.keys())
|
||||
)
|
||||
for k in y_true.keys():
|
||||
sgl_sample_weight = None if sample_weight is None else sample_weight[k]
|
||||
sgl_value = (
|
||||
weights[k]
|
||||
* per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight)
|
||||
/ num_losses
|
||||
)
|
||||
loss_values.append(tf.reshape(sgl_value, shape=[-1]))
|
||||
return tf.math.add_n(loss_values)
|
||||
|
||||
return _per_example_loss_fn
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unsupported type for loss function conversion: {}'.format(
|
||||
type(model_loss)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def model_forward_backward_pass(
|
||||
tape: tf.GradientTape,
|
||||
input_model: tf.keras.Model,
|
||||
x_batch: type_aliases.InputTensors,
|
||||
y_batch: type_aliases.OutputTensors,
|
||||
registry_generator_fn: Optional[
|
||||
Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]
|
||||
],
|
||||
weight_batch: Optional[tf.Tensor] = None,
|
||||
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
|
||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||
trainable_vars: Optional[Sequence[tf.Variable]] = None,
|
||||
) -> tuple[
|
||||
dict[str, list[type_aliases.Tensor]], list[RegistryGeneratorFunctionOutput]
|
||||
]:
|
||||
"""Does a forward and backward pass of a model and returns useful intermediates."""
|
||||
# First loop computes the model outputs, summed loss, and generator outputs.
|
||||
with tape:
|
||||
model_outputs, generator_outputs_list = model_forward_pass(
|
||||
input_model, x_batch, generator_fn=registry_generator_fn
|
||||
)
|
||||
|
||||
# Ignore the original loss function's reduction to get per-example loss.
|
||||
if per_example_loss_fn is None:
|
||||
per_example_loss_fn = _infer_per_example_loss_fn(input_model)
|
||||
|
||||
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
|
||||
if losses.shape is None:
|
||||
raise NotImplementedError(
|
||||
"The unreduced (or per-example) loss's shape cannot be `None`"
|
||||
)
|
||||
if len(losses.shape) != 1:
|
||||
raise NotImplementedError(
|
||||
'The unreduced (or per-example) loss needs to have a shape of length '
|
||||
'one, but received an unreduced loss of shape length %s'
|
||||
% len(losses.shape)
|
||||
)
|
||||
if num_microbatches is not None:
|
||||
losses = tf.reduce_mean(
|
||||
common_manip_utils.maybe_add_microbatch_axis(
|
||||
losses, num_microbatches
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
summed_loss = tf.reduce_sum(losses)
|
||||
# Unwrap the generator outputs so that the next loop avoids duplicating
|
||||
# backprop ops.
|
||||
filtered_outputs = [t for t in generator_outputs_list if t is not None]
|
||||
|
||||
if trainable_vars is not None:
|
||||
# Create a set using `ref()` for fast set membership check. tf.Variable
|
||||
# itself is not hashable.
|
||||
trainable_vars = set([v.ref() for v in trainable_vars])
|
||||
layer_vars = collections.defaultdict(list)
|
||||
for registry_fn_output in filtered_outputs:
|
||||
if trainable_vars is None or any(
|
||||
w.ref() in trainable_vars
|
||||
for w in registry_fn_output.layer_trainable_weights
|
||||
):
|
||||
layer_vars[registry_fn_output.layer_id].append(
|
||||
registry_fn_output.layer_vars
|
||||
)
|
||||
|
||||
layer_grad_vars = tape.gradient(
|
||||
summed_loss,
|
||||
layer_vars,
|
||||
unconnected_gradients=tf.UnconnectedGradients.ZERO,
|
||||
)
|
||||
if not layer_grad_vars:
|
||||
raise ValueError('The gradient list cannot be empty.')
|
||||
|
||||
return layer_grad_vars, filtered_outputs
|
||||
|
||||
|
||||
def model_forward_pass(
|
||||
input_model: tf.keras.Model,
|
||||
inputs: type_aliases.PackedTensors,
|
||||
|
|
|
@ -19,6 +19,8 @@ import tensorflow as tf
|
|||
|
||||
|
||||
# Tensorflow aliases.
|
||||
Tensor = Union[tf.Tensor, tf.IndexedSlices, tf.SparseTensor, tf.RaggedTensor]
|
||||
|
||||
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]]
|
||||
|
||||
InputTensors = PackedTensors
|
||||
|
|
|
@ -274,14 +274,36 @@ def make_dp_model_class(cls):
|
|||
# trick, and uses these norms to clip the per-example gradients.
|
||||
# NOTE: Reshaping of the input according to the effective number of
|
||||
# microbatches is done here.
|
||||
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
||||
|
||||
registry_generator_fn = (
|
||||
gradient_clipping_utils.get_registry_generator_fn(
|
||||
tape=tape,
|
||||
layer_registry=self._layer_registry,
|
||||
num_microbatches=num_microbatches,
|
||||
)
|
||||
)
|
||||
layer_grad_vars, registry_fn_outputs_list = (
|
||||
gradient_clipping_utils.model_forward_backward_pass(
|
||||
tape=tape,
|
||||
input_model=self,
|
||||
x_batch=x,
|
||||
y_batch=y,
|
||||
registry_generator_fn=registry_generator_fn,
|
||||
weight_batch=weights,
|
||||
num_microbatches=num_microbatches,
|
||||
trainable_vars=self.trainable_variables,
|
||||
)
|
||||
)
|
||||
clipped_grads, y_pred, clipping_loss = (
|
||||
clip_grads.compute_clipped_gradients_and_outputs(
|
||||
input_model=self,
|
||||
registry_fn_outputs_list=registry_fn_outputs_list,
|
||||
layer_grad_vars=layer_grad_vars,
|
||||
x_batch=x,
|
||||
y_batch=y,
|
||||
weight_batch=weights,
|
||||
l2_norm_clip=self._l2_norm_clip,
|
||||
layer_registry=self._layer_registry,
|
||||
num_microbatches=self._num_microbatches,
|
||||
clipping_loss=self._clipping_loss,
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue