Improve documentation and logging of fast gradient clipping modules and callers.

PiperOrigin-RevId: 513283486
This commit is contained in:
A. Unique TensorFlower 2023-03-01 10:55:23 -08:00
parent d7cd3f8af1
commit 7436930c64
5 changed files with 129 additions and 42 deletions

View file

@ -6,6 +6,7 @@ py_library(
name = "gradient_clipping_utils", name = "gradient_clipping_utils",
srcs = ["gradient_clipping_utils.py"], srcs = ["gradient_clipping_utils.py"],
srcs_version = "PY3", srcs_version = "PY3",
deps = [":layer_registry"],
) )
py_library( py_library(

View file

@ -21,11 +21,18 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
`compute_gradient_norms()` function). `compute_gradient_norms()` function).
""" """
from typing import Union, Iterable, Text
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
def get_registry_generator_fn(tape, layer_registry): def get_registry_generator_fn(
tape: tf.GradientTape, layer_registry: lr.LayerRegistry
):
"""Creates the generator function for `compute_gradient_norms()`.""" """Creates the generator function for `compute_gradient_norms()`."""
if layer_registry is None: if layer_registry is None:
# Needed for backwards compatibility. # Needed for backwards compatibility.
@ -53,7 +60,12 @@ def get_registry_generator_fn(tape, layer_registry):
return registry_generator_fn return registry_generator_fn
def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry): def compute_gradient_norms(
input_model: tf.keras.Model,
x_batch: InputTensor,
y_batch: tf.Tensor,
layer_registry: lr.LayerRegistry,
):
"""Computes the per-example loss gradient norms for given data. """Computes the per-example loss gradient norms for given data.
Applies a variant of the approach given in Applies a variant of the approach given in
@ -62,7 +74,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
Args: Args:
input_model: The `tf.keras.Model` from which to obtain the layers from. The input_model: The `tf.keras.Model` from which to obtain the layers from. The
loss of the model *must* be a scalar loss. loss of the model *must* be a scalar loss.
x_batch: A `tf.Tensor` representing a batch of inputs to the model. The x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension. first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the must be the batch dimension. The number of examples should match the
@ -106,7 +118,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
def compute_clip_weights(l2_norm_clip, gradient_norms): def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
"""Computes the per-example loss/clip weights for clipping. """Computes the per-example loss/clip weights for clipping.
When the sum of the per-example losses is replaced a weighted sum, where When the sum of the per-example losses is replaced a weighted sum, where
@ -132,7 +144,11 @@ def compute_clip_weights(l2_norm_clip, gradient_norms):
def compute_pred_and_clipped_gradients( def compute_pred_and_clipped_gradients(
input_model, x_batch, y_batch, l2_norm_clip, layer_registry input_model: tf.keras.Model,
x_batch: InputTensor,
y_batch: tf.Tensor,
l2_norm_clip: float,
layer_registry: lr.LayerRegistry,
): ):
"""Computes the per-example predictions and per-example clipped loss gradient. """Computes the per-example predictions and per-example clipped loss gradient.
@ -147,7 +163,7 @@ def compute_pred_and_clipped_gradients(
Args: Args:
input_model: The `tf.keras.Model` from which to obtain the layers from. input_model: The `tf.keras.Model` from which to obtain the layers from.
x_batch: A `tf.Tensor` representing a batch of inputs to the model. The x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension. first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the must be the batch dimension. The number of examples should match the

View file

@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
import itertools import itertools
from typing import Callable, Any, List, Union
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
@ -20,23 +22,35 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
# ==============================================================================
# Type aliases
# ==============================================================================
LayerGenerator = Callable[[int, int], tf.keras.layers.Layer]
ModelGenerator = Callable[
[LayerGenerator, Union[int, List[int]], int], tf.keras.Model
]
# ============================================================================== # ==============================================================================
# Helper functions and classes. # Helper functions and classes.
# ============================================================================== # ==============================================================================
class DoubleDense(tf.keras.layers.Layer): class DoubleDense(tf.keras.layers.Layer):
"""Generates two dense layers nested together.""" """Generates two dense layers nested together."""
def __init__(self, units): def __init__(self, units: int):
super().__init__() super().__init__()
self.dense1 = tf.keras.layers.Dense(units) self.dense1 = tf.keras.layers.Dense(units)
self.dense2 = tf.keras.layers.Dense(1) self.dense2 = tf.keras.layers.Dense(1)
def call(self, inputs): def call(self, inputs: Any):
x = self.dense1(inputs) x = self.dense1(inputs)
return self.dense2(x) return self.dense2(x)
def double_dense_layer_computation(layer_instance, inputs, tape): def double_dense_layer_computation(
layer_instance: tf.keras.layers.Layer, inputs: Any, tape: tf.GradientTape
):
"""Layer registry function for the custom `DoubleDense` layer class.""" """Layer registry function for the custom `DoubleDense` layer class."""
vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation( vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation(
layer_instance.dense1, inputs, tape layer_instance.dense1, inputs, tape
@ -53,7 +67,9 @@ def double_dense_layer_computation(layer_instance, inputs, tape):
return [vars1, vars2], outputs, sqr_norm_fn return [vars1, vars2], outputs, sqr_norm_fn
def compute_true_gradient_norms(input_model, x_batch, y_batch): def compute_true_gradient_norms(
input_model: tf.keras.Model, x_batch: tf.Tensor, y_batch: tf.Tensor
):
"""Computes the real gradient norms for an input `(model, x, y)`.""" """Computes the real gradient norms for an input `(model, x, y)`."""
loss_config = input_model.loss.get_config() loss_config = input_model.loss.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE loss_config['reduction'] = tf.keras.losses.Reduction.NONE
@ -73,14 +89,14 @@ def compute_true_gradient_norms(input_model, x_batch, y_batch):
def get_computed_and_true_norms( def get_computed_and_true_norms(
model_generator, model_generator: ModelGenerator,
layer_generator, layer_generator: LayerGenerator,
input_dims, input_dims: Union[int, List[int]],
output_dim, output_dim: int,
is_eager, is_eager: bool,
x_input, x_input: tf.Tensor,
rng_seed=777, rng_seed: int = 777,
registry=None, registry: layer_registry.LayerRegistry = None,
): ):
"""Obtains the true and computed gradient norms for a model and batch input. """Obtains the true and computed gradient norms for a model and batch input.
@ -238,7 +254,7 @@ def make_weighted_bow_model(layer_generator, input_dims, output_dim):
# ============================================================================== # ==============================================================================
# Factory functions. # Factory functions.
# ============================================================================== # ==============================================================================
def get_nd_test_tensors(n): def get_nd_test_tensors(n: int):
"""Returns a list of candidate tests for a given dimension n.""" """Returns a list of candidate tests for a given dimension n."""
return [ return [
tf.zeros((n,), dtype=tf.float64), tf.zeros((n,), dtype=tf.float64),
@ -246,7 +262,7 @@ def get_nd_test_tensors(n):
] ]
def get_nd_test_batches(n): def get_nd_test_batches(n: int):
"""Returns a list of candidate input batches of dimension n.""" """Returns a list of candidate input batches of dimension n."""
result = [] result = []
tensors = get_nd_test_tensors(n) tensors = get_nd_test_tensors(n)

View file

@ -13,11 +13,19 @@
# limitations under the License. # limitations under the License.
"""Utility functions that help in the computation of per-example gradient norms.""" """Utility functions that help in the computation of per-example gradient norms."""
from typing import Any, Union, Iterable, Text, Callable, Tuple, Optional
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
def has_internal_compute_graph(input_object): InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
GeneratorFunction = Optional[Callable[[Any, Tuple, dict], Tuple[Any, Any]]]
def has_internal_compute_graph(input_object: Any):
"""Checks if input is a TF model and has a TF internal compute graph.""" """Checks if input is a TF model and has a TF internal compute graph."""
return ( return (
isinstance(input_object, tf.keras.Model) isinstance(input_object, tf.keras.Model)
@ -28,7 +36,9 @@ def has_internal_compute_graph(input_object):
) )
def _get_internal_layers(input_layer): def _get_internal_layers(
input_layer: tf.keras.layers.Layer,
) -> list[tf.keras.layers.Layer]:
"""Returns a list of layers that are nested within a given layer.""" """Returns a list of layers that are nested within a given layer."""
internal_layers = [] internal_layers = []
if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'): if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'):
@ -39,7 +49,11 @@ def _get_internal_layers(input_layer):
return internal_layers return internal_layers
def model_forward_pass(input_model, inputs, generator_fn=None): def model_forward_pass(
input_model: tf.keras.Model,
inputs: InputTensor,
generator_fn: GeneratorFunction = None,
) -> Tuple[tf.Tensor, list[Any]]:
"""Does a forward pass of a model and returns useful intermediates. """Does a forward pass of a model and returns useful intermediates.
NOTE: the graph traversal algorithm is an adaptation of the logic in the NOTE: the graph traversal algorithm is an adaptation of the logic in the
@ -118,7 +132,9 @@ def model_forward_pass(input_model, inputs, generator_fn=None):
return node_layer_outputs, generator_outputs_list return node_layer_outputs, generator_outputs_list
def all_trainable_layers_are_registered(input_model, layer_registry): def all_trainable_layers_are_registered(
input_model: tf.keras.Model, layer_registry: lr.LayerRegistry
) -> bool:
"""Check if an input model's trainable layers are all registered. """Check if an input model's trainable layers are all registered.
Args: Args:
@ -140,18 +156,21 @@ def all_trainable_layers_are_registered(input_model, layer_registry):
def add_aggregate_noise( def add_aggregate_noise(
input_model, x_batch, clipped_grads, l2_norm_clip, noise_multiplier input_model: tf.keras.Model,
): x_batch: InputTensor,
clipped_grads: list[tf.Tensor],
l2_norm_clip: float,
noise_multiplier: float,
) -> list[tf.Tensor]:
"""Adds noise to a collection of clipped gradients. """Adds noise to a collection of clipped gradients.
The magnitude of the noise depends on the aggregation strategy of the The magnitude of the noise depends on the aggregation strategy of the
input model's loss function. input model's loss function.
Args: Args:
input_model: The Keras model to obtain the layers from. input_model: The `tf.keras.Model` to obtain the layers from.
x_batch: A collection of Tensors to be fed into the input layer of the x_batch: An `InputTensor` to be fed into the input layer of the model.
model. clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
clipped_grads: A list of tensors representing the clipped gradients.
l2_norm_clip: Clipping norm (max L2 norm of each gradient). l2_norm_clip: Clipping norm (max L2 norm of each gradient).
noise_multiplier: Ratio of the standard deviation to the clipping norm. noise_multiplier: Ratio of the standard deviation to the clipping norm.
@ -187,7 +206,9 @@ def add_aggregate_noise(
return tf.nest.map_structure(add_noise, clipped_grads) return tf.nest.map_structure(add_noise, clipped_grads)
def generate_model_outputs_using_core_keras_layers(input_model): def generate_model_outputs_using_core_keras_layers(
input_model: tf.keras.Model,
) -> tf.Tensor:
"""Returns the model outputs generated by only core Keras layers.""" """Returns the model outputs generated by only core Keras layers."""
cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects()) cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects())
cust_hash_set = set([hash(v) for v in cust_obj_dict.values()]) cust_hash_set = set([hash(v) for v in cust_obj_dict.values()])

View file

@ -40,9 +40,24 @@ where `l2_row_norm(y)` computes the L2 norm for each row of an input `y`.
Details of this decomposition can be found in https://arxiv.org/abs/1510.01799 Details of this decomposition can be found in https://arxiv.org/abs/1510.01799
""" """
from typing import Callable, Type, Any, Union, Iterable, Text
import tensorflow as tf import tensorflow as tf
# ==============================================================================
# Type aliases
# ==============================================================================
SquareNormFunction = Callable[[Any], tf.Tensor]
RegistryFunctionOutput = tuple[Any, tf.Tensor, SquareNormFunction]
RegistryFunction = Callable[
[Any, tuple[Any], tf.GradientTape], RegistryFunctionOutput
]
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
# ============================================================================== # ==============================================================================
# Main class # Main class
# ============================================================================== # ==============================================================================
@ -54,15 +69,19 @@ class LayerRegistry:
self._layer_class_dict = {} self._layer_class_dict = {}
self._registry = {} self._registry = {}
def is_elem(self, layer_instance): def is_elem(self, layer_instance: tf.keras.layers.Layer) -> bool:
"""Checks if a layer instance's class is in the registry.""" """Checks if a layer instance's class is in the registry."""
return hash(layer_instance.__class__) in self._registry return hash(layer_instance.__class__) in self._registry
def lookup(self, layer_instance): def lookup(self, layer_instance: tf.keras.layers.Layer) -> RegistryFunction:
"""Returns the layer registry function for a given layer instance.""" """Returns the layer registry function for a given layer instance."""
return self._registry[hash(layer_instance.__class__)] return self._registry[hash(layer_instance.__class__)]
def insert(self, layer_class, layer_registry_function): def insert(
self,
layer_class: Type[tf.keras.layers.Layer],
layer_registry_function: RegistryFunction,
):
"""Inserts a layer registry function into the internal dictionaries.""" """Inserts a layer registry function into the internal dictionaries."""
layer_key = hash(layer_class) layer_key = hash(layer_class)
self._layer_class_dict[layer_key] = layer_class self._layer_class_dict[layer_key] = layer_class
@ -72,7 +91,11 @@ class LayerRegistry:
# ============================================================================== # ==============================================================================
# Supported Keras layers # Supported Keras layers
# ============================================================================== # ==============================================================================
def dense_layer_computation(layer_instance, inputs, tape): def dense_layer_computation(
layer_instance: tf.keras.layers.Dense,
inputs: tuple[InputTensor],
tape: tf.GradientTape,
) -> RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Dense`. """Registry function for `tf.keras.layers.Dense`.
The logic for this computation is based on the following paper: The logic for this computation is based on the following paper:
@ -83,8 +106,9 @@ def dense_layer_computation(layer_instance, inputs, tape):
Args: Args:
layer_instance: A `tf.keras.layers.Dense` instance. layer_instance: A `tf.keras.layers.Dense` instance.
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e., inputs: A tuple containing a single `InputTensor` which can be passed into
`layer_instance(inputs)` returns a valid output. the layer instance, i.e., `layer_instance(*inputs)` returns a valid
output.
tape: A `tf.GradientTape` instance that will be used to watch the output tape: A `tf.GradientTape` instance that will be used to watch the output
`base_vars`. `base_vars`.
@ -100,6 +124,8 @@ def dense_layer_computation(layer_instance, inputs, tape):
trainable variables in `layer_instance`. These squared norms should be a 1D trainable variables in `layer_instance`. These squared norms should be a 1D
`tf.Tensor` of length `batch_size`. `tf.Tensor` of length `batch_size`.
""" """
if len(inputs) != 1:
raise ValueError("Only layer inputs of length 1 are permitted.")
orig_activation = layer_instance.activation orig_activation = layer_instance.activation
layer_instance.activation = None layer_instance.activation = None
base_vars = layer_instance(*inputs) base_vars = layer_instance(*inputs)
@ -125,7 +151,11 @@ def dense_layer_computation(layer_instance, inputs, tape):
return base_vars, outputs, sqr_norm_fn return base_vars, outputs, sqr_norm_fn
def embedding_layer_computation(layer_instance, inputs, tape): def embedding_layer_computation(
layer_instance: tf.keras.layers.Embedding,
inputs: tuple[InputTensor],
tape: tf.GradientTape,
) -> RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Embedding`. """Registry function for `tf.keras.layers.Embedding`.
The logic of this computation is based on the `tf.keras.layers.Dense` The logic of this computation is based on the `tf.keras.layers.Dense`
@ -136,8 +166,9 @@ def embedding_layer_computation(layer_instance, inputs, tape):
Args: Args:
layer_instance: A `tf.keras.layers.Embedding` instance. layer_instance: A `tf.keras.layers.Embedding` instance.
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e., inputs: A tuple containing a single `InputTensor` which can be passed into
`layer_instance(inputs)` returns a valid output. the layer instance, i.e., `layer_instance(*inputs)` returns a valid
output.
tape: A `tf.GradientTape` instance that will be used to watch the output tape: A `tf.GradientTape` instance that will be used to watch the output
`base_vars`. `base_vars`.
@ -153,10 +184,12 @@ def embedding_layer_computation(layer_instance, inputs, tape):
trainable variables in `layer_instance`. These squared norms should be a 1D trainable variables in `layer_instance`. These squared norms should be a 1D
`tf.Tensor` of length `batch_size`. `tf.Tensor` of length `batch_size`.
""" """
if len(inputs) != 1:
raise ValueError("Only layer inputs of length 1 are permitted.")
if hasattr(layer_instance, "sparse"): # for backwards compatibility if hasattr(layer_instance, "sparse"): # for backwards compatibility
if layer_instance.sparse: if layer_instance.sparse:
raise NotImplementedError("Sparse output tensors are not supported.") raise NotImplementedError("Sparse output tensors are not supported.")
if isinstance(inputs, tf.SparseTensor): if isinstance(inputs[0], tf.SparseTensor):
raise NotImplementedError("Sparse input tensors are not supported.") raise NotImplementedError("Sparse input tensors are not supported.")
# Disable experimental features. # Disable experimental features.
@ -225,7 +258,7 @@ def embedding_layer_computation(layer_instance, inputs, tape):
# ============================================================================== # ==============================================================================
# Main factory methods # Main factory methods
# ============================================================================== # ==============================================================================
def make_default_layer_registry(): def make_default_layer_registry() -> LayerRegistry:
registry = LayerRegistry() registry = LayerRegistry()
registry.insert(tf.keras.layers.Dense, dense_layer_computation) registry.insert(tf.keras.layers.Dense, dense_layer_computation)
registry.insert(tf.keras.layers.Embedding, embedding_layer_computation) registry.insert(tf.keras.layers.Embedding, embedding_layer_computation)