Improve documentation and logging of fast gradient clipping modules and callers.
PiperOrigin-RevId: 513283486
This commit is contained in:
parent
d7cd3f8af1
commit
7436930c64
5 changed files with 129 additions and 42 deletions
|
@ -6,6 +6,7 @@ py_library(
|
|||
name = "gradient_clipping_utils",
|
||||
srcs = ["gradient_clipping_utils.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [":layer_registry"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
|
|
@ -21,11 +21,18 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
|
|||
`compute_gradient_norms()` function).
|
||||
"""
|
||||
|
||||
from typing import Union, Iterable, Text
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||
|
||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
|
||||
|
||||
|
||||
def get_registry_generator_fn(tape, layer_registry):
|
||||
def get_registry_generator_fn(
|
||||
tape: tf.GradientTape, layer_registry: lr.LayerRegistry
|
||||
):
|
||||
"""Creates the generator function for `compute_gradient_norms()`."""
|
||||
if layer_registry is None:
|
||||
# Needed for backwards compatibility.
|
||||
|
@ -53,7 +60,12 @@ def get_registry_generator_fn(tape, layer_registry):
|
|||
return registry_generator_fn
|
||||
|
||||
|
||||
def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
|
||||
def compute_gradient_norms(
|
||||
input_model: tf.keras.Model,
|
||||
x_batch: InputTensor,
|
||||
y_batch: tf.Tensor,
|
||||
layer_registry: lr.LayerRegistry,
|
||||
):
|
||||
"""Computes the per-example loss gradient norms for given data.
|
||||
|
||||
Applies a variant of the approach given in
|
||||
|
@ -62,7 +74,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
|
|||
Args:
|
||||
input_model: The `tf.keras.Model` from which to obtain the layers from. The
|
||||
loss of the model *must* be a scalar loss.
|
||||
x_batch: A `tf.Tensor` representing a batch of inputs to the model. The
|
||||
x_batch: An `InputTensor` representing a batch of inputs to the model. The
|
||||
first axis must be the batch dimension.
|
||||
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
||||
must be the batch dimension. The number of examples should match the
|
||||
|
@ -106,7 +118,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
|
|||
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
||||
|
||||
|
||||
def compute_clip_weights(l2_norm_clip, gradient_norms):
|
||||
def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
|
||||
"""Computes the per-example loss/clip weights for clipping.
|
||||
|
||||
When the sum of the per-example losses is replaced a weighted sum, where
|
||||
|
@ -132,7 +144,11 @@ def compute_clip_weights(l2_norm_clip, gradient_norms):
|
|||
|
||||
|
||||
def compute_pred_and_clipped_gradients(
|
||||
input_model, x_batch, y_batch, l2_norm_clip, layer_registry
|
||||
input_model: tf.keras.Model,
|
||||
x_batch: InputTensor,
|
||||
y_batch: tf.Tensor,
|
||||
l2_norm_clip: float,
|
||||
layer_registry: lr.LayerRegistry,
|
||||
):
|
||||
"""Computes the per-example predictions and per-example clipped loss gradient.
|
||||
|
||||
|
@ -147,7 +163,7 @@ def compute_pred_and_clipped_gradients(
|
|||
|
||||
Args:
|
||||
input_model: The `tf.keras.Model` from which to obtain the layers from.
|
||||
x_batch: A `tf.Tensor` representing a batch of inputs to the model. The
|
||||
x_batch: An `InputTensor` representing a batch of inputs to the model. The
|
||||
first axis must be the batch dimension.
|
||||
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
||||
must be the batch dimension. The number of examples should match the
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
from typing import Callable, Any, List, Union
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -20,23 +22,35 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
|||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Type aliases
|
||||
# ==============================================================================
|
||||
LayerGenerator = Callable[[int, int], tf.keras.layers.Layer]
|
||||
|
||||
ModelGenerator = Callable[
|
||||
[LayerGenerator, Union[int, List[int]], int], tf.keras.Model
|
||||
]
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Helper functions and classes.
|
||||
# ==============================================================================
|
||||
class DoubleDense(tf.keras.layers.Layer):
|
||||
"""Generates two dense layers nested together."""
|
||||
|
||||
def __init__(self, units):
|
||||
def __init__(self, units: int):
|
||||
super().__init__()
|
||||
self.dense1 = tf.keras.layers.Dense(units)
|
||||
self.dense2 = tf.keras.layers.Dense(1)
|
||||
|
||||
def call(self, inputs):
|
||||
def call(self, inputs: Any):
|
||||
x = self.dense1(inputs)
|
||||
return self.dense2(x)
|
||||
|
||||
|
||||
def double_dense_layer_computation(layer_instance, inputs, tape):
|
||||
def double_dense_layer_computation(
|
||||
layer_instance: tf.keras.layers.Layer, inputs: Any, tape: tf.GradientTape
|
||||
):
|
||||
"""Layer registry function for the custom `DoubleDense` layer class."""
|
||||
vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation(
|
||||
layer_instance.dense1, inputs, tape
|
||||
|
@ -53,7 +67,9 @@ def double_dense_layer_computation(layer_instance, inputs, tape):
|
|||
return [vars1, vars2], outputs, sqr_norm_fn
|
||||
|
||||
|
||||
def compute_true_gradient_norms(input_model, x_batch, y_batch):
|
||||
def compute_true_gradient_norms(
|
||||
input_model: tf.keras.Model, x_batch: tf.Tensor, y_batch: tf.Tensor
|
||||
):
|
||||
"""Computes the real gradient norms for an input `(model, x, y)`."""
|
||||
loss_config = input_model.loss.get_config()
|
||||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||
|
@ -73,14 +89,14 @@ def compute_true_gradient_norms(input_model, x_batch, y_batch):
|
|||
|
||||
|
||||
def get_computed_and_true_norms(
|
||||
model_generator,
|
||||
layer_generator,
|
||||
input_dims,
|
||||
output_dim,
|
||||
is_eager,
|
||||
x_input,
|
||||
rng_seed=777,
|
||||
registry=None,
|
||||
model_generator: ModelGenerator,
|
||||
layer_generator: LayerGenerator,
|
||||
input_dims: Union[int, List[int]],
|
||||
output_dim: int,
|
||||
is_eager: bool,
|
||||
x_input: tf.Tensor,
|
||||
rng_seed: int = 777,
|
||||
registry: layer_registry.LayerRegistry = None,
|
||||
):
|
||||
"""Obtains the true and computed gradient norms for a model and batch input.
|
||||
|
||||
|
@ -238,7 +254,7 @@ def make_weighted_bow_model(layer_generator, input_dims, output_dim):
|
|||
# ==============================================================================
|
||||
# Factory functions.
|
||||
# ==============================================================================
|
||||
def get_nd_test_tensors(n):
|
||||
def get_nd_test_tensors(n: int):
|
||||
"""Returns a list of candidate tests for a given dimension n."""
|
||||
return [
|
||||
tf.zeros((n,), dtype=tf.float64),
|
||||
|
@ -246,7 +262,7 @@ def get_nd_test_tensors(n):
|
|||
]
|
||||
|
||||
|
||||
def get_nd_test_batches(n):
|
||||
def get_nd_test_batches(n: int):
|
||||
"""Returns a list of candidate input batches of dimension n."""
|
||||
result = []
|
||||
tensors = get_nd_test_tensors(n)
|
||||
|
|
|
@ -13,11 +13,19 @@
|
|||
# limitations under the License.
|
||||
"""Utility functions that help in the computation of per-example gradient norms."""
|
||||
|
||||
from typing import Any, Union, Iterable, Text, Callable, Tuple, Optional
|
||||
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||
|
||||
def has_internal_compute_graph(input_object):
|
||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
|
||||
|
||||
GeneratorFunction = Optional[Callable[[Any, Tuple, dict], Tuple[Any, Any]]]
|
||||
|
||||
|
||||
def has_internal_compute_graph(input_object: Any):
|
||||
"""Checks if input is a TF model and has a TF internal compute graph."""
|
||||
return (
|
||||
isinstance(input_object, tf.keras.Model)
|
||||
|
@ -28,7 +36,9 @@ def has_internal_compute_graph(input_object):
|
|||
)
|
||||
|
||||
|
||||
def _get_internal_layers(input_layer):
|
||||
def _get_internal_layers(
|
||||
input_layer: tf.keras.layers.Layer,
|
||||
) -> list[tf.keras.layers.Layer]:
|
||||
"""Returns a list of layers that are nested within a given layer."""
|
||||
internal_layers = []
|
||||
if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'):
|
||||
|
@ -39,7 +49,11 @@ def _get_internal_layers(input_layer):
|
|||
return internal_layers
|
||||
|
||||
|
||||
def model_forward_pass(input_model, inputs, generator_fn=None):
|
||||
def model_forward_pass(
|
||||
input_model: tf.keras.Model,
|
||||
inputs: InputTensor,
|
||||
generator_fn: GeneratorFunction = None,
|
||||
) -> Tuple[tf.Tensor, list[Any]]:
|
||||
"""Does a forward pass of a model and returns useful intermediates.
|
||||
|
||||
NOTE: the graph traversal algorithm is an adaptation of the logic in the
|
||||
|
@ -118,7 +132,9 @@ def model_forward_pass(input_model, inputs, generator_fn=None):
|
|||
return node_layer_outputs, generator_outputs_list
|
||||
|
||||
|
||||
def all_trainable_layers_are_registered(input_model, layer_registry):
|
||||
def all_trainable_layers_are_registered(
|
||||
input_model: tf.keras.Model, layer_registry: lr.LayerRegistry
|
||||
) -> bool:
|
||||
"""Check if an input model's trainable layers are all registered.
|
||||
|
||||
Args:
|
||||
|
@ -140,18 +156,21 @@ def all_trainable_layers_are_registered(input_model, layer_registry):
|
|||
|
||||
|
||||
def add_aggregate_noise(
|
||||
input_model, x_batch, clipped_grads, l2_norm_clip, noise_multiplier
|
||||
):
|
||||
input_model: tf.keras.Model,
|
||||
x_batch: InputTensor,
|
||||
clipped_grads: list[tf.Tensor],
|
||||
l2_norm_clip: float,
|
||||
noise_multiplier: float,
|
||||
) -> list[tf.Tensor]:
|
||||
"""Adds noise to a collection of clipped gradients.
|
||||
|
||||
The magnitude of the noise depends on the aggregation strategy of the
|
||||
input model's loss function.
|
||||
|
||||
Args:
|
||||
input_model: The Keras model to obtain the layers from.
|
||||
x_batch: A collection of Tensors to be fed into the input layer of the
|
||||
model.
|
||||
clipped_grads: A list of tensors representing the clipped gradients.
|
||||
input_model: The `tf.keras.Model` to obtain the layers from.
|
||||
x_batch: An `InputTensor` to be fed into the input layer of the model.
|
||||
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
|
||||
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
|
||||
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
||||
|
||||
|
@ -187,7 +206,9 @@ def add_aggregate_noise(
|
|||
return tf.nest.map_structure(add_noise, clipped_grads)
|
||||
|
||||
|
||||
def generate_model_outputs_using_core_keras_layers(input_model):
|
||||
def generate_model_outputs_using_core_keras_layers(
|
||||
input_model: tf.keras.Model,
|
||||
) -> tf.Tensor:
|
||||
"""Returns the model outputs generated by only core Keras layers."""
|
||||
cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects())
|
||||
cust_hash_set = set([hash(v) for v in cust_obj_dict.values()])
|
||||
|
|
|
@ -40,9 +40,24 @@ where `l2_row_norm(y)` computes the L2 norm for each row of an input `y`.
|
|||
Details of this decomposition can be found in https://arxiv.org/abs/1510.01799
|
||||
"""
|
||||
|
||||
from typing import Callable, Type, Any, Union, Iterable, Text
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Type aliases
|
||||
# ==============================================================================
|
||||
SquareNormFunction = Callable[[Any], tf.Tensor]
|
||||
|
||||
RegistryFunctionOutput = tuple[Any, tf.Tensor, SquareNormFunction]
|
||||
|
||||
RegistryFunction = Callable[
|
||||
[Any, tuple[Any], tf.GradientTape], RegistryFunctionOutput
|
||||
]
|
||||
|
||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main class
|
||||
# ==============================================================================
|
||||
|
@ -54,15 +69,19 @@ class LayerRegistry:
|
|||
self._layer_class_dict = {}
|
||||
self._registry = {}
|
||||
|
||||
def is_elem(self, layer_instance):
|
||||
def is_elem(self, layer_instance: tf.keras.layers.Layer) -> bool:
|
||||
"""Checks if a layer instance's class is in the registry."""
|
||||
return hash(layer_instance.__class__) in self._registry
|
||||
|
||||
def lookup(self, layer_instance):
|
||||
def lookup(self, layer_instance: tf.keras.layers.Layer) -> RegistryFunction:
|
||||
"""Returns the layer registry function for a given layer instance."""
|
||||
return self._registry[hash(layer_instance.__class__)]
|
||||
|
||||
def insert(self, layer_class, layer_registry_function):
|
||||
def insert(
|
||||
self,
|
||||
layer_class: Type[tf.keras.layers.Layer],
|
||||
layer_registry_function: RegistryFunction,
|
||||
):
|
||||
"""Inserts a layer registry function into the internal dictionaries."""
|
||||
layer_key = hash(layer_class)
|
||||
self._layer_class_dict[layer_key] = layer_class
|
||||
|
@ -72,7 +91,11 @@ class LayerRegistry:
|
|||
# ==============================================================================
|
||||
# Supported Keras layers
|
||||
# ==============================================================================
|
||||
def dense_layer_computation(layer_instance, inputs, tape):
|
||||
def dense_layer_computation(
|
||||
layer_instance: tf.keras.layers.Dense,
|
||||
inputs: tuple[InputTensor],
|
||||
tape: tf.GradientTape,
|
||||
) -> RegistryFunctionOutput:
|
||||
"""Registry function for `tf.keras.layers.Dense`.
|
||||
|
||||
The logic for this computation is based on the following paper:
|
||||
|
@ -83,8 +106,9 @@ def dense_layer_computation(layer_instance, inputs, tape):
|
|||
|
||||
Args:
|
||||
layer_instance: A `tf.keras.layers.Dense` instance.
|
||||
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
||||
`layer_instance(inputs)` returns a valid output.
|
||||
inputs: A tuple containing a single `InputTensor` which can be passed into
|
||||
the layer instance, i.e., `layer_instance(*inputs)` returns a valid
|
||||
output.
|
||||
tape: A `tf.GradientTape` instance that will be used to watch the output
|
||||
`base_vars`.
|
||||
|
||||
|
@ -100,6 +124,8 @@ def dense_layer_computation(layer_instance, inputs, tape):
|
|||
trainable variables in `layer_instance`. These squared norms should be a 1D
|
||||
`tf.Tensor` of length `batch_size`.
|
||||
"""
|
||||
if len(inputs) != 1:
|
||||
raise ValueError("Only layer inputs of length 1 are permitted.")
|
||||
orig_activation = layer_instance.activation
|
||||
layer_instance.activation = None
|
||||
base_vars = layer_instance(*inputs)
|
||||
|
@ -125,7 +151,11 @@ def dense_layer_computation(layer_instance, inputs, tape):
|
|||
return base_vars, outputs, sqr_norm_fn
|
||||
|
||||
|
||||
def embedding_layer_computation(layer_instance, inputs, tape):
|
||||
def embedding_layer_computation(
|
||||
layer_instance: tf.keras.layers.Embedding,
|
||||
inputs: tuple[InputTensor],
|
||||
tape: tf.GradientTape,
|
||||
) -> RegistryFunctionOutput:
|
||||
"""Registry function for `tf.keras.layers.Embedding`.
|
||||
|
||||
The logic of this computation is based on the `tf.keras.layers.Dense`
|
||||
|
@ -136,8 +166,9 @@ def embedding_layer_computation(layer_instance, inputs, tape):
|
|||
|
||||
Args:
|
||||
layer_instance: A `tf.keras.layers.Embedding` instance.
|
||||
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
||||
`layer_instance(inputs)` returns a valid output.
|
||||
inputs: A tuple containing a single `InputTensor` which can be passed into
|
||||
the layer instance, i.e., `layer_instance(*inputs)` returns a valid
|
||||
output.
|
||||
tape: A `tf.GradientTape` instance that will be used to watch the output
|
||||
`base_vars`.
|
||||
|
||||
|
@ -153,10 +184,12 @@ def embedding_layer_computation(layer_instance, inputs, tape):
|
|||
trainable variables in `layer_instance`. These squared norms should be a 1D
|
||||
`tf.Tensor` of length `batch_size`.
|
||||
"""
|
||||
if len(inputs) != 1:
|
||||
raise ValueError("Only layer inputs of length 1 are permitted.")
|
||||
if hasattr(layer_instance, "sparse"): # for backwards compatibility
|
||||
if layer_instance.sparse:
|
||||
raise NotImplementedError("Sparse output tensors are not supported.")
|
||||
if isinstance(inputs, tf.SparseTensor):
|
||||
if isinstance(inputs[0], tf.SparseTensor):
|
||||
raise NotImplementedError("Sparse input tensors are not supported.")
|
||||
|
||||
# Disable experimental features.
|
||||
|
@ -225,7 +258,7 @@ def embedding_layer_computation(layer_instance, inputs, tape):
|
|||
# ==============================================================================
|
||||
# Main factory methods
|
||||
# ==============================================================================
|
||||
def make_default_layer_registry():
|
||||
def make_default_layer_registry() -> LayerRegistry:
|
||||
registry = LayerRegistry()
|
||||
registry.insert(tf.keras.layers.Dense, dense_layer_computation)
|
||||
registry.insert(tf.keras.layers.Embedding, embedding_layer_computation)
|
||||
|
|
Loading…
Reference in a new issue