Improve documentation and logging of fast gradient clipping modules and callers.

PiperOrigin-RevId: 513283486
This commit is contained in:
A. Unique TensorFlower 2023-03-01 10:55:23 -08:00
parent d7cd3f8af1
commit 7436930c64
5 changed files with 129 additions and 42 deletions

View file

@ -6,6 +6,7 @@ py_library(
name = "gradient_clipping_utils",
srcs = ["gradient_clipping_utils.py"],
srcs_version = "PY3",
deps = [":layer_registry"],
)
py_library(

View file

@ -21,11 +21,18 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
`compute_gradient_norms()` function).
"""
from typing import Union, Iterable, Text
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
def get_registry_generator_fn(tape, layer_registry):
def get_registry_generator_fn(
tape: tf.GradientTape, layer_registry: lr.LayerRegistry
):
"""Creates the generator function for `compute_gradient_norms()`."""
if layer_registry is None:
# Needed for backwards compatibility.
@ -53,7 +60,12 @@ def get_registry_generator_fn(tape, layer_registry):
return registry_generator_fn
def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
def compute_gradient_norms(
input_model: tf.keras.Model,
x_batch: InputTensor,
y_batch: tf.Tensor,
layer_registry: lr.LayerRegistry,
):
"""Computes the per-example loss gradient norms for given data.
Applies a variant of the approach given in
@ -62,7 +74,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
Args:
input_model: The `tf.keras.Model` from which to obtain the layers from. The
loss of the model *must* be a scalar loss.
x_batch: A `tf.Tensor` representing a batch of inputs to the model. The
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
@ -106,7 +118,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
def compute_clip_weights(l2_norm_clip, gradient_norms):
def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
"""Computes the per-example loss/clip weights for clipping.
When the sum of the per-example losses is replaced a weighted sum, where
@ -132,7 +144,11 @@ def compute_clip_weights(l2_norm_clip, gradient_norms):
def compute_pred_and_clipped_gradients(
input_model, x_batch, y_batch, l2_norm_clip, layer_registry
input_model: tf.keras.Model,
x_batch: InputTensor,
y_batch: tf.Tensor,
l2_norm_clip: float,
layer_registry: lr.LayerRegistry,
):
"""Computes the per-example predictions and per-example clipped loss gradient.
@ -147,7 +163,7 @@ def compute_pred_and_clipped_gradients(
Args:
input_model: The `tf.keras.Model` from which to obtain the layers from.
x_batch: A `tf.Tensor` representing a batch of inputs to the model. The
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the

View file

@ -13,6 +13,8 @@
# limitations under the License.
import itertools
from typing import Callable, Any, List, Union
from absl.testing import parameterized
import tensorflow as tf
@ -20,23 +22,35 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
# ==============================================================================
# Type aliases
# ==============================================================================
LayerGenerator = Callable[[int, int], tf.keras.layers.Layer]
ModelGenerator = Callable[
[LayerGenerator, Union[int, List[int]], int], tf.keras.Model
]
# ==============================================================================
# Helper functions and classes.
# ==============================================================================
class DoubleDense(tf.keras.layers.Layer):
"""Generates two dense layers nested together."""
def __init__(self, units):
def __init__(self, units: int):
super().__init__()
self.dense1 = tf.keras.layers.Dense(units)
self.dense2 = tf.keras.layers.Dense(1)
def call(self, inputs):
def call(self, inputs: Any):
x = self.dense1(inputs)
return self.dense2(x)
def double_dense_layer_computation(layer_instance, inputs, tape):
def double_dense_layer_computation(
layer_instance: tf.keras.layers.Layer, inputs: Any, tape: tf.GradientTape
):
"""Layer registry function for the custom `DoubleDense` layer class."""
vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation(
layer_instance.dense1, inputs, tape
@ -53,7 +67,9 @@ def double_dense_layer_computation(layer_instance, inputs, tape):
return [vars1, vars2], outputs, sqr_norm_fn
def compute_true_gradient_norms(input_model, x_batch, y_batch):
def compute_true_gradient_norms(
input_model: tf.keras.Model, x_batch: tf.Tensor, y_batch: tf.Tensor
):
"""Computes the real gradient norms for an input `(model, x, y)`."""
loss_config = input_model.loss.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
@ -73,14 +89,14 @@ def compute_true_gradient_norms(input_model, x_batch, y_batch):
def get_computed_and_true_norms(
model_generator,
layer_generator,
input_dims,
output_dim,
is_eager,
x_input,
rng_seed=777,
registry=None,
model_generator: ModelGenerator,
layer_generator: LayerGenerator,
input_dims: Union[int, List[int]],
output_dim: int,
is_eager: bool,
x_input: tf.Tensor,
rng_seed: int = 777,
registry: layer_registry.LayerRegistry = None,
):
"""Obtains the true and computed gradient norms for a model and batch input.
@ -238,7 +254,7 @@ def make_weighted_bow_model(layer_generator, input_dims, output_dim):
# ==============================================================================
# Factory functions.
# ==============================================================================
def get_nd_test_tensors(n):
def get_nd_test_tensors(n: int):
"""Returns a list of candidate tests for a given dimension n."""
return [
tf.zeros((n,), dtype=tf.float64),
@ -246,7 +262,7 @@ def get_nd_test_tensors(n):
]
def get_nd_test_batches(n):
def get_nd_test_batches(n: int):
"""Returns a list of candidate input batches of dimension n."""
result = []
tensors = get_nd_test_tensors(n)

View file

@ -13,11 +13,19 @@
# limitations under the License.
"""Utility functions that help in the computation of per-example gradient norms."""
from typing import Any, Union, Iterable, Text, Callable, Tuple, Optional
from absl import logging
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
def has_internal_compute_graph(input_object):
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
GeneratorFunction = Optional[Callable[[Any, Tuple, dict], Tuple[Any, Any]]]
def has_internal_compute_graph(input_object: Any):
"""Checks if input is a TF model and has a TF internal compute graph."""
return (
isinstance(input_object, tf.keras.Model)
@ -28,7 +36,9 @@ def has_internal_compute_graph(input_object):
)
def _get_internal_layers(input_layer):
def _get_internal_layers(
input_layer: tf.keras.layers.Layer,
) -> list[tf.keras.layers.Layer]:
"""Returns a list of layers that are nested within a given layer."""
internal_layers = []
if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'):
@ -39,7 +49,11 @@ def _get_internal_layers(input_layer):
return internal_layers
def model_forward_pass(input_model, inputs, generator_fn=None):
def model_forward_pass(
input_model: tf.keras.Model,
inputs: InputTensor,
generator_fn: GeneratorFunction = None,
) -> Tuple[tf.Tensor, list[Any]]:
"""Does a forward pass of a model and returns useful intermediates.
NOTE: the graph traversal algorithm is an adaptation of the logic in the
@ -118,7 +132,9 @@ def model_forward_pass(input_model, inputs, generator_fn=None):
return node_layer_outputs, generator_outputs_list
def all_trainable_layers_are_registered(input_model, layer_registry):
def all_trainable_layers_are_registered(
input_model: tf.keras.Model, layer_registry: lr.LayerRegistry
) -> bool:
"""Check if an input model's trainable layers are all registered.
Args:
@ -140,18 +156,21 @@ def all_trainable_layers_are_registered(input_model, layer_registry):
def add_aggregate_noise(
input_model, x_batch, clipped_grads, l2_norm_clip, noise_multiplier
):
input_model: tf.keras.Model,
x_batch: InputTensor,
clipped_grads: list[tf.Tensor],
l2_norm_clip: float,
noise_multiplier: float,
) -> list[tf.Tensor]:
"""Adds noise to a collection of clipped gradients.
The magnitude of the noise depends on the aggregation strategy of the
input model's loss function.
Args:
input_model: The Keras model to obtain the layers from.
x_batch: A collection of Tensors to be fed into the input layer of the
model.
clipped_grads: A list of tensors representing the clipped gradients.
input_model: The `tf.keras.Model` to obtain the layers from.
x_batch: An `InputTensor` to be fed into the input layer of the model.
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
@ -187,7 +206,9 @@ def add_aggregate_noise(
return tf.nest.map_structure(add_noise, clipped_grads)
def generate_model_outputs_using_core_keras_layers(input_model):
def generate_model_outputs_using_core_keras_layers(
input_model: tf.keras.Model,
) -> tf.Tensor:
"""Returns the model outputs generated by only core Keras layers."""
cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects())
cust_hash_set = set([hash(v) for v in cust_obj_dict.values()])

View file

@ -40,9 +40,24 @@ where `l2_row_norm(y)` computes the L2 norm for each row of an input `y`.
Details of this decomposition can be found in https://arxiv.org/abs/1510.01799
"""
from typing import Callable, Type, Any, Union, Iterable, Text
import tensorflow as tf
# ==============================================================================
# Type aliases
# ==============================================================================
SquareNormFunction = Callable[[Any], tf.Tensor]
RegistryFunctionOutput = tuple[Any, tf.Tensor, SquareNormFunction]
RegistryFunction = Callable[
[Any, tuple[Any], tf.GradientTape], RegistryFunctionOutput
]
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
# ==============================================================================
# Main class
# ==============================================================================
@ -54,15 +69,19 @@ class LayerRegistry:
self._layer_class_dict = {}
self._registry = {}
def is_elem(self, layer_instance):
def is_elem(self, layer_instance: tf.keras.layers.Layer) -> bool:
"""Checks if a layer instance's class is in the registry."""
return hash(layer_instance.__class__) in self._registry
def lookup(self, layer_instance):
def lookup(self, layer_instance: tf.keras.layers.Layer) -> RegistryFunction:
"""Returns the layer registry function for a given layer instance."""
return self._registry[hash(layer_instance.__class__)]
def insert(self, layer_class, layer_registry_function):
def insert(
self,
layer_class: Type[tf.keras.layers.Layer],
layer_registry_function: RegistryFunction,
):
"""Inserts a layer registry function into the internal dictionaries."""
layer_key = hash(layer_class)
self._layer_class_dict[layer_key] = layer_class
@ -72,7 +91,11 @@ class LayerRegistry:
# ==============================================================================
# Supported Keras layers
# ==============================================================================
def dense_layer_computation(layer_instance, inputs, tape):
def dense_layer_computation(
layer_instance: tf.keras.layers.Dense,
inputs: tuple[InputTensor],
tape: tf.GradientTape,
) -> RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Dense`.
The logic for this computation is based on the following paper:
@ -83,8 +106,9 @@ def dense_layer_computation(layer_instance, inputs, tape):
Args:
layer_instance: A `tf.keras.layers.Dense` instance.
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
`layer_instance(inputs)` returns a valid output.
inputs: A tuple containing a single `InputTensor` which can be passed into
the layer instance, i.e., `layer_instance(*inputs)` returns a valid
output.
tape: A `tf.GradientTape` instance that will be used to watch the output
`base_vars`.
@ -100,6 +124,8 @@ def dense_layer_computation(layer_instance, inputs, tape):
trainable variables in `layer_instance`. These squared norms should be a 1D
`tf.Tensor` of length `batch_size`.
"""
if len(inputs) != 1:
raise ValueError("Only layer inputs of length 1 are permitted.")
orig_activation = layer_instance.activation
layer_instance.activation = None
base_vars = layer_instance(*inputs)
@ -125,7 +151,11 @@ def dense_layer_computation(layer_instance, inputs, tape):
return base_vars, outputs, sqr_norm_fn
def embedding_layer_computation(layer_instance, inputs, tape):
def embedding_layer_computation(
layer_instance: tf.keras.layers.Embedding,
inputs: tuple[InputTensor],
tape: tf.GradientTape,
) -> RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Embedding`.
The logic of this computation is based on the `tf.keras.layers.Dense`
@ -136,8 +166,9 @@ def embedding_layer_computation(layer_instance, inputs, tape):
Args:
layer_instance: A `tf.keras.layers.Embedding` instance.
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
`layer_instance(inputs)` returns a valid output.
inputs: A tuple containing a single `InputTensor` which can be passed into
the layer instance, i.e., `layer_instance(*inputs)` returns a valid
output.
tape: A `tf.GradientTape` instance that will be used to watch the output
`base_vars`.
@ -153,10 +184,12 @@ def embedding_layer_computation(layer_instance, inputs, tape):
trainable variables in `layer_instance`. These squared norms should be a 1D
`tf.Tensor` of length `batch_size`.
"""
if len(inputs) != 1:
raise ValueError("Only layer inputs of length 1 are permitted.")
if hasattr(layer_instance, "sparse"): # for backwards compatibility
if layer_instance.sparse:
raise NotImplementedError("Sparse output tensors are not supported.")
if isinstance(inputs, tf.SparseTensor):
if isinstance(inputs[0], tf.SparseTensor):
raise NotImplementedError("Sparse input tensors are not supported.")
# Disable experimental features.
@ -225,7 +258,7 @@ def embedding_layer_computation(layer_instance, inputs, tape):
# ==============================================================================
# Main factory methods
# ==============================================================================
def make_default_layer_registry():
def make_default_layer_registry() -> LayerRegistry:
registry = LayerRegistry()
registry.insert(tf.keras.layers.Dense, dense_layer_computation)
registry.insert(tf.keras.layers.Embedding, embedding_layer_computation)