Sparsity Preserving DP-SGD in TF Privacy

Move get_registry_generator_fn from clip_grads.py to gradient_clipping_utils.py and change return type to dataclass.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 660548431
This commit is contained in:
A. Unique TensorFlower 2024-08-07 14:53:58 -07:00
parent d3f527e775
commit 8294cec132
2 changed files with 78 additions and 44 deletions

View file

@ -32,43 +32,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
def get_registry_generator_fn(
tape: tf.GradientTape,
layer_registry: lr.LayerRegistry,
num_microbatches: Optional[type_aliases.BatchSize] = None,
):
"""Creates the generator function for `compute_gradient_norms()`."""
if layer_registry is None:
# Needed for backwards compatibility.
registry_generator_fn = None
else:
def registry_generator_fn(layer_instance, args, kwargs):
if layer_instance.trainable_variables:
# Only trainable variables factor into the gradient.
if not layer_registry.is_elem(layer_instance):
raise NotImplementedError(
'Layer %s is not in the registry of known layers that can '
'be used for efficient gradient clipping.'
% layer_instance.__class__.__name__
)
registry_fn = layer_registry.lookup(layer_instance)
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
layer_instance, args, kwargs, tape, num_microbatches
)
return layer_outputs, (
str(id(layer_instance)),
layer_vars,
layer_sqr_norm_fn,
layer_instance.trainable_weights,
)
else:
# Non-trainable layer.
return layer_instance(*args, **kwargs), None
return registry_generator_fn
def _infer_per_example_loss_fn(model: tf.keras.Model): def _infer_per_example_loss_fn(model: tf.keras.Model):
"""Infer the per-example loss from model config.""" """Infer the per-example loss from model config."""
@ -190,7 +153,7 @@ def compute_gradient_norms(
are applied prior to clipping. are applied prior to clipping.
""" """
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
registry_generator_fn = get_registry_generator_fn( registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
tape, layer_registry, num_microbatches tape, layer_registry, num_microbatches
) )
# First loop computes the model outputs, summed loss, and generator outputs. # First loop computes the model outputs, summed loss, and generator outputs.
@ -241,12 +204,17 @@ def compute_gradient_norms(
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th # $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
# occurrence. This is an over-estimate of the actual norm. For more details, # occurrence. This is an over-estimate of the actual norm. For more details,
# see the explanation in go/dp-sgd-shared-weights. # see the explanation in go/dp-sgd-shared-weights.
for layer_id, v, f, weights_list in filtered_outputs: for registry_fn_output in filtered_outputs:
if trainable_vars is None or any( if trainable_vars is None or any(
w.ref() in trainable_vars for w in weights_list w.ref() in trainable_vars
for w in registry_fn_output.layer_trainable_weights
): ):
layer_vars[layer_id].append(v) layer_vars[registry_fn_output.layer_id].append(
layer_sqr_norm_fns[layer_id].append(f) registry_fn_output.layer_vars
)
layer_sqr_norm_fns[registry_fn_output.layer_id].append(
registry_fn_output.layer_sqr_norm_fn
)
# Second loop evaluates the squared L2 norm functions and appends the results. # Second loop evaluates the squared L2 norm functions and appends the results.
layer_grad_vars = tape.gradient( layer_grad_vars = tape.gradient(
summed_loss, summed_loss,

View file

@ -13,14 +13,23 @@
# limitations under the License. # limitations under the License.
"""Utility functions that help in the computation of per-example gradient norms.""" """Utility functions that help in the computation of per-example gradient norms."""
from collections.abc import Sequence, Set from collections.abc import Callable, Sequence, Set
from typing import Any, Optional import dataclasses
from typing import Any, Optional, Tuple
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
@dataclasses.dataclass(frozen=True)
class RegistryGeneratorFunctionOutput:
layer_id: str
layer_vars: Optional[Sequence[tf.Variable]]
layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction]
layer_trainable_weights: Optional[Sequence[tf.Variable]]
def has_internal_compute_graph(input_object: Any): def has_internal_compute_graph(input_object: Any):
"""Checks if input is a TF model and has a TF internal compute graph.""" """Checks if input is a TF model and has a TF internal compute graph."""
return ( return (
@ -32,6 +41,63 @@ def has_internal_compute_graph(input_object: Any):
) )
def get_registry_generator_fn(
tape: tf.GradientTape,
layer_registry: lr.LayerRegistry,
num_microbatches: Optional[type_aliases.BatchSize] = None,
) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]:
"""Creates the generator function for `model_forward_backward_pass()`.
Args:
tape: The `tf.GradientTape` to use for the gradient computation.
layer_registry: A `dict` of layers that support "fast" gradient norm
computations. The key is the class of the layer and the value is a
function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
microbatches. If not None, indicates that the loss is grouped into
num_microbatches (in this case, the batch dimension needs to be a multiple
of num_microbatches).
Returns:
A function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable variables.
"""
if layer_registry is None:
# Needed for backwards compatibility.
registry_generator_fn = None
else:
def registry_generator_fn(layer_instance, args, kwargs):
if layer_instance.trainable_variables:
# Only trainable variables factor into the gradient.
if not layer_registry.is_elem(layer_instance):
raise NotImplementedError(
'Layer %s is not in the registry of known layers that can '
'be used for efficient gradient clipping.'
% layer_instance.__class__.__name__
)
registry_fn = layer_registry.lookup(layer_instance)
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
layer_instance, args, kwargs, tape, num_microbatches
)
return layer_outputs, RegistryGeneratorFunctionOutput(
layer_id=str(id(layer_instance)),
layer_vars=layer_vars,
layer_sqr_norm_fn=layer_sqr_norm_fn,
layer_trainable_weights=layer_instance.trainable_weights,
)
else:
# Non-trainable layer.
return layer_instance(*args, **kwargs), None
return registry_generator_fn
def model_forward_pass( def model_forward_pass(
input_model: tf.keras.Model, input_model: tf.keras.Model,
inputs: type_aliases.PackedTensors, inputs: type_aliases.PackedTensors,