Generalize the internal API to allow for more general models + layers.
PiperOrigin-RevId: 509518753
This commit is contained in:
parent
6ee988885a
commit
410814ec39
9 changed files with 340 additions and 300 deletions
|
@ -9,8 +9,8 @@ py_library(
|
|||
)
|
||||
|
||||
py_library(
|
||||
name = "layer_registry_factories",
|
||||
srcs = ["layer_registry_factories.py"],
|
||||
name = "layer_registry",
|
||||
srcs = ["layer_registry.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
|
@ -20,7 +20,7 @@ py_library(
|
|||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":gradient_clipping_utils",
|
||||
":layer_registry_factories",
|
||||
":layer_registry",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -31,6 +31,6 @@ py_test(
|
|||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":clip_grads",
|
||||
":layer_registry_factories",
|
||||
":layer_registry",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -25,40 +25,42 @@ import tensorflow as tf
|
|||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||
|
||||
|
||||
def combine_pre_and_post_sqr_norms(pre_sqr_norm, post_grad, layer_hash):
|
||||
"""Combines pre and post-activation tensors for a given variable.
|
||||
|
||||
The logic for combining norms depends on the variable's underlying layer.
|
||||
|
||||
Args:
|
||||
pre_sqr_norm: A `tf.Tensor` whose first dimension is the batch dimension.
|
||||
Contains squared norms that are related to the pre-activation Tensor.
|
||||
post_grad: A `tf.Tensor` whose first dimension is the batch dimension.
|
||||
Contains gradients that are related to the post-activation Tensor.
|
||||
layer_hash: A `float` that is the hash of the variable's underlying layer
|
||||
class.
|
||||
|
||||
Returns:
|
||||
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
|
||||
per-example loss function with respect to the given variable.
|
||||
"""
|
||||
post_sqr_grads = tf.square(post_grad)
|
||||
if layer_hash == hash(tf.keras.layers.Embedding):
|
||||
scaled_grads = tf.expand_dims(pre_sqr_norm, axis=-1) * post_sqr_grads
|
||||
reduction_axes = tf.range(1, tf.rank(scaled_grads))
|
||||
return tf.reduce_sum(scaled_grads, axis=reduction_axes)
|
||||
def get_registry_generator_fn(tape, layer_registry):
|
||||
"""Creates the generator function for `compute_gradient_norms()`."""
|
||||
if layer_registry is None:
|
||||
# Needed for backwards compatibility.
|
||||
registry_generator_fn = None
|
||||
else:
|
||||
reduction_axes = tf.range(1, tf.rank(post_sqr_grads))
|
||||
post_sqr_norm = tf.reduce_sum(post_sqr_grads, axis=reduction_axes)
|
||||
return pre_sqr_norm * post_sqr_norm
|
||||
|
||||
def registry_generator_fn(layer_instance, args, kwargs):
|
||||
if layer_instance.trainable_variables:
|
||||
# Only trainable variables factor into the gradient.
|
||||
if not layer_registry.is_elem(layer_instance):
|
||||
raise NotImplementedError(
|
||||
'Layer %s is not in the registry of known layers that can '
|
||||
'be used for efficient gradient clipping.'
|
||||
% layer_instance.__class__.__name__
|
||||
)
|
||||
registry_fn = layer_registry.lookup(layer_instance)
|
||||
(layer_vars, transform, layer_sqr_norm_fn) = registry_fn(
|
||||
layer_instance, args
|
||||
)
|
||||
if tape is not None:
|
||||
tape.watch(layer_vars)
|
||||
layer_outputs = transform(layer_vars) if transform else layer_vars
|
||||
return layer_outputs, (layer_vars, layer_sqr_norm_fn)
|
||||
else:
|
||||
# Non-trainable layer.
|
||||
return layer_instance(*args, **kwargs), None
|
||||
|
||||
return registry_generator_fn
|
||||
|
||||
|
||||
def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
|
||||
"""Computes the per-example loss gradient norms for given data.
|
||||
|
||||
Applies the approach given in https://arxiv.org/pdf/2009.03106.pdf, except
|
||||
the batch matrix multiplication operation in Algorithm 2 is replaced with
|
||||
the computation of two norm computations.
|
||||
Applies a variant of the approach given in
|
||||
https://arxiv.org/pdf/2009.03106.pdf
|
||||
|
||||
Args:
|
||||
input_model: The `tf.keras.Model` from which to obtain the layers from. The
|
||||
|
@ -68,24 +70,22 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
|
|||
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
||||
must be the batch dimension. The number of examples should match the
|
||||
number of examples in `x_batch`.
|
||||
layer_registry: A `dict` of layers that support "fast" gradient norm
|
||||
computations. The key is the class of the layer and the value is a
|
||||
function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
|
||||
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
|
||||
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
||||
trainable weights (see `layer_registry_factories.py` for examples).
|
||||
layer_registry: A `LayerRegistry` instance containing functions that help
|
||||
compute gradient norms quickly. See
|
||||
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
|
||||
more details.
|
||||
|
||||
Returns:
|
||||
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
|
||||
per-example loss function.
|
||||
"""
|
||||
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
||||
# First loop computes the norms of the layer inputs, caches these inputs,
|
||||
# and computes the summed loss.
|
||||
registry_generator_fn = get_registry_generator_fn(tape, layer_registry)
|
||||
# First loop computes the model outputs, summed loss, and generator outputs.
|
||||
with tape:
|
||||
model_outputs, pre_norm_list, var_list, layer_hash_list = (
|
||||
gradient_clipping_utils.forward_norm_pass(
|
||||
input_model, x_batch, tape, layer_registry
|
||||
model_outputs, generator_outputs_list = (
|
||||
gradient_clipping_utils.model_forward_pass(
|
||||
input_model, x_batch, generator_fn=registry_generator_fn
|
||||
)
|
||||
)
|
||||
# Ignore the original loss function's reduction to get per-example loss.
|
||||
|
@ -94,20 +94,19 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
|
|||
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
||||
losses = per_example_loss_fn(y_batch, model_outputs)
|
||||
summed_loss = tf.reduce_sum(losses)
|
||||
# Second loop computes the norm of the gradient of the loss with respect to
|
||||
# the pre-activation tensors, and multiplies these norms with the results of
|
||||
# the first loop.
|
||||
full_norm_list = []
|
||||
grads = tape.gradient(summed_loss, var_list)
|
||||
for i in range(len(var_list)):
|
||||
full_norm = combine_pre_and_post_sqr_norms(
|
||||
pre_norm_list[i], grads[i], layer_hash_list[i]
|
||||
)
|
||||
full_norm_list.append(full_norm)
|
||||
# Unwrap the generator outputs so that the next loop avoids duplicating
|
||||
# backprop ops.
|
||||
filtered_outputs = [t for t in generator_outputs_list if t is not None]
|
||||
vars_list = [a for (a, b) in filtered_outputs]
|
||||
sqr_norm_fns_list = [b for (a, b) in filtered_outputs]
|
||||
# Second loop evaluates the squared L2 norm functions and appends the results.
|
||||
grads_list = tape.gradient(summed_loss, vars_list)
|
||||
sqr_norm_list = []
|
||||
for grads, f in zip(grads_list, sqr_norm_fns_list):
|
||||
sqr_norm_list.append(f(grads))
|
||||
del tape
|
||||
# Post-processing for compatibility with non-eager mode (very annoying).
|
||||
full_norm_tsr = tf.stack(full_norm_list, axis=1)
|
||||
return tf.sqrt(tf.reduce_sum(full_norm_tsr, axis=1))
|
||||
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
|
||||
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
||||
|
||||
|
||||
def compute_clip_weights(l2_norm_clip, gradient_norms):
|
||||
|
|
|
@ -17,7 +17,7 @@ from absl.testing import parameterized
|
|||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry_factories
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -81,7 +81,7 @@ def get_computed_and_true_norms(
|
|||
)
|
||||
y_pred = model(x_input)
|
||||
y_batch = tf.ones_like(y_pred)
|
||||
registry = layer_registry_factories.make_default_layer_registry()
|
||||
registry = layer_registry.make_default_layer_registry()
|
||||
computed_norms = clip_grads.compute_gradient_norms(
|
||||
model, x_input, y_batch, layer_registry=registry
|
||||
)
|
||||
|
|
|
@ -28,8 +28,19 @@ def has_internal_compute_graph(input_object):
|
|||
)
|
||||
|
||||
|
||||
def forward_norm_pass(input_model, x_batch, tape, layer_registry):
|
||||
"""Does a forward pass of a model and returns some useful intermediates.
|
||||
def _get_internal_layers(input_layer):
|
||||
"""Returns a list of layers that are nested within a given layer."""
|
||||
internal_layers = []
|
||||
if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'):
|
||||
for layer in input_layer.layers:
|
||||
internal_layers.extend(_get_internal_layers(layer))
|
||||
else:
|
||||
internal_layers.append(input_layer)
|
||||
return internal_layers
|
||||
|
||||
|
||||
def model_forward_pass(input_model, inputs, generator_fn=None):
|
||||
"""Does a forward pass of a model and returns useful intermediates.
|
||||
|
||||
NOTE: the graph traversal algorithm is an adaptation of the logic in the
|
||||
_run_internal_graph() method in the functional.Functional class. Hence,
|
||||
|
@ -37,44 +48,42 @@ def forward_norm_pass(input_model, x_batch, tape, layer_registry):
|
|||
instance is an instance of the functional.Functional class.
|
||||
|
||||
Args:
|
||||
input_model: A Keras functional model to compute the quantities for.
|
||||
x_batch: A collection of Tensors to be fed into the input layer of the
|
||||
model.
|
||||
tape: A tf.GradientTape() instance used to record certain operations.
|
||||
Assumes that this function is being called inside of this tape.
|
||||
layer_registry: A dictionary of layers that support "fast" gradient norm
|
||||
computations. The key is the class of the layer and the value is a
|
||||
function that returns a triple (norm_list, var_list, transform). For more
|
||||
details, see `layer_registry_factories.py`.
|
||||
input_model: A `tf.keras.Model` to compute the quantities for.
|
||||
inputs: Arbitrary input to be fed into the input layer of the model. It is
|
||||
expected that `input_model(inputs)` returns a valid output.
|
||||
generator_fn: A function with signature `(tf.keras.layers.Layer, Any, Any)
|
||||
-> (tf.Tensor, Any)`, where we require `generator_fn(layer_instance, args,
|
||||
kwargs)[0] == layer_instance(*args, **kwargs)`. If `None`, then
|
||||
`layer_fn(layer_instance, args, kwargs)[1] == None`.
|
||||
|
||||
Returns:
|
||||
Four objects (outputs, norm_list, var_list, layer_hash_list). The first are
|
||||
the outputs that are generated as a result of a forward pass. The second is
|
||||
either `None` or a collection of squared norms for each example that helps
|
||||
in computing the gradient norm of an example (if such a "nice" computation
|
||||
exists). The third is an ordered list of tf.Tensor() objects that are
|
||||
intended to be evaluated with respect to the summed loss of the model. The
|
||||
fourth is a list whose i-th element is the hash of the layer class of the
|
||||
i-th element of var_list.
|
||||
A `tuple` `(outputs, generator_outputs_list)`. `outputs` is the
|
||||
`tf.Tensor` that is generated as a result of a forward pass.
|
||||
`generator_outputs_list` is a `list` whose i-th entry is the output of
|
||||
`generator_fn(lyr, args, kwargs)[1]` where `lyr` is the i-th
|
||||
layer when the compute graph of `input_model` is traversed in BFS order.
|
||||
"""
|
||||
# TODO: Avoid or remove the references to protected methods of `input_model`. # pylint: disable=g-bad-todo
|
||||
|
||||
# Default generator.
|
||||
generator_outputs_list = []
|
||||
if generator_fn is None:
|
||||
|
||||
def generator_fn(layer_instance, args, kwargs):
|
||||
return layer_instance(*args, **kwargs), None
|
||||
|
||||
# Prepare the inputs and BFS variables.
|
||||
flattened_inputs = input_model._flatten_to_reference_inputs(x_batch) # pylint: disable=protected-access
|
||||
flattened_inputs = input_model._flatten_to_reference_inputs(inputs) # pylint: disable=protected-access
|
||||
tensor_dict = {}
|
||||
tensor_usage_count = input_model._tensor_usage_count # pylint: disable=protected-access
|
||||
for x, y in zip(input_model.inputs, flattened_inputs):
|
||||
y = input_model._conform_to_reference_input(y, ref_input=x) # pylint: disable=protected-access
|
||||
x_id = str(id(x))
|
||||
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
|
||||
|
||||
# Main computations.
|
||||
nodes_by_depth = input_model._nodes_by_depth # pylint: disable=protected-access
|
||||
depth_keys = list(nodes_by_depth.keys())
|
||||
depth_keys.sort(reverse=True)
|
||||
var_list = []
|
||||
norm_list = []
|
||||
layer_hash_list = []
|
||||
node_outputs = None
|
||||
|
||||
# Perform BFS feedforward computations.
|
||||
for depth in depth_keys:
|
||||
for node in nodes_by_depth[depth]:
|
||||
|
@ -85,62 +94,28 @@ def forward_norm_pass(input_model, x_batch, tape, layer_registry):
|
|||
args, kwargs = node.map_arguments(tensor_dict)
|
||||
if has_internal_compute_graph(node.layer):
|
||||
# If this node has an internal computational graph, we can recurse.
|
||||
node_outputs, node_norm_list, node_var_list, node_layer_hash_list = (
|
||||
forward_norm_pass(node.layer, args, tape, layer_registry)
|
||||
node_layer_outputs, node_generator_outputs = model_forward_pass(
|
||||
node.layer, args, generator_fn
|
||||
)
|
||||
var_list += node_var_list
|
||||
norm_list += node_norm_list
|
||||
layer_hash_list += node_layer_hash_list
|
||||
generator_outputs_list.extend(node_generator_outputs)
|
||||
else:
|
||||
# Either pass through or record some metadata.
|
||||
if not node.layer.trainable_variables:
|
||||
node_outputs = node.layer(*args, **kwargs)
|
||||
else:
|
||||
lyr_hash = hash(node.layer.__class__)
|
||||
if lyr_hash not in layer_registry:
|
||||
raise NotImplementedError(
|
||||
'Layer %s is not in the registry of known layers that can'
|
||||
'be used for efficient gradient clipping.'
|
||||
% node.layer.__class__.__name__
|
||||
)
|
||||
lyr = layer_registry[lyr_hash]
|
||||
node_norms, node_vars, transform = lyr(node.layer, args)
|
||||
tape.watch(node_vars)
|
||||
node_outputs = transform(node_vars) if transform else node_vars
|
||||
var_list.append(node_vars)
|
||||
norm_list.append(node_norms)
|
||||
layer_hash_list.append(lyr_hash)
|
||||
# update the current dictionary of inputs for the next node.
|
||||
for x_id, y in zip(node.flat_output_ids, tf.nest.flatten(node_outputs)):
|
||||
# Otherwise, we parse the node directly.
|
||||
node_layers = _get_internal_layers(node.layer)
|
||||
for layer in node_layers:
|
||||
node_layer_outputs, layer_generator_outputs = generator_fn(
|
||||
layer, args, kwargs
|
||||
)
|
||||
generator_outputs_list.append(layer_generator_outputs)
|
||||
args = (node_layer_outputs,)
|
||||
kwargs = {}
|
||||
|
||||
# Update the current dictionary of inputs for the next node.
|
||||
for x_id, y in zip(
|
||||
node.flat_output_ids, tf.nest.flatten(node_layer_outputs)
|
||||
):
|
||||
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
|
||||
|
||||
return node_outputs, norm_list, var_list, layer_hash_list
|
||||
|
||||
|
||||
def get_trainable_hidden_layers(input_model):
|
||||
"""Obtains the trainable hidden layers of a Keras model.
|
||||
|
||||
Args:
|
||||
input_model: The Keras model to obtain the layers from.
|
||||
|
||||
Returns:
|
||||
A list of Keras layers where the tensorflow.keras.layers.Layer
|
||||
ancestor class MUST precede any existing tensorflow.keras.models.Model
|
||||
ancestor class.
|
||||
"""
|
||||
hidden_layers = []
|
||||
for l in input_model.layers:
|
||||
for c in l.__class__.__mro__:
|
||||
if c == tf.keras.models.Model:
|
||||
hidden_layers += get_trainable_hidden_layers(l)
|
||||
break
|
||||
elif c == tf.keras.layers.InputLayer:
|
||||
break
|
||||
elif c == tf.keras.layers.Layer:
|
||||
if l.trainable_variables:
|
||||
hidden_layers.append(l)
|
||||
break
|
||||
return hidden_layers
|
||||
return node_layer_outputs, generator_outputs_list
|
||||
|
||||
|
||||
def all_trainable_layers_are_registered(input_model, layer_registry):
|
||||
|
@ -148,21 +123,19 @@ def all_trainable_layers_are_registered(input_model, layer_registry):
|
|||
|
||||
Args:
|
||||
input_model: The Keras model from which to obtain the layers from.
|
||||
layer_registry: A dictionary of layers that support "fast" gradient norm
|
||||
computations. The key is the class of the layer and the value is a
|
||||
function that returns a triple (output, sqr_grad_norms, vars), where
|
||||
output is the pre-activator tensor, sqr_grad_norms is the square of the
|
||||
norm of the layer's input, and vars is an ordered list of the trainable
|
||||
weights.
|
||||
layer_registry: A `LayerRegistry` instance containing functions that help
|
||||
compute gradient norms quickly. See
|
||||
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
|
||||
more details.
|
||||
|
||||
Returns:
|
||||
True if all the trainable layers in `input_model` are in `layer_registry`.
|
||||
False otherwise.
|
||||
"""
|
||||
hidden_layers = get_trainable_hidden_layers(input_model)
|
||||
for l in hidden_layers:
|
||||
if hash(l.__class__) not in layer_registry:
|
||||
return False
|
||||
for layer in input_model.layers:
|
||||
for sublayer in _get_internal_layers(layer):
|
||||
if not layer_registry.is_elem(sublayer) and sublayer.trainable_variables:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
|
@ -212,3 +185,19 @@ def add_aggregate_noise(
|
|||
)
|
||||
|
||||
return tf.nest.map_structure(add_noise, clipped_grads)
|
||||
|
||||
|
||||
def generate_model_outputs_using_core_keras_layers(input_model):
|
||||
"""Returns the model outputs generated by only core Keras layers."""
|
||||
cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects())
|
||||
cust_hash_set = set([hash(v) for v in cust_obj_dict.values()])
|
||||
|
||||
def generator_fn(layer_instance, args, kwargs):
|
||||
if hash(layer_instance.__class__) in cust_hash_set:
|
||||
# Using `.call()` does not register the layer in the compute graph of
|
||||
# a forward pass.
|
||||
return layer_instance.call(*args, **kwargs), None
|
||||
else:
|
||||
return layer_instance(*args, **kwargs), None
|
||||
|
||||
return model_forward_pass(input_model, input_model.inputs, generator_fn)[0]
|
||||
|
|
|
@ -0,0 +1,195 @@
|
|||
# Copyright 2022, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Defines the layer registry class and useful factory functions.
|
||||
|
||||
Defines "fast" gradient norm layer registry functions for use in the "fast"
|
||||
gradient clipping algorithm. Specifically, each registry function takes
|
||||
in two inputs (i) a layer instance and (ii) `tf.Tensor` inputs to produce three
|
||||
outputs: (a) a differentiable `tf.Tensor` `Z`, (b) either `None` or a function
|
||||
that maps the object in (a) to the layer instance's output when using the
|
||||
inputs in (ii), and (c) a function `F` that generates the per-example
|
||||
squared gradient norms when it is fed an object representing the gradient of
|
||||
the summed loss with respect to `Z` in (a). If (b) is `None`, then (a) is
|
||||
expected to contain the layer outputs.
|
||||
|
||||
When a layer registry function is defined, it is generally assumed that the
|
||||
following relation holds:
|
||||
|
||||
`|dL/dW|^2 == F(grad_Z)`
|
||||
|
||||
where `gradient_Z` is the gradient of the summed loss with respect to `Z`.
|
||||
|
||||
For example, if the layer instance is tf.keras.layers.Dense, Z contains the
|
||||
pre-activation tensors, i.e., `z = X * w` for input `X`, and `g` is a tensor
|
||||
whose i-th entry is the L2 norm of the i-th input vector, then
|
||||
|
||||
`F(grad_Z) = g^2 * l2_row_norm(grad_Z)^2`,
|
||||
|
||||
where `l2_row_norm(y)` computes the L2 norm for each row of an input `y`.
|
||||
Details of this decomposition can be found in https://arxiv.org/abs/1510.01799
|
||||
"""
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main class
|
||||
# ==============================================================================
|
||||
class LayerRegistry:
|
||||
"""Custom container for layer registry functions."""
|
||||
|
||||
def __init__(self):
|
||||
"""Basic initialization of various internal dictionaries."""
|
||||
self._layer_class_dict = {}
|
||||
self._registry = {}
|
||||
|
||||
def is_elem(self, layer_instance):
|
||||
"""Checks if a layer instance's class is in the registry."""
|
||||
return hash(layer_instance.__class__) in self._registry
|
||||
|
||||
def lookup(self, layer_instance):
|
||||
"""Returns the layer registry function for a given layer instance."""
|
||||
return self._registry[hash(layer_instance.__class__)]
|
||||
|
||||
def insert(self, layer_class, layer_registry_function):
|
||||
"""Inserts a layer registry function into the internal dictionaries."""
|
||||
layer_key = hash(layer_class)
|
||||
self._layer_class_dict[layer_key] = layer_class
|
||||
self._registry[layer_key] = layer_registry_function
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Supported Keras layers
|
||||
# ==============================================================================
|
||||
def dense_layer_computation(layer_instance, inputs):
|
||||
"""Registry function for `tf.keras.layers.Dense`.
|
||||
|
||||
The logic for this computation is based on the following paper:
|
||||
https://arxiv.org/abs/1510.01799
|
||||
|
||||
For the sake of efficiency, we fuse the variables and square grad norms
|
||||
for the kernel weights and bias vector together.
|
||||
|
||||
Args:
|
||||
layer_instance: A `tf.keras.layers.Dense` instance.
|
||||
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
||||
`layer_instance(inputs)` returns a valid output.
|
||||
|
||||
Returns:
|
||||
A `tuple` `(base_vars, transform, sqr_norm_fn)`. `base_vars` is the
|
||||
intermediate Tensor used in the chain-rule / "fast" clipping trick,
|
||||
`transform` is a function that maps `base_vars` to the layer outputs, and
|
||||
`sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that
|
||||
represents the output of the call `tape.gradient(summed_loss, base_vars)`
|
||||
where `tape` is a `tf.GradientTape` instance that records the dense
|
||||
layer computation and `summed_loss` is the sum of the per-example losses
|
||||
of the underlying model. This function then returns the per-example squared
|
||||
L2 gradient norms of the trainable variables in `layer_instance`. These
|
||||
squared norms should be a 1D `tf.Tensor` of length `batch_size`.
|
||||
"""
|
||||
orig_activation = layer_instance.activation
|
||||
layer_instance.activation = None
|
||||
base_vars = layer_instance(*inputs)
|
||||
layer_instance.activation = orig_activation
|
||||
def sqr_norm_fn(base_vars_grads):
|
||||
sqr_inputs = tf.square(*inputs)
|
||||
inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs))
|
||||
input_sqr_norms = tf.reduce_sum(sqr_inputs, axis=inputs_reduction_axes)
|
||||
if layer_instance.use_bias:
|
||||
# Adding a bias term is equivalent to a layer with no bias term and which
|
||||
# adds an additional variable to the layer input that only takes a
|
||||
# constant value of 1.0. This is thus equivalent to adding 1.0 to the sum
|
||||
# of the squared values of the inputs.
|
||||
input_sqr_norms += tf.cast(1.0, dtype=input_sqr_norms.dtype)
|
||||
reduction_axes = tf.range(1, tf.rank(base_vars_grads))
|
||||
base_vars_sqr_norms = tf.reduce_sum(
|
||||
tf.square(base_vars_grads), axis=reduction_axes
|
||||
)
|
||||
return input_sqr_norms * base_vars_sqr_norms
|
||||
|
||||
return base_vars, layer_instance.activation, sqr_norm_fn
|
||||
|
||||
|
||||
def embedding_layer_computation(layer_instance, inputs):
|
||||
"""Registry function for `tf.keras.layers.Embedding`.
|
||||
|
||||
The logic of this computation is based on the `tf.keras.layers.Dense`
|
||||
computation and the fact that an embedding layer is just a dense layer
|
||||
with no activation function and an output vector of the form X*W for input
|
||||
X, where the i-th row of W is the i-th embedding vector and the j-th row of
|
||||
X is a one-hot vector representing the input of example j.
|
||||
|
||||
Args:
|
||||
layer_instance: A `tf.keras.layers.Embedding` instance.
|
||||
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
||||
`layer_instance(inputs)` returns a valid output.
|
||||
|
||||
Returns:
|
||||
A `tuple` `(base_vars, None, sqr_norm_fn)`, `base_vars` is the
|
||||
intermediate Tensor used in the chain-rule / "fast" clipping trick, and
|
||||
`sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that
|
||||
represents the output of the call `tape.gradient(summed_loss, base_vars)`
|
||||
where `tape` is a `tf.GradientTape` instance that records the dense
|
||||
layer computation and `summed_loss` is the sum of the per-example losses
|
||||
of the underlying model. This function then returns the per-example squared
|
||||
L2 gradient norms of the trainable variables in `layer_instance`. These
|
||||
squared norms should be a 1D `tf.Tensor` of length `batch_size`.
|
||||
"""
|
||||
if hasattr(layer_instance, "sparse"): # for backwards compatibility
|
||||
if layer_instance.sparse:
|
||||
raise NotImplementedError("Sparse output vectors are not supported.")
|
||||
if len(inputs[0].shape) != 2:
|
||||
raise NotImplementedError("Only 2D embedding inputs are supported.")
|
||||
# The logic below is applied to properly handle repeated embedding indices.
|
||||
# Specifically, sqr_grad_norms will contain the total counts of each embedding
|
||||
# index (see how it is processed in the combine_pre_and_post_sqr_norms()
|
||||
# function in clip_grads.py). An example is as follows:
|
||||
#
|
||||
# inputs =
|
||||
# [[0 0 0 1 2 2],
|
||||
# [0 2 2 2 1 1]]
|
||||
#
|
||||
# counts =
|
||||
# [[3 1 2]
|
||||
# [1 2 3]]
|
||||
#
|
||||
# input_counts =
|
||||
# [[3 3 3 1 2 2],
|
||||
# [1 3 3 3 2 2]]
|
||||
#
|
||||
base_vars = layer_instance(*inputs)
|
||||
def sqr_norm_fn(base_vars_grads):
|
||||
indices = tf.cast(*inputs, tf.int32)
|
||||
if isinstance(indices, tf.SparseTensor):
|
||||
indices = tf.sparse.to_dense(indices)
|
||||
counts = tf.math.bincount(indices, axis=-1)
|
||||
input_counts = tf.expand_dims(
|
||||
tf.cast(tf.gather(counts, indices, batch_dims=1), base_vars.dtype),
|
||||
axis=-1,
|
||||
)
|
||||
scaled_grads = input_counts * tf.square(base_vars_grads)
|
||||
reduction_axes = tf.range(1, tf.rank(scaled_grads))
|
||||
return tf.reduce_sum(scaled_grads, axis=reduction_axes)
|
||||
|
||||
return base_vars, None, sqr_norm_fn
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main factory methods
|
||||
# ==============================================================================
|
||||
def make_default_layer_registry():
|
||||
registry = LayerRegistry()
|
||||
registry.insert(tf.keras.layers.Dense, dense_layer_computation)
|
||||
registry.insert(tf.keras.layers.Embedding, embedding_layer_computation)
|
||||
return registry
|
|
@ -1,143 +0,0 @@
|
|||
# Copyright 2022, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Generates a default layer registry.
|
||||
|
||||
Defines "fast" gradient norm layer registry functions for use in the "fast"
|
||||
gradient clipping algorithm. Specifically, each registry function takes
|
||||
in two inputs (i) a layer instance and (ii) `tf.Tensor` inputs to produce three
|
||||
outputs: (a) a `tf.Tensor` `G` of gradient norms, (b) a differentiable
|
||||
`tf.Tensor` `Z`, and (c) either `None` or a function that maps the object in
|
||||
(b) to the layer instance's output when using the inputs in (ii). If (c) is
|
||||
`None`, then (b) contains the layer outputs.
|
||||
|
||||
When a layer registry function is defined, it is generally assumed that the
|
||||
following relation holds for each pair `(g, z)` in `zip(G, Z)`:
|
||||
|
||||
`|dL/dw|^2 == |dL/dz|^2 * g^2`
|
||||
|
||||
where `L` is any per-example loss function and `w` are the trainable variables
|
||||
corresponding to `(g, z)`.
|
||||
|
||||
For example, this relation holds if the layer instance is tf.keras.layers.Dense,
|
||||
Z contains the pre-activation tensors, i.e., `z = X * w` for input `X`, and `g`
|
||||
is the norm of the input corresponding to the given per-example loss (see the
|
||||
formulae in https://arxiv.org/abs/1510.01799 for more details).
|
||||
|
||||
The registry functions are registered in a `dict` (registry) whose key is the
|
||||
hash of the layer class and whose value is the registry function.
|
||||
"""
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Supported Keras layers
|
||||
# ==============================================================================
|
||||
def dense_layer_computation(layer_instance, inputs):
|
||||
"""Registry function for `tf.keras.layers.Dense`.
|
||||
|
||||
The logic for this computation is based on the following paper:
|
||||
https://arxiv.org/abs/1510.01799
|
||||
|
||||
For the sake of efficiency, we fuse the variables and square grad norms
|
||||
for the kernel weights and bias vector together.
|
||||
|
||||
Args:
|
||||
layer_instance: A `tf.keras.layers.Dense` instance.
|
||||
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
||||
`layer_instance(inputs)` returns a valid output.
|
||||
|
||||
Returns:
|
||||
A `tuple` `(sqr_grad_norms, base_vars, transform)`, where `norms` is a 1D
|
||||
`tf.Tensor` of the squared l2-norms of the input tensors, `base_vars` is the
|
||||
intermediate Tensor used in the chain-rule / "fast" clipping trick, and
|
||||
`transform` is a function that maps `base_vars` to the layer outputs.
|
||||
"""
|
||||
orig_activation = layer_instance.activation
|
||||
layer_instance.activation = None
|
||||
base_vars = layer_instance(*inputs)
|
||||
sqr_inputs = tf.square(*inputs)
|
||||
inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs))
|
||||
sqr_grad_norms = tf.reduce_sum(tf.square(*inputs), axis=inputs_reduction_axes)
|
||||
if layer_instance.use_bias:
|
||||
# Adding a bias term is equivalent to a layer with no bias term and which
|
||||
# adds an additional variable to the layer input that only takes a constant
|
||||
# value of 1.0. This is thus equivalent to adding 1.0 to the sum of the
|
||||
# squared values of the inputs.
|
||||
sqr_grad_norms += tf.cast(1.0, dtype=sqr_grad_norms.dtype)
|
||||
layer_instance.activation = orig_activation
|
||||
return sqr_grad_norms, base_vars, layer_instance.activation
|
||||
|
||||
|
||||
def embedding_layer_computation(layer_instance, inputs):
|
||||
"""Registry function for `tf.keras.layers.Embedding`.
|
||||
|
||||
The logic of this computation is based on the `tf.keras.layers.Dense`
|
||||
computation and the fact that an embedding layer is just a dense layer
|
||||
with no activation function and an output vector of the form X*W for input
|
||||
X, where the i-th row of W is the i-th embedding vector and the j-th row of
|
||||
X is a one-hot vector representing the input of example j.
|
||||
|
||||
Args:
|
||||
layer_instance: A `tf.keras.layers.Embedding` instance.
|
||||
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
||||
`layer_instance(inputs)` returns a valid output.
|
||||
|
||||
Returns:
|
||||
A `tuple` `(sqr_grad_norms, base_vars, None)`, where `sqr_grad_norms` is
|
||||
a `tf.Tensor` that is related to the squared l2-norms of the input tensors
|
||||
and `base_vars` is the intermediate Tensor used in the chain-rule / "fast"
|
||||
clipping trick.
|
||||
"""
|
||||
if hasattr(layer_instance, "sparse"): # for backwards compatibility
|
||||
if layer_instance.sparse:
|
||||
raise NotImplementedError("Sparse output vectors are not supported.")
|
||||
if tf.rank(*inputs) != 2:
|
||||
raise NotImplementedError("Only 2D embedding inputs are supported.")
|
||||
# The logic below is applied to properly handle repeated embedding indices.
|
||||
# Specifically, sqr_grad_norms will contain the total counts of each embedding
|
||||
# index (see how it is processed in the combine_pre_and_post_sqr_norms()
|
||||
# function in clip_grads.py). An example is as follows:
|
||||
#
|
||||
# inputs =
|
||||
# [[0 0 0 1 2 2],
|
||||
# [0 2 2 2 1 1]]
|
||||
#
|
||||
# counts =
|
||||
# [[3 1 2]
|
||||
# [1 2 3]]
|
||||
#
|
||||
# sqr_grad_norms =
|
||||
# [[3 3 3 1 2 2],
|
||||
# [1 3 3 3 2 2]]
|
||||
#
|
||||
base_vars = layer_instance(*inputs)
|
||||
indices = tf.cast(*inputs, tf.int32)
|
||||
if isinstance(indices, tf.SparseTensor):
|
||||
indices = tf.sparse.to_dense(indices)
|
||||
counts = tf.math.bincount(indices, axis=-1)
|
||||
sqr_grad_norms = tf.cast(
|
||||
tf.gather(counts, indices, batch_dims=1), base_vars.dtype
|
||||
)
|
||||
return sqr_grad_norms, base_vars, None
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main factory methods
|
||||
# ==============================================================================
|
||||
def make_default_layer_registry():
|
||||
registry = {}
|
||||
registry[hash(tf.keras.layers.Dense)] = dense_layer_computation
|
||||
registry[hash(tf.keras.layers.Embedding)] = embedding_layer_computation
|
||||
return registry
|
|
@ -18,7 +18,7 @@ py_library(
|
|||
deps = [
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry_factories",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -28,7 +28,7 @@ py_test(
|
|||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry_factories",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
||||
"//tensorflow_privacy/privacy/keras_models:dp_keras_model",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -96,13 +96,10 @@ def make_dp_model_class(cls):
|
|||
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
||||
num_microbatches: Number of microbatches.
|
||||
use_xla: If `True`, compiles train_step to XLA.
|
||||
layer_registry: A `dict` of layers that support "fast" gradient norm
|
||||
computations. The key is the class of the layer and the value is a
|
||||
function that returns a `tuple` `(output, sqr_grad_norms, vars)`,
|
||||
where `output` is the pre-activator tensor, `sqr_grad_norms` is
|
||||
related to the squared norms of a layer's pre-activation tensor, and
|
||||
`vars` are relevant trainable weights (see
|
||||
`layer_registry_factories.py` for examples).
|
||||
layer_registry: A `LayerRegistry` instance containing functions that
|
||||
help compute gradient norms quickly. See
|
||||
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
|
||||
more details.
|
||||
*args: These will be passed on to the base class `__init__` method.
|
||||
**kwargs: These will be passed on to the base class `__init__` method.
|
||||
"""
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry_factories
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||
from tensorflow_privacy.privacy.keras_models import dp_keras_model
|
||||
|
||||
|
||||
|
@ -32,7 +32,10 @@ def get_layer_registries():
|
|||
# Outputs a list of testable layer registries.
|
||||
# The empty registry {} tests the behavior of the standard approach,
|
||||
# while the other one tests the fast gradient clipping algorithm.
|
||||
return [{}, layer_registry_factories.make_default_layer_registry()]
|
||||
return [
|
||||
layer_registry.LayerRegistry(),
|
||||
layer_registry.make_default_layer_registry(),
|
||||
]
|
||||
|
||||
|
||||
class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
|
Loading…
Reference in a new issue