forked from 626_privacy/tensorflow_privacy
Sparsity Preserving DP-SGD in TF Privacy
Move get_registry_generator_fn from clip_grads.py to gradient_clipping_utils.py and change return type to dataclass. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 660548431
This commit is contained in:
parent
d3f527e775
commit
8294cec132
2 changed files with 78 additions and 44 deletions
|
@ -32,43 +32,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
|
||||||
|
|
||||||
def get_registry_generator_fn(
|
|
||||||
tape: tf.GradientTape,
|
|
||||||
layer_registry: lr.LayerRegistry,
|
|
||||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
|
||||||
):
|
|
||||||
"""Creates the generator function for `compute_gradient_norms()`."""
|
|
||||||
if layer_registry is None:
|
|
||||||
# Needed for backwards compatibility.
|
|
||||||
registry_generator_fn = None
|
|
||||||
else:
|
|
||||||
|
|
||||||
def registry_generator_fn(layer_instance, args, kwargs):
|
|
||||||
if layer_instance.trainable_variables:
|
|
||||||
# Only trainable variables factor into the gradient.
|
|
||||||
if not layer_registry.is_elem(layer_instance):
|
|
||||||
raise NotImplementedError(
|
|
||||||
'Layer %s is not in the registry of known layers that can '
|
|
||||||
'be used for efficient gradient clipping.'
|
|
||||||
% layer_instance.__class__.__name__
|
|
||||||
)
|
|
||||||
registry_fn = layer_registry.lookup(layer_instance)
|
|
||||||
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
|
|
||||||
layer_instance, args, kwargs, tape, num_microbatches
|
|
||||||
)
|
|
||||||
return layer_outputs, (
|
|
||||||
str(id(layer_instance)),
|
|
||||||
layer_vars,
|
|
||||||
layer_sqr_norm_fn,
|
|
||||||
layer_instance.trainable_weights,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Non-trainable layer.
|
|
||||||
return layer_instance(*args, **kwargs), None
|
|
||||||
|
|
||||||
return registry_generator_fn
|
|
||||||
|
|
||||||
|
|
||||||
def _infer_per_example_loss_fn(model: tf.keras.Model):
|
def _infer_per_example_loss_fn(model: tf.keras.Model):
|
||||||
"""Infer the per-example loss from model config."""
|
"""Infer the per-example loss from model config."""
|
||||||
|
|
||||||
|
@ -190,7 +153,7 @@ def compute_gradient_norms(
|
||||||
are applied prior to clipping.
|
are applied prior to clipping.
|
||||||
"""
|
"""
|
||||||
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
||||||
registry_generator_fn = get_registry_generator_fn(
|
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
|
||||||
tape, layer_registry, num_microbatches
|
tape, layer_registry, num_microbatches
|
||||||
)
|
)
|
||||||
# First loop computes the model outputs, summed loss, and generator outputs.
|
# First loop computes the model outputs, summed loss, and generator outputs.
|
||||||
|
@ -241,12 +204,17 @@ def compute_gradient_norms(
|
||||||
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
|
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
|
||||||
# occurrence. This is an over-estimate of the actual norm. For more details,
|
# occurrence. This is an over-estimate of the actual norm. For more details,
|
||||||
# see the explanation in go/dp-sgd-shared-weights.
|
# see the explanation in go/dp-sgd-shared-weights.
|
||||||
for layer_id, v, f, weights_list in filtered_outputs:
|
for registry_fn_output in filtered_outputs:
|
||||||
if trainable_vars is None or any(
|
if trainable_vars is None or any(
|
||||||
w.ref() in trainable_vars for w in weights_list
|
w.ref() in trainable_vars
|
||||||
|
for w in registry_fn_output.layer_trainable_weights
|
||||||
):
|
):
|
||||||
layer_vars[layer_id].append(v)
|
layer_vars[registry_fn_output.layer_id].append(
|
||||||
layer_sqr_norm_fns[layer_id].append(f)
|
registry_fn_output.layer_vars
|
||||||
|
)
|
||||||
|
layer_sqr_norm_fns[registry_fn_output.layer_id].append(
|
||||||
|
registry_fn_output.layer_sqr_norm_fn
|
||||||
|
)
|
||||||
# Second loop evaluates the squared L2 norm functions and appends the results.
|
# Second loop evaluates the squared L2 norm functions and appends the results.
|
||||||
layer_grad_vars = tape.gradient(
|
layer_grad_vars = tape.gradient(
|
||||||
summed_loss,
|
summed_loss,
|
||||||
|
|
|
@ -13,14 +13,23 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Utility functions that help in the computation of per-example gradient norms."""
|
"""Utility functions that help in the computation of per-example gradient norms."""
|
||||||
|
|
||||||
from collections.abc import Sequence, Set
|
from collections.abc import Callable, Sequence, Set
|
||||||
from typing import Any, Optional
|
import dataclasses
|
||||||
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class RegistryGeneratorFunctionOutput:
|
||||||
|
layer_id: str
|
||||||
|
layer_vars: Optional[Sequence[tf.Variable]]
|
||||||
|
layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction]
|
||||||
|
layer_trainable_weights: Optional[Sequence[tf.Variable]]
|
||||||
|
|
||||||
|
|
||||||
def has_internal_compute_graph(input_object: Any):
|
def has_internal_compute_graph(input_object: Any):
|
||||||
"""Checks if input is a TF model and has a TF internal compute graph."""
|
"""Checks if input is a TF model and has a TF internal compute graph."""
|
||||||
return (
|
return (
|
||||||
|
@ -32,6 +41,63 @@ def has_internal_compute_graph(input_object: Any):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_registry_generator_fn(
|
||||||
|
tape: tf.GradientTape,
|
||||||
|
layer_registry: lr.LayerRegistry,
|
||||||
|
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||||
|
) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]:
|
||||||
|
"""Creates the generator function for `model_forward_backward_pass()`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tape: The `tf.GradientTape` to use for the gradient computation.
|
||||||
|
layer_registry: A `dict` of layers that support "fast" gradient norm
|
||||||
|
computations. The key is the class of the layer and the value is a
|
||||||
|
function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
|
||||||
|
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
|
||||||
|
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
||||||
|
trainable
|
||||||
|
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
|
||||||
|
microbatches. If not None, indicates that the loss is grouped into
|
||||||
|
num_microbatches (in this case, the batch dimension needs to be a multiple
|
||||||
|
of num_microbatches).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
|
||||||
|
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
|
||||||
|
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
||||||
|
trainable variables.
|
||||||
|
"""
|
||||||
|
if layer_registry is None:
|
||||||
|
# Needed for backwards compatibility.
|
||||||
|
registry_generator_fn = None
|
||||||
|
else:
|
||||||
|
|
||||||
|
def registry_generator_fn(layer_instance, args, kwargs):
|
||||||
|
if layer_instance.trainable_variables:
|
||||||
|
# Only trainable variables factor into the gradient.
|
||||||
|
if not layer_registry.is_elem(layer_instance):
|
||||||
|
raise NotImplementedError(
|
||||||
|
'Layer %s is not in the registry of known layers that can '
|
||||||
|
'be used for efficient gradient clipping.'
|
||||||
|
% layer_instance.__class__.__name__
|
||||||
|
)
|
||||||
|
registry_fn = layer_registry.lookup(layer_instance)
|
||||||
|
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
|
||||||
|
layer_instance, args, kwargs, tape, num_microbatches
|
||||||
|
)
|
||||||
|
return layer_outputs, RegistryGeneratorFunctionOutput(
|
||||||
|
layer_id=str(id(layer_instance)),
|
||||||
|
layer_vars=layer_vars,
|
||||||
|
layer_sqr_norm_fn=layer_sqr_norm_fn,
|
||||||
|
layer_trainable_weights=layer_instance.trainable_weights,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Non-trainable layer.
|
||||||
|
return layer_instance(*args, **kwargs), None
|
||||||
|
|
||||||
|
return registry_generator_fn
|
||||||
|
|
||||||
|
|
||||||
def model_forward_pass(
|
def model_forward_pass(
|
||||||
input_model: tf.keras.Model,
|
input_model: tf.keras.Model,
|
||||||
inputs: type_aliases.PackedTensors,
|
inputs: type_aliases.PackedTensors,
|
||||||
|
|
Loading…
Reference in a new issue