Generalize the internal API to allow for more general models + layers.

PiperOrigin-RevId: 509518753
This commit is contained in:
A. Unique TensorFlower 2023-02-14 07:10:03 -08:00
parent 6ee988885a
commit 410814ec39
9 changed files with 340 additions and 300 deletions

View file

@ -9,8 +9,8 @@ py_library(
)
py_library(
name = "layer_registry_factories",
srcs = ["layer_registry_factories.py"],
name = "layer_registry",
srcs = ["layer_registry.py"],
srcs_version = "PY3",
)
@ -20,7 +20,7 @@ py_library(
srcs_version = "PY3",
deps = [
":gradient_clipping_utils",
":layer_registry_factories",
":layer_registry",
],
)
@ -31,6 +31,6 @@ py_test(
srcs_version = "PY3",
deps = [
":clip_grads",
":layer_registry_factories",
":layer_registry",
],
)

View file

@ -25,40 +25,42 @@ import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
def combine_pre_and_post_sqr_norms(pre_sqr_norm, post_grad, layer_hash):
"""Combines pre and post-activation tensors for a given variable.
The logic for combining norms depends on the variable's underlying layer.
Args:
pre_sqr_norm: A `tf.Tensor` whose first dimension is the batch dimension.
Contains squared norms that are related to the pre-activation Tensor.
post_grad: A `tf.Tensor` whose first dimension is the batch dimension.
Contains gradients that are related to the post-activation Tensor.
layer_hash: A `float` that is the hash of the variable's underlying layer
class.
Returns:
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
per-example loss function with respect to the given variable.
"""
post_sqr_grads = tf.square(post_grad)
if layer_hash == hash(tf.keras.layers.Embedding):
scaled_grads = tf.expand_dims(pre_sqr_norm, axis=-1) * post_sqr_grads
reduction_axes = tf.range(1, tf.rank(scaled_grads))
return tf.reduce_sum(scaled_grads, axis=reduction_axes)
def get_registry_generator_fn(tape, layer_registry):
"""Creates the generator function for `compute_gradient_norms()`."""
if layer_registry is None:
# Needed for backwards compatibility.
registry_generator_fn = None
else:
reduction_axes = tf.range(1, tf.rank(post_sqr_grads))
post_sqr_norm = tf.reduce_sum(post_sqr_grads, axis=reduction_axes)
return pre_sqr_norm * post_sqr_norm
def registry_generator_fn(layer_instance, args, kwargs):
if layer_instance.trainable_variables:
# Only trainable variables factor into the gradient.
if not layer_registry.is_elem(layer_instance):
raise NotImplementedError(
'Layer %s is not in the registry of known layers that can '
'be used for efficient gradient clipping.'
% layer_instance.__class__.__name__
)
registry_fn = layer_registry.lookup(layer_instance)
(layer_vars, transform, layer_sqr_norm_fn) = registry_fn(
layer_instance, args
)
if tape is not None:
tape.watch(layer_vars)
layer_outputs = transform(layer_vars) if transform else layer_vars
return layer_outputs, (layer_vars, layer_sqr_norm_fn)
else:
# Non-trainable layer.
return layer_instance(*args, **kwargs), None
return registry_generator_fn
def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
"""Computes the per-example loss gradient norms for given data.
Applies the approach given in https://arxiv.org/pdf/2009.03106.pdf, except
the batch matrix multiplication operation in Algorithm 2 is replaced with
the computation of two norm computations.
Applies a variant of the approach given in
https://arxiv.org/pdf/2009.03106.pdf
Args:
input_model: The `tf.keras.Model` from which to obtain the layers from. The
@ -68,24 +70,22 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
number of examples in `x_batch`.
layer_registry: A `dict` of layers that support "fast" gradient norm
computations. The key is the class of the layer and the value is a
function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable weights (see `layer_registry_factories.py` for examples).
layer_registry: A `LayerRegistry` instance containing functions that help
compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details.
Returns:
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
per-example loss function.
"""
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
# First loop computes the norms of the layer inputs, caches these inputs,
# and computes the summed loss.
registry_generator_fn = get_registry_generator_fn(tape, layer_registry)
# First loop computes the model outputs, summed loss, and generator outputs.
with tape:
model_outputs, pre_norm_list, var_list, layer_hash_list = (
gradient_clipping_utils.forward_norm_pass(
input_model, x_batch, tape, layer_registry
model_outputs, generator_outputs_list = (
gradient_clipping_utils.model_forward_pass(
input_model, x_batch, generator_fn=registry_generator_fn
)
)
# Ignore the original loss function's reduction to get per-example loss.
@ -94,20 +94,19 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
per_example_loss_fn = input_model.loss.from_config(loss_config)
losses = per_example_loss_fn(y_batch, model_outputs)
summed_loss = tf.reduce_sum(losses)
# Second loop computes the norm of the gradient of the loss with respect to
# the pre-activation tensors, and multiplies these norms with the results of
# the first loop.
full_norm_list = []
grads = tape.gradient(summed_loss, var_list)
for i in range(len(var_list)):
full_norm = combine_pre_and_post_sqr_norms(
pre_norm_list[i], grads[i], layer_hash_list[i]
)
full_norm_list.append(full_norm)
# Unwrap the generator outputs so that the next loop avoids duplicating
# backprop ops.
filtered_outputs = [t for t in generator_outputs_list if t is not None]
vars_list = [a for (a, b) in filtered_outputs]
sqr_norm_fns_list = [b for (a, b) in filtered_outputs]
# Second loop evaluates the squared L2 norm functions and appends the results.
grads_list = tape.gradient(summed_loss, vars_list)
sqr_norm_list = []
for grads, f in zip(grads_list, sqr_norm_fns_list):
sqr_norm_list.append(f(grads))
del tape
# Post-processing for compatibility with non-eager mode (very annoying).
full_norm_tsr = tf.stack(full_norm_list, axis=1)
return tf.sqrt(tf.reduce_sum(full_norm_tsr, axis=1))
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
def compute_clip_weights(l2_norm_clip, gradient_norms):

View file

@ -17,7 +17,7 @@ from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry_factories
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
# ==============================================================================
@ -81,7 +81,7 @@ def get_computed_and_true_norms(
)
y_pred = model(x_input)
y_batch = tf.ones_like(y_pred)
registry = layer_registry_factories.make_default_layer_registry()
registry = layer_registry.make_default_layer_registry()
computed_norms = clip_grads.compute_gradient_norms(
model, x_input, y_batch, layer_registry=registry
)

View file

@ -28,8 +28,19 @@ def has_internal_compute_graph(input_object):
)
def forward_norm_pass(input_model, x_batch, tape, layer_registry):
"""Does a forward pass of a model and returns some useful intermediates.
def _get_internal_layers(input_layer):
"""Returns a list of layers that are nested within a given layer."""
internal_layers = []
if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'):
for layer in input_layer.layers:
internal_layers.extend(_get_internal_layers(layer))
else:
internal_layers.append(input_layer)
return internal_layers
def model_forward_pass(input_model, inputs, generator_fn=None):
"""Does a forward pass of a model and returns useful intermediates.
NOTE: the graph traversal algorithm is an adaptation of the logic in the
_run_internal_graph() method in the functional.Functional class. Hence,
@ -37,44 +48,42 @@ def forward_norm_pass(input_model, x_batch, tape, layer_registry):
instance is an instance of the functional.Functional class.
Args:
input_model: A Keras functional model to compute the quantities for.
x_batch: A collection of Tensors to be fed into the input layer of the
model.
tape: A tf.GradientTape() instance used to record certain operations.
Assumes that this function is being called inside of this tape.
layer_registry: A dictionary of layers that support "fast" gradient norm
computations. The key is the class of the layer and the value is a
function that returns a triple (norm_list, var_list, transform). For more
details, see `layer_registry_factories.py`.
input_model: A `tf.keras.Model` to compute the quantities for.
inputs: Arbitrary input to be fed into the input layer of the model. It is
expected that `input_model(inputs)` returns a valid output.
generator_fn: A function with signature `(tf.keras.layers.Layer, Any, Any)
-> (tf.Tensor, Any)`, where we require `generator_fn(layer_instance, args,
kwargs)[0] == layer_instance(*args, **kwargs)`. If `None`, then
`layer_fn(layer_instance, args, kwargs)[1] == None`.
Returns:
Four objects (outputs, norm_list, var_list, layer_hash_list). The first are
the outputs that are generated as a result of a forward pass. The second is
either `None` or a collection of squared norms for each example that helps
in computing the gradient norm of an example (if such a "nice" computation
exists). The third is an ordered list of tf.Tensor() objects that are
intended to be evaluated with respect to the summed loss of the model. The
fourth is a list whose i-th element is the hash of the layer class of the
i-th element of var_list.
A `tuple` `(outputs, generator_outputs_list)`. `outputs` is the
`tf.Tensor` that is generated as a result of a forward pass.
`generator_outputs_list` is a `list` whose i-th entry is the output of
`generator_fn(lyr, args, kwargs)[1]` where `lyr` is the i-th
layer when the compute graph of `input_model` is traversed in BFS order.
"""
# TODO: Avoid or remove the references to protected methods of `input_model`. # pylint: disable=g-bad-todo
# Default generator.
generator_outputs_list = []
if generator_fn is None:
def generator_fn(layer_instance, args, kwargs):
return layer_instance(*args, **kwargs), None
# Prepare the inputs and BFS variables.
flattened_inputs = input_model._flatten_to_reference_inputs(x_batch) # pylint: disable=protected-access
flattened_inputs = input_model._flatten_to_reference_inputs(inputs) # pylint: disable=protected-access
tensor_dict = {}
tensor_usage_count = input_model._tensor_usage_count # pylint: disable=protected-access
for x, y in zip(input_model.inputs, flattened_inputs):
y = input_model._conform_to_reference_input(y, ref_input=x) # pylint: disable=protected-access
x_id = str(id(x))
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
# Main computations.
nodes_by_depth = input_model._nodes_by_depth # pylint: disable=protected-access
depth_keys = list(nodes_by_depth.keys())
depth_keys.sort(reverse=True)
var_list = []
norm_list = []
layer_hash_list = []
node_outputs = None
# Perform BFS feedforward computations.
for depth in depth_keys:
for node in nodes_by_depth[depth]:
@ -85,62 +94,28 @@ def forward_norm_pass(input_model, x_batch, tape, layer_registry):
args, kwargs = node.map_arguments(tensor_dict)
if has_internal_compute_graph(node.layer):
# If this node has an internal computational graph, we can recurse.
node_outputs, node_norm_list, node_var_list, node_layer_hash_list = (
forward_norm_pass(node.layer, args, tape, layer_registry)
node_layer_outputs, node_generator_outputs = model_forward_pass(
node.layer, args, generator_fn
)
var_list += node_var_list
norm_list += node_norm_list
layer_hash_list += node_layer_hash_list
generator_outputs_list.extend(node_generator_outputs)
else:
# Either pass through or record some metadata.
if not node.layer.trainable_variables:
node_outputs = node.layer(*args, **kwargs)
else:
lyr_hash = hash(node.layer.__class__)
if lyr_hash not in layer_registry:
raise NotImplementedError(
'Layer %s is not in the registry of known layers that can'
'be used for efficient gradient clipping.'
% node.layer.__class__.__name__
)
lyr = layer_registry[lyr_hash]
node_norms, node_vars, transform = lyr(node.layer, args)
tape.watch(node_vars)
node_outputs = transform(node_vars) if transform else node_vars
var_list.append(node_vars)
norm_list.append(node_norms)
layer_hash_list.append(lyr_hash)
# update the current dictionary of inputs for the next node.
for x_id, y in zip(node.flat_output_ids, tf.nest.flatten(node_outputs)):
# Otherwise, we parse the node directly.
node_layers = _get_internal_layers(node.layer)
for layer in node_layers:
node_layer_outputs, layer_generator_outputs = generator_fn(
layer, args, kwargs
)
generator_outputs_list.append(layer_generator_outputs)
args = (node_layer_outputs,)
kwargs = {}
# Update the current dictionary of inputs for the next node.
for x_id, y in zip(
node.flat_output_ids, tf.nest.flatten(node_layer_outputs)
):
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
return node_outputs, norm_list, var_list, layer_hash_list
def get_trainable_hidden_layers(input_model):
"""Obtains the trainable hidden layers of a Keras model.
Args:
input_model: The Keras model to obtain the layers from.
Returns:
A list of Keras layers where the tensorflow.keras.layers.Layer
ancestor class MUST precede any existing tensorflow.keras.models.Model
ancestor class.
"""
hidden_layers = []
for l in input_model.layers:
for c in l.__class__.__mro__:
if c == tf.keras.models.Model:
hidden_layers += get_trainable_hidden_layers(l)
break
elif c == tf.keras.layers.InputLayer:
break
elif c == tf.keras.layers.Layer:
if l.trainable_variables:
hidden_layers.append(l)
break
return hidden_layers
return node_layer_outputs, generator_outputs_list
def all_trainable_layers_are_registered(input_model, layer_registry):
@ -148,21 +123,19 @@ def all_trainable_layers_are_registered(input_model, layer_registry):
Args:
input_model: The Keras model from which to obtain the layers from.
layer_registry: A dictionary of layers that support "fast" gradient norm
computations. The key is the class of the layer and the value is a
function that returns a triple (output, sqr_grad_norms, vars), where
output is the pre-activator tensor, sqr_grad_norms is the square of the
norm of the layer's input, and vars is an ordered list of the trainable
weights.
layer_registry: A `LayerRegistry` instance containing functions that help
compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details.
Returns:
True if all the trainable layers in `input_model` are in `layer_registry`.
False otherwise.
"""
hidden_layers = get_trainable_hidden_layers(input_model)
for l in hidden_layers:
if hash(l.__class__) not in layer_registry:
return False
for layer in input_model.layers:
for sublayer in _get_internal_layers(layer):
if not layer_registry.is_elem(sublayer) and sublayer.trainable_variables:
return False
return True
@ -212,3 +185,19 @@ def add_aggregate_noise(
)
return tf.nest.map_structure(add_noise, clipped_grads)
def generate_model_outputs_using_core_keras_layers(input_model):
"""Returns the model outputs generated by only core Keras layers."""
cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects())
cust_hash_set = set([hash(v) for v in cust_obj_dict.values()])
def generator_fn(layer_instance, args, kwargs):
if hash(layer_instance.__class__) in cust_hash_set:
# Using `.call()` does not register the layer in the compute graph of
# a forward pass.
return layer_instance.call(*args, **kwargs), None
else:
return layer_instance(*args, **kwargs), None
return model_forward_pass(input_model, input_model.inputs, generator_fn)[0]

View file

@ -0,0 +1,195 @@
# Copyright 2022, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the layer registry class and useful factory functions.
Defines "fast" gradient norm layer registry functions for use in the "fast"
gradient clipping algorithm. Specifically, each registry function takes
in two inputs (i) a layer instance and (ii) `tf.Tensor` inputs to produce three
outputs: (a) a differentiable `tf.Tensor` `Z`, (b) either `None` or a function
that maps the object in (a) to the layer instance's output when using the
inputs in (ii), and (c) a function `F` that generates the per-example
squared gradient norms when it is fed an object representing the gradient of
the summed loss with respect to `Z` in (a). If (b) is `None`, then (a) is
expected to contain the layer outputs.
When a layer registry function is defined, it is generally assumed that the
following relation holds:
`|dL/dW|^2 == F(grad_Z)`
where `gradient_Z` is the gradient of the summed loss with respect to `Z`.
For example, if the layer instance is tf.keras.layers.Dense, Z contains the
pre-activation tensors, i.e., `z = X * w` for input `X`, and `g` is a tensor
whose i-th entry is the L2 norm of the i-th input vector, then
`F(grad_Z) = g^2 * l2_row_norm(grad_Z)^2`,
where `l2_row_norm(y)` computes the L2 norm for each row of an input `y`.
Details of this decomposition can be found in https://arxiv.org/abs/1510.01799
"""
import tensorflow as tf
# ==============================================================================
# Main class
# ==============================================================================
class LayerRegistry:
"""Custom container for layer registry functions."""
def __init__(self):
"""Basic initialization of various internal dictionaries."""
self._layer_class_dict = {}
self._registry = {}
def is_elem(self, layer_instance):
"""Checks if a layer instance's class is in the registry."""
return hash(layer_instance.__class__) in self._registry
def lookup(self, layer_instance):
"""Returns the layer registry function for a given layer instance."""
return self._registry[hash(layer_instance.__class__)]
def insert(self, layer_class, layer_registry_function):
"""Inserts a layer registry function into the internal dictionaries."""
layer_key = hash(layer_class)
self._layer_class_dict[layer_key] = layer_class
self._registry[layer_key] = layer_registry_function
# ==============================================================================
# Supported Keras layers
# ==============================================================================
def dense_layer_computation(layer_instance, inputs):
"""Registry function for `tf.keras.layers.Dense`.
The logic for this computation is based on the following paper:
https://arxiv.org/abs/1510.01799
For the sake of efficiency, we fuse the variables and square grad norms
for the kernel weights and bias vector together.
Args:
layer_instance: A `tf.keras.layers.Dense` instance.
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
`layer_instance(inputs)` returns a valid output.
Returns:
A `tuple` `(base_vars, transform, sqr_norm_fn)`. `base_vars` is the
intermediate Tensor used in the chain-rule / "fast" clipping trick,
`transform` is a function that maps `base_vars` to the layer outputs, and
`sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that
represents the output of the call `tape.gradient(summed_loss, base_vars)`
where `tape` is a `tf.GradientTape` instance that records the dense
layer computation and `summed_loss` is the sum of the per-example losses
of the underlying model. This function then returns the per-example squared
L2 gradient norms of the trainable variables in `layer_instance`. These
squared norms should be a 1D `tf.Tensor` of length `batch_size`.
"""
orig_activation = layer_instance.activation
layer_instance.activation = None
base_vars = layer_instance(*inputs)
layer_instance.activation = orig_activation
def sqr_norm_fn(base_vars_grads):
sqr_inputs = tf.square(*inputs)
inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs))
input_sqr_norms = tf.reduce_sum(sqr_inputs, axis=inputs_reduction_axes)
if layer_instance.use_bias:
# Adding a bias term is equivalent to a layer with no bias term and which
# adds an additional variable to the layer input that only takes a
# constant value of 1.0. This is thus equivalent to adding 1.0 to the sum
# of the squared values of the inputs.
input_sqr_norms += tf.cast(1.0, dtype=input_sqr_norms.dtype)
reduction_axes = tf.range(1, tf.rank(base_vars_grads))
base_vars_sqr_norms = tf.reduce_sum(
tf.square(base_vars_grads), axis=reduction_axes
)
return input_sqr_norms * base_vars_sqr_norms
return base_vars, layer_instance.activation, sqr_norm_fn
def embedding_layer_computation(layer_instance, inputs):
"""Registry function for `tf.keras.layers.Embedding`.
The logic of this computation is based on the `tf.keras.layers.Dense`
computation and the fact that an embedding layer is just a dense layer
with no activation function and an output vector of the form X*W for input
X, where the i-th row of W is the i-th embedding vector and the j-th row of
X is a one-hot vector representing the input of example j.
Args:
layer_instance: A `tf.keras.layers.Embedding` instance.
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
`layer_instance(inputs)` returns a valid output.
Returns:
A `tuple` `(base_vars, None, sqr_norm_fn)`, `base_vars` is the
intermediate Tensor used in the chain-rule / "fast" clipping trick, and
`sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that
represents the output of the call `tape.gradient(summed_loss, base_vars)`
where `tape` is a `tf.GradientTape` instance that records the dense
layer computation and `summed_loss` is the sum of the per-example losses
of the underlying model. This function then returns the per-example squared
L2 gradient norms of the trainable variables in `layer_instance`. These
squared norms should be a 1D `tf.Tensor` of length `batch_size`.
"""
if hasattr(layer_instance, "sparse"): # for backwards compatibility
if layer_instance.sparse:
raise NotImplementedError("Sparse output vectors are not supported.")
if len(inputs[0].shape) != 2:
raise NotImplementedError("Only 2D embedding inputs are supported.")
# The logic below is applied to properly handle repeated embedding indices.
# Specifically, sqr_grad_norms will contain the total counts of each embedding
# index (see how it is processed in the combine_pre_and_post_sqr_norms()
# function in clip_grads.py). An example is as follows:
#
# inputs =
# [[0 0 0 1 2 2],
# [0 2 2 2 1 1]]
#
# counts =
# [[3 1 2]
# [1 2 3]]
#
# input_counts =
# [[3 3 3 1 2 2],
# [1 3 3 3 2 2]]
#
base_vars = layer_instance(*inputs)
def sqr_norm_fn(base_vars_grads):
indices = tf.cast(*inputs, tf.int32)
if isinstance(indices, tf.SparseTensor):
indices = tf.sparse.to_dense(indices)
counts = tf.math.bincount(indices, axis=-1)
input_counts = tf.expand_dims(
tf.cast(tf.gather(counts, indices, batch_dims=1), base_vars.dtype),
axis=-1,
)
scaled_grads = input_counts * tf.square(base_vars_grads)
reduction_axes = tf.range(1, tf.rank(scaled_grads))
return tf.reduce_sum(scaled_grads, axis=reduction_axes)
return base_vars, None, sqr_norm_fn
# ==============================================================================
# Main factory methods
# ==============================================================================
def make_default_layer_registry():
registry = LayerRegistry()
registry.insert(tf.keras.layers.Dense, dense_layer_computation)
registry.insert(tf.keras.layers.Embedding, embedding_layer_computation)
return registry

View file

@ -1,143 +0,0 @@
# Copyright 2022, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generates a default layer registry.
Defines "fast" gradient norm layer registry functions for use in the "fast"
gradient clipping algorithm. Specifically, each registry function takes
in two inputs (i) a layer instance and (ii) `tf.Tensor` inputs to produce three
outputs: (a) a `tf.Tensor` `G` of gradient norms, (b) a differentiable
`tf.Tensor` `Z`, and (c) either `None` or a function that maps the object in
(b) to the layer instance's output when using the inputs in (ii). If (c) is
`None`, then (b) contains the layer outputs.
When a layer registry function is defined, it is generally assumed that the
following relation holds for each pair `(g, z)` in `zip(G, Z)`:
`|dL/dw|^2 == |dL/dz|^2 * g^2`
where `L` is any per-example loss function and `w` are the trainable variables
corresponding to `(g, z)`.
For example, this relation holds if the layer instance is tf.keras.layers.Dense,
Z contains the pre-activation tensors, i.e., `z = X * w` for input `X`, and `g`
is the norm of the input corresponding to the given per-example loss (see the
formulae in https://arxiv.org/abs/1510.01799 for more details).
The registry functions are registered in a `dict` (registry) whose key is the
hash of the layer class and whose value is the registry function.
"""
import tensorflow as tf
# ==============================================================================
# Supported Keras layers
# ==============================================================================
def dense_layer_computation(layer_instance, inputs):
"""Registry function for `tf.keras.layers.Dense`.
The logic for this computation is based on the following paper:
https://arxiv.org/abs/1510.01799
For the sake of efficiency, we fuse the variables and square grad norms
for the kernel weights and bias vector together.
Args:
layer_instance: A `tf.keras.layers.Dense` instance.
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
`layer_instance(inputs)` returns a valid output.
Returns:
A `tuple` `(sqr_grad_norms, base_vars, transform)`, where `norms` is a 1D
`tf.Tensor` of the squared l2-norms of the input tensors, `base_vars` is the
intermediate Tensor used in the chain-rule / "fast" clipping trick, and
`transform` is a function that maps `base_vars` to the layer outputs.
"""
orig_activation = layer_instance.activation
layer_instance.activation = None
base_vars = layer_instance(*inputs)
sqr_inputs = tf.square(*inputs)
inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs))
sqr_grad_norms = tf.reduce_sum(tf.square(*inputs), axis=inputs_reduction_axes)
if layer_instance.use_bias:
# Adding a bias term is equivalent to a layer with no bias term and which
# adds an additional variable to the layer input that only takes a constant
# value of 1.0. This is thus equivalent to adding 1.0 to the sum of the
# squared values of the inputs.
sqr_grad_norms += tf.cast(1.0, dtype=sqr_grad_norms.dtype)
layer_instance.activation = orig_activation
return sqr_grad_norms, base_vars, layer_instance.activation
def embedding_layer_computation(layer_instance, inputs):
"""Registry function for `tf.keras.layers.Embedding`.
The logic of this computation is based on the `tf.keras.layers.Dense`
computation and the fact that an embedding layer is just a dense layer
with no activation function and an output vector of the form X*W for input
X, where the i-th row of W is the i-th embedding vector and the j-th row of
X is a one-hot vector representing the input of example j.
Args:
layer_instance: A `tf.keras.layers.Embedding` instance.
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
`layer_instance(inputs)` returns a valid output.
Returns:
A `tuple` `(sqr_grad_norms, base_vars, None)`, where `sqr_grad_norms` is
a `tf.Tensor` that is related to the squared l2-norms of the input tensors
and `base_vars` is the intermediate Tensor used in the chain-rule / "fast"
clipping trick.
"""
if hasattr(layer_instance, "sparse"): # for backwards compatibility
if layer_instance.sparse:
raise NotImplementedError("Sparse output vectors are not supported.")
if tf.rank(*inputs) != 2:
raise NotImplementedError("Only 2D embedding inputs are supported.")
# The logic below is applied to properly handle repeated embedding indices.
# Specifically, sqr_grad_norms will contain the total counts of each embedding
# index (see how it is processed in the combine_pre_and_post_sqr_norms()
# function in clip_grads.py). An example is as follows:
#
# inputs =
# [[0 0 0 1 2 2],
# [0 2 2 2 1 1]]
#
# counts =
# [[3 1 2]
# [1 2 3]]
#
# sqr_grad_norms =
# [[3 3 3 1 2 2],
# [1 3 3 3 2 2]]
#
base_vars = layer_instance(*inputs)
indices = tf.cast(*inputs, tf.int32)
if isinstance(indices, tf.SparseTensor):
indices = tf.sparse.to_dense(indices)
counts = tf.math.bincount(indices, axis=-1)
sqr_grad_norms = tf.cast(
tf.gather(counts, indices, batch_dims=1), base_vars.dtype
)
return sqr_grad_norms, base_vars, None
# ==============================================================================
# Main factory methods
# ==============================================================================
def make_default_layer_registry():
registry = {}
registry[hash(tf.keras.layers.Dense)] = dense_layer_computation
registry[hash(tf.keras.layers.Embedding)] = embedding_layer_computation
return registry

View file

@ -18,7 +18,7 @@ py_library(
deps = [
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry_factories",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
],
)
@ -28,7 +28,7 @@ py_test(
python_version = "PY3",
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry_factories",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
"//tensorflow_privacy/privacy/keras_models:dp_keras_model",
],
)

View file

@ -96,13 +96,10 @@ def make_dp_model_class(cls):
noise_multiplier: Ratio of the standard deviation to the clipping norm.
num_microbatches: Number of microbatches.
use_xla: If `True`, compiles train_step to XLA.
layer_registry: A `dict` of layers that support "fast" gradient norm
computations. The key is the class of the layer and the value is a
function that returns a `tuple` `(output, sqr_grad_norms, vars)`,
where `output` is the pre-activator tensor, `sqr_grad_norms` is
related to the squared norms of a layer's pre-activation tensor, and
`vars` are relevant trainable weights (see
`layer_registry_factories.py` for examples).
layer_registry: A `LayerRegistry` instance containing functions that
help compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
"""

View file

@ -15,7 +15,7 @@
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry_factories
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.keras_models import dp_keras_model
@ -32,7 +32,10 @@ def get_layer_registries():
# Outputs a list of testable layer registries.
# The empty registry {} tests the behavior of the standard approach,
# while the other one tests the fast gradient clipping algorithm.
return [{}, layer_registry_factories.make_default_layer_registry()]
return [
layer_registry.LayerRegistry(),
layer_registry.make_default_layer_registry(),
]
class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):