forked from 626_privacy/tensorflow_privacy
Sparsity Preserving DP-SGD in TF Privacy
Move get_registry_generator_fn from clip_grads.py to gradient_clipping_utils.py and change return type to dataclass. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 660548431
This commit is contained in:
parent
d3f527e775
commit
8294cec132
2 changed files with 78 additions and 44 deletions
|
@ -32,43 +32,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as
|
|||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||
|
||||
|
||||
def get_registry_generator_fn(
|
||||
tape: tf.GradientTape,
|
||||
layer_registry: lr.LayerRegistry,
|
||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||
):
|
||||
"""Creates the generator function for `compute_gradient_norms()`."""
|
||||
if layer_registry is None:
|
||||
# Needed for backwards compatibility.
|
||||
registry_generator_fn = None
|
||||
else:
|
||||
|
||||
def registry_generator_fn(layer_instance, args, kwargs):
|
||||
if layer_instance.trainable_variables:
|
||||
# Only trainable variables factor into the gradient.
|
||||
if not layer_registry.is_elem(layer_instance):
|
||||
raise NotImplementedError(
|
||||
'Layer %s is not in the registry of known layers that can '
|
||||
'be used for efficient gradient clipping.'
|
||||
% layer_instance.__class__.__name__
|
||||
)
|
||||
registry_fn = layer_registry.lookup(layer_instance)
|
||||
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
|
||||
layer_instance, args, kwargs, tape, num_microbatches
|
||||
)
|
||||
return layer_outputs, (
|
||||
str(id(layer_instance)),
|
||||
layer_vars,
|
||||
layer_sqr_norm_fn,
|
||||
layer_instance.trainable_weights,
|
||||
)
|
||||
else:
|
||||
# Non-trainable layer.
|
||||
return layer_instance(*args, **kwargs), None
|
||||
|
||||
return registry_generator_fn
|
||||
|
||||
|
||||
def _infer_per_example_loss_fn(model: tf.keras.Model):
|
||||
"""Infer the per-example loss from model config."""
|
||||
|
||||
|
@ -190,7 +153,7 @@ def compute_gradient_norms(
|
|||
are applied prior to clipping.
|
||||
"""
|
||||
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
||||
registry_generator_fn = get_registry_generator_fn(
|
||||
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
|
||||
tape, layer_registry, num_microbatches
|
||||
)
|
||||
# First loop computes the model outputs, summed loss, and generator outputs.
|
||||
|
@ -241,12 +204,17 @@ def compute_gradient_norms(
|
|||
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
|
||||
# occurrence. This is an over-estimate of the actual norm. For more details,
|
||||
# see the explanation in go/dp-sgd-shared-weights.
|
||||
for layer_id, v, f, weights_list in filtered_outputs:
|
||||
for registry_fn_output in filtered_outputs:
|
||||
if trainable_vars is None or any(
|
||||
w.ref() in trainable_vars for w in weights_list
|
||||
w.ref() in trainable_vars
|
||||
for w in registry_fn_output.layer_trainable_weights
|
||||
):
|
||||
layer_vars[layer_id].append(v)
|
||||
layer_sqr_norm_fns[layer_id].append(f)
|
||||
layer_vars[registry_fn_output.layer_id].append(
|
||||
registry_fn_output.layer_vars
|
||||
)
|
||||
layer_sqr_norm_fns[registry_fn_output.layer_id].append(
|
||||
registry_fn_output.layer_sqr_norm_fn
|
||||
)
|
||||
# Second loop evaluates the squared L2 norm functions and appends the results.
|
||||
layer_grad_vars = tape.gradient(
|
||||
summed_loss,
|
||||
|
|
|
@ -13,14 +13,23 @@
|
|||
# limitations under the License.
|
||||
"""Utility functions that help in the computation of per-example gradient norms."""
|
||||
|
||||
from collections.abc import Sequence, Set
|
||||
from typing import Any, Optional
|
||||
from collections.abc import Callable, Sequence, Set
|
||||
import dataclasses
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RegistryGeneratorFunctionOutput:
|
||||
layer_id: str
|
||||
layer_vars: Optional[Sequence[tf.Variable]]
|
||||
layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction]
|
||||
layer_trainable_weights: Optional[Sequence[tf.Variable]]
|
||||
|
||||
|
||||
def has_internal_compute_graph(input_object: Any):
|
||||
"""Checks if input is a TF model and has a TF internal compute graph."""
|
||||
return (
|
||||
|
@ -32,6 +41,63 @@ def has_internal_compute_graph(input_object: Any):
|
|||
)
|
||||
|
||||
|
||||
def get_registry_generator_fn(
|
||||
tape: tf.GradientTape,
|
||||
layer_registry: lr.LayerRegistry,
|
||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||
) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]:
|
||||
"""Creates the generator function for `model_forward_backward_pass()`.
|
||||
|
||||
Args:
|
||||
tape: The `tf.GradientTape` to use for the gradient computation.
|
||||
layer_registry: A `dict` of layers that support "fast" gradient norm
|
||||
computations. The key is the class of the layer and the value is a
|
||||
function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
|
||||
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
|
||||
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
||||
trainable
|
||||
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
|
||||
microbatches. If not None, indicates that the loss is grouped into
|
||||
num_microbatches (in this case, the batch dimension needs to be a multiple
|
||||
of num_microbatches).
|
||||
|
||||
Returns:
|
||||
A function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
|
||||
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
|
||||
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
||||
trainable variables.
|
||||
"""
|
||||
if layer_registry is None:
|
||||
# Needed for backwards compatibility.
|
||||
registry_generator_fn = None
|
||||
else:
|
||||
|
||||
def registry_generator_fn(layer_instance, args, kwargs):
|
||||
if layer_instance.trainable_variables:
|
||||
# Only trainable variables factor into the gradient.
|
||||
if not layer_registry.is_elem(layer_instance):
|
||||
raise NotImplementedError(
|
||||
'Layer %s is not in the registry of known layers that can '
|
||||
'be used for efficient gradient clipping.'
|
||||
% layer_instance.__class__.__name__
|
||||
)
|
||||
registry_fn = layer_registry.lookup(layer_instance)
|
||||
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
|
||||
layer_instance, args, kwargs, tape, num_microbatches
|
||||
)
|
||||
return layer_outputs, RegistryGeneratorFunctionOutput(
|
||||
layer_id=str(id(layer_instance)),
|
||||
layer_vars=layer_vars,
|
||||
layer_sqr_norm_fn=layer_sqr_norm_fn,
|
||||
layer_trainable_weights=layer_instance.trainable_weights,
|
||||
)
|
||||
else:
|
||||
# Non-trainable layer.
|
||||
return layer_instance(*args, **kwargs), None
|
||||
|
||||
return registry_generator_fn
|
||||
|
||||
|
||||
def model_forward_pass(
|
||||
input_model: tf.keras.Model,
|
||||
inputs: type_aliases.PackedTensors,
|
||||
|
|
Loading…
Reference in a new issue