First implementation of the fast gradient clipping algorithm.

PiperOrigin-RevId: 504668189
This commit is contained in:
A. Unique TensorFlower 2023-01-25 14:50:39 -08:00
parent ee3d349a8d
commit a3b14ae20a
4 changed files with 564 additions and 0 deletions

View file

@ -0,0 +1,25 @@
load("@rules_python//python:defs.bzl", "py_library")
package(default_visibility = ["//visibility:public"])
py_library(
name = "gradient_clipping_utils",
srcs = ["gradient_clipping_utils.py"],
srcs_version = "PY3",
)
py_library(
name = "layer_registry_factories",
srcs = ["layer_registry_factories.py"],
srcs_version = "PY3",
)
py_library(
name = "clip_grads",
srcs = ["clip_grads.py"],
srcs_version = "PY3",
deps = [
":gradient_clipping_utils",
":layer_registry_factories",
],
)

View file

@ -0,0 +1,185 @@
# Copyright 2022, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Computes per-example loss clip weights.
For a given Keras model and batch of inputs, computes the per-example
clip weights so that the gradient of the loss function, weighted by these
weights, is equivalent to the gradient of the original loss function but
with the per-example gradients clipped by some clip weight. Uses a variant
of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
`compute_gradient_norms()` function).
"""
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
def combine_pre_and_post_sqr_norms(pre_sqr_norm, post_grad, layer_hash):
"""Combines pre and post-activation tensors for a given variable.
The logic for combining norms depends on the variable's underlying layer.
Args:
pre_sqr_norm: A `tf.Tensor` whose first dimension is the batch dimension.
Contains squared norms that are related to the pre-activation Tensor.
post_grad: A `tf.Tensor` whose first dimension is the batch dimension.
Contains gradients that are related to the post-activation Tensor.
layer_hash: A `float` that is the hash of the variable's underlying layer
class.
Returns:
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
per-example loss function with respect to the given variable.
"""
post_sqr_grads = tf.square(post_grad)
if layer_hash == hash(tf.keras.layers.Embedding):
scaled_grads = tf.expand_dims(pre_sqr_norm, axis=-1) * post_sqr_grads
reduction_axes = tf.range(1, tf.rank(scaled_grads))
return tf.reduce_sum(scaled_grads, axis=reduction_axes)
else:
reduction_axes = tf.range(1, tf.rank(post_sqr_grads))
post_sqr_norm = tf.reduce_sum(post_sqr_grads, axis=reduction_axes)
return pre_sqr_norm * post_sqr_norm
def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
"""Computes the per-example loss gradient norms for given data.
Applies the approach given in https://arxiv.org/pdf/2009.03106.pdf, except
the batch matrix multiplication operation in Algorithm 2 is replaced with
the computation of two norm computations.
Args:
input_model: The `tf.keras.Model` from which to obtain the layers from. The
loss of the model *must* be a scalar loss.
x_batch: A `tf.Tensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
number of examples in `x_batch`.
layer_registry: A `dict` of layers that support "fast" gradient norm
computations. The key is the class of the layer and the value is a
function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable weights (see `layer_registry_factories.py` for examples).
Returns:
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
per-example loss function.
"""
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
# First loop computes the norms of the layer inputs, caches these inputs,
# and computes the summed loss.
with tape:
model_outputs, pre_norm_list, var_list, layer_hash_list = (
gradient_clipping_utils.forward_norm_pass(
input_model, x_batch, tape, layer_registry
)
)
# Ignore the original loss function's reduction to get per-example loss.
loss_config = input_model.loss.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
per_example_loss_fn = input_model.loss.from_config(loss_config)
losses = per_example_loss_fn(y_batch, model_outputs)
if tf.rank(tf.squeeze(losses)) > 1:
raise NotImplementedError('Vector losses are not supported.')
summed_loss = tf.reduce_sum(losses)
# Second loop computes the norm of the gradient of the loss with respect to
# the pre-activation tensors, and multiplies these norms with the results of
# the first loop.
full_norm_list = []
grads = tape.gradient(summed_loss, var_list)
for i in range(len(var_list)):
full_norm = combine_pre_and_post_sqr_norms(
pre_norm_list[i], grads[i], layer_hash_list[i]
)
full_norm_list.append(full_norm)
del tape
# Post-processing for compatibility with non-eager mode (very annoying).
full_norm_tsr = tf.stack(full_norm_list, axis=1)
return tf.sqrt(tf.reduce_sum(full_norm_tsr, axis=1))
def compute_clip_weights(l2_norm_clip, gradient_norms):
"""Computes the per-example loss/clip weights for clipping.
When the sum of the per-example losses is replaced a weighted sum, where
the weights are generated by this method, then the gradients of each
term in the weighted sum are clipped by the given clip value.
Args:
l2_norm_clip: A `float` indicating the norm to which per-example gradients
will be clipped. That is, all gradients of the per-example loss functions
will have norm at most `l2_norm_clip`.
gradient_norms: A 1D `tf.Tensor` whose i-th entry is the norm of the
gradient of the loss function for the i-th input.
Returns:
A 1D `tf.Tensor` representing whose i-th entry `C[i]` is either `1.0` if the
norm of the gradient of i-th per-example loss `G[i]` is less than
`l2_norm_clip` or a number less than `1.0` so that
`|G[i]| * C[i] == l2_norm_clip` otherwise.
"""
if l2_norm_clip is None:
return None
return l2_norm_clip / tf.math.maximum(l2_norm_clip, gradient_norms)
def compute_pred_and_clipped_gradients(
input_model, x_batch, y_batch, l2_norm_clip, layer_registry
):
"""Computes the per-example predictions and per-example clipped loss gradient.
Given a batch of observations `(x_batch, y_batch)`, the main steps of this
function are: (i) compute the l2-norm of the gradients of the trainable
variables of `input_model` for each example in the batch; (ii) use the norms
computed in (i) to obtain "clip_weights" that are used to generate a weighted
loss function whose gradient for each example has l2-norm at most
`l2_norm_clip`; (iii) output the clipped gradients in (ii) and the
`tf.Tensor` generated by `input_model` when it is given `x_batch` as its
input.
Args:
input_model: The `tf.keras.Model` from which to obtain the layers from.
x_batch: A `tf.Tensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
number of examples in `x_batch`.
l2_norm_clip: A `float` indicating the norm to which per-example gradients
will be clipped. That is, all gradients of the per-example loss functions
will have norm at most `l2_norm_clip`.
layer_registry: A `dict` of layers that support "fast" gradient norm
computations. The key is the class of the layer and the value is a
function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable weights (see `layer_registry_factories.py` for examples).
Returns:
A `tuple` `(y_pred, grad)`. The first element is the prediction generated by
the model on the input `x_batch`. The second element is the clipped
gradient of the loss function.
"""
gradient_norms = compute_gradient_norms(
input_model, x_batch, y_batch, layer_registry
)
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
with tf.GradientTape() as tape:
y_pred = input_model(x_batch, training=True)
loss_value = input_model.compute_loss(
x_batch, y_batch, y_pred, loss_weights
)
return y_pred, tape.gradient(loss_value, input_model.trainable_variables)

View file

@ -0,0 +1,214 @@
# Copyright 2022, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions that help in the computation of per-example gradient norms."""
from absl import logging
import tensorflow as tf
def has_internal_compute_graph(input_object):
"""Checks if input is a TF model and has a TF internal compute graph."""
return (
isinstance(input_object, tf.keras.Model)
and hasattr(input_object, '_flatten_to_reference_inputs')
and hasattr(input_object, '_tensor_usage_count')
and hasattr(input_object, '_conform_to_reference_input')
and hasattr(input_object, '_nodes_by_depth')
)
def forward_norm_pass(input_model, x_batch, tape, layer_registry):
"""Does a forward pass of a model and returns some useful intermediates.
NOTE: the graph traversal algorithm is an adaptation of the logic in the
_run_internal_graph() method in the functional.Functional class. Hence,
forward_norm_pass should only be invoked if the generated model
instance is an instance of the functional.Functional class.
Args:
input_model: A Keras functional model to compute the quantities for.
x_batch: A collection of Tensors to be fed into the input layer of the
model.
tape: A tf.GradientTape() instance used to record certain operations.
Assumes that this function is being called inside of this tape.
layer_registry: A dictionary of layers that support "fast" gradient norm
computations. The key is the class of the layer and the value is a
function that returns a triple (norm_list, var_list, transform). For more
details, see `layer_registry_factories.py`.
Returns:
Four objects (outputs, norm_list, var_list, layer_hash_list). The first are
the outputs that are generated as a result of a forward pass. The second is
either `None` or a collection of squared norms for each example that helps
in computing the gradient norm of an example (if such a "nice" computation
exists). The third is an ordered list of tf.Tensor() objects that are
intended to be evaluated with respect to the summed loss of the model. The
fourth is a list whose i-th element is the hash of the layer class of the
i-th element of var_list.
"""
# TODO: Avoid or remove the references to protected methods of `input_model`. # pylint: disable=g-bad-todo
# Prepare the inputs and BFS variables.
flattened_inputs = input_model._flatten_to_reference_inputs(x_batch) # pylint: disable=protected-access
tensor_dict = {}
tensor_usage_count = input_model._tensor_usage_count # pylint: disable=protected-access
for x, y in zip(input_model.inputs, flattened_inputs):
y = input_model._conform_to_reference_input(y, ref_input=x) # pylint: disable=protected-access
x_id = str(id(x))
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
# Main computations.
nodes_by_depth = input_model._nodes_by_depth # pylint: disable=protected-access
depth_keys = list(nodes_by_depth.keys())
depth_keys.sort(reverse=True)
var_list = []
norm_list = []
layer_hash_list = []
node_outputs = None
# Perform BFS feedforward computations.
for depth in depth_keys:
for node in nodes_by_depth[depth]:
if node.is_input:
continue # inputs already exist
if any(t_id not in tensor_dict for t_id in node.flat_input_ids):
continue # node is not computable; try skipping
args, kwargs = node.map_arguments(tensor_dict)
if has_internal_compute_graph(node.layer):
# If this node has an internal computational graph, we can recurse.
node_outputs, node_norm_list, node_var_list, node_layer_hash_list = (
forward_norm_pass(node.layer, args, tape, layer_registry)
)
var_list += node_var_list
norm_list += node_norm_list
layer_hash_list += node_layer_hash_list
else:
# Either pass through or record some metadata.
if not node.layer.trainable_variables:
node_outputs = node.layer(*args, **kwargs)
else:
lyr_hash = hash(node.layer.__class__)
if lyr_hash not in layer_registry:
raise NotImplementedError(
'Layer %s is not in the registry of known layers that can'
'be used for efficient gradient clipping.'
% node.layer.__class__.__name__
)
lyr = layer_registry[lyr_hash]
node_norms, node_vars, transform = lyr(node.layer, args)
tape.watch(node_vars)
node_outputs = transform(node_vars) if transform else node_vars
var_list.append(node_vars)
norm_list.append(node_norms)
layer_hash_list.append(lyr_hash)
# update the current dictionary of inputs for the next node.
for x_id, y in zip(node.flat_output_ids, tf.nest.flatten(node_outputs)):
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
return node_outputs, norm_list, var_list, layer_hash_list
def get_trainable_hidden_layers(input_model):
"""Obtains the trainable hidden layers of a Keras model.
Args:
input_model: The Keras model to obtain the layers from.
Returns:
A list of Keras layers where the tensorflow.keras.layers.Layer
ancestor class MUST precede any existing tensorflow.keras.models.Model
ancestor class.
"""
hidden_layers = []
for l in input_model.layers:
for c in l.__class__.__mro__:
if c == tf.keras.models.Model:
hidden_layers += get_trainable_hidden_layers(l)
break
elif c == tf.keras.layers.InputLayer:
break
elif c == tf.keras.layers.Layer:
if l.trainable_variables:
hidden_layers.append(l)
break
return hidden_layers
def all_trainable_layers_are_registered(input_model, layer_registry):
"""Check if an input model's trainable layers are all registered.
Args:
input_model: The Keras model from which to obtain the layers from.
layer_registry: A dictionary of layers that support "fast" gradient norm
computations. The key is the class of the layer and the value is a
function that returns a triple (output, sqr_grad_norms, vars), where
output is the pre-activator tensor, sqr_grad_norms is the square of the
norm of the layer's input, and vars is an ordered list of the trainable
weights.
Returns:
True if all the trainable layers in `input_model` are in `layer_registry`.
False otherwise.
"""
hidden_layers = get_trainable_hidden_layers(input_model)
for l in hidden_layers:
if hash(l.__class__) not in layer_registry:
return False
return True
def add_aggregate_noise(
input_model, x_batch, clipped_grads, l2_norm_clip, noise_multiplier
):
"""Adds noise to a collection of clipped gradients.
The magnitude of the noise depends on the aggregation strategy of the
input model's loss function.
Args:
input_model: The Keras model to obtain the layers from.
x_batch: A collection of Tensors to be fed into the input layer of the
model.
clipped_grads: A list of tensors representing the clipped gradients.
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
Returns:
A list of tensors containing the clipped gradients, but with the right
amount of Gaussian noise added to them (depending on the reduction
strategy of the loss function).
"""
scale = l2_norm_clip
if input_model.loss.reduction in [
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
tf.keras.losses.Reduction.AUTO,
]:
if input_model.loss.reduction == tf.keras.losses.Reduction.AUTO:
logging.info('Assuming that the loss reduction is `SUM_OVER_BATCH_SIZE`.')
if isinstance(x_batch, tf.Tensor):
scale /= tf.cast(tf.shape(x_batch)[0], tf.float32)
elif isinstance(x_batch, dict):
batch_sizes = [
tf.cast(tf.shape(v)[0], tf.float32) for v in x_batch.values()
]
scale /= tf.math.reduce_min(batch_sizes)
else:
raise NotImplementedError(
'Unknown container/class %s for input' % x_batch.__class__.__name__
)
def add_noise(g):
return g + tf.random.normal(
tf.shape(g), mean=0.0, stddev=noise_multiplier * scale
)
return tf.nest.map_structure(add_noise, clipped_grads)

View file

@ -0,0 +1,140 @@
# Copyright 2022, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generates a default layer registry.
Defines "fast" gradient norm layer registry functions for use in the "fast"
gradient clipping algorithm. Specifically, each registry function takes
in two inputs (i) a layer instance and (ii) `tf.Tensor` inputs to produce three
outputs: (a) a `tf.Tensor` `G` of gradient norms, (b) a differentiable
`tf.Tensor` `Z`, and (c) either `None` or a function that maps the object in
(b) to the layer instance's output when using the inputs in (ii). If (c) is
`None`, then (b) contains the layer outputs.
When a layer registry function is defined, it is generally assumed that the
following relation holds for each pair `(g, z)` in `zip(G, Z)`:
`|dL/dw|^2 == |dL/dz|^2 * g^2`
where `L` is any per-example loss function and `w` are the trainable variables
corresponding to `(g, z)`.
For example, this relation holds if the layer instance is tf.keras.layers.Dense,
Z contains the pre-activation tensors, i.e., `z = X * w` for input `X`, and `g`
is the norm of the input corresponding to the given per-example loss (see the
formulae in https://arxiv.org/abs/1510.01799 for more details).
The registry functions are registered in a `dict` (registry) whose key is the
hash of the layer class and whose value is the registry function.
"""
import tensorflow as tf
# ==============================================================================
# Supported Keras layers
# ==============================================================================
def dense_layer_computation(layer_instance, inputs):
"""Registry function for `tf.keras.layers.Dense`.
The logic for this computation is based on the following paper:
https://arxiv.org/abs/1510.01799
For the sake of efficiency, we fuse the variables and square grad norms
for the kernel weights and bias vector together.
Args:
layer_instance: A `tf.keras.layers.Dense` instance.
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
`layer_instance(inputs)` returns a valid output.
Returns:
A `tuple` `(sqr_grad_norms, base_vars, transform)`, where `norms` is a 1D
`tf.Tensor` of the squared l2-norms of the input tensors, `base_vars` is the
intermediate Tensor used in the chain-rule / "fast" clipping trick, and
`transform` is a function that maps `base_vars` to the layer outputs.
"""
orig_activation = layer_instance.activation
layer_instance.activation = None
base_vars = layer_instance(*inputs)
sqr_inputs = tf.square(*inputs)
inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs))
sqr_grad_norms = tf.reduce_sum(tf.square(*inputs), axis=inputs_reduction_axes)
if layer_instance.use_bias:
# Adding a bias term is equivalent to a layer with no bias term and which
# adds an additional variable to the layer input that only takes a constant
# value of 1.0. This is thus equivalent to adding 1.0 to the sum of the
# squared values of the inputs.
sqr_grad_norms += tf.cast(1.0, dtype=sqr_grad_norms.dtype)
layer_instance.activation = orig_activation
return sqr_grad_norms, base_vars, layer_instance.activation
def embedding_layer_computation(layer_instance, inputs):
"""Registry function for `tf.keras.layers.Embedding`.
The logic of this computation is based on the `tf.keras.layers.Dense`
computation and the fact that an embedding layer is just a dense layer
with no activation function and an output vector of the form X*W for input
X, where the i-th row of W is the i-th embedding vector and the j-th row of
X is a one-hot vector representing the input of example j.
Args:
layer_instance: A `tf.keras.layers.Embedding` instance.
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
`layer_instance(inputs)` returns a valid output.
Returns:
A `tuple` `(sqr_grad_norms, base_vars, None)`, where `sqr_grad_norms` is
a `tf.Tensor` that is related to the squared l2-norms of the input tensors
and `base_vars` is the intermediate Tensor used in the chain-rule / "fast"
clipping trick.
"""
if layer_instance.sparse:
raise NotImplementedError("Sparse output vectors are not supported.")
# The logic below is applied to properly handle repeated embedding indices.
# Specifically, sqr_grad_norms will contain the total counts of each embedding
# index (see how it is processed in the combine_pre_and_post_sqr_norms()
# function in clip_grads.py). An example is as follows:
#
# inputs =
# [[0 0 0 1 2 2],
# [0 2 2 2 1 1]]
#
# counts =
# [[3 1 2]
# [1 2 3]]
#
# sqr_grad_norms =
# [[3 3 3 1 2 2],
# [1 3 3 3 2 2]]
#
base_vars = layer_instance(*inputs)
indices = tf.cast(*inputs, tf.int32)
if isinstance(indices, tf.SparseTensor):
indices = tf.sparse.to_dense(indices)
counts = tf.math.bincount(indices, axis=-1)
sqr_grad_norms = tf.cast(
tf.gather(counts, indices, batch_dims=1), base_vars.dtype
)
return sqr_grad_norms, base_vars, None
# ==============================================================================
# Main factory methods
# ==============================================================================
def make_default_layer_registry():
registry = {}
registry[hash(tf.keras.layers.Dense)] = dense_layer_computation
registry[hash(tf.keras.layers.Embedding)] = embedding_layer_computation
return registry