forked from 626_privacy/tensorflow_privacy
First implementation of the fast gradient clipping algorithm.
PiperOrigin-RevId: 504668189
This commit is contained in:
parent
ee3d349a8d
commit
a3b14ae20a
4 changed files with 564 additions and 0 deletions
25
tensorflow_privacy/privacy/fast_gradient_clipping/BUILD
Normal file
25
tensorflow_privacy/privacy/fast_gradient_clipping/BUILD
Normal file
|
@ -0,0 +1,25 @@
|
|||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
py_library(
|
||||
name = "gradient_clipping_utils",
|
||||
srcs = ["gradient_clipping_utils.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "layer_registry_factories",
|
||||
srcs = ["layer_registry_factories.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "clip_grads",
|
||||
srcs = ["clip_grads.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":gradient_clipping_utils",
|
||||
":layer_registry_factories",
|
||||
],
|
||||
)
|
185
tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py
Normal file
185
tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py
Normal file
|
@ -0,0 +1,185 @@
|
|||
# Copyright 2022, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Computes per-example loss clip weights.
|
||||
|
||||
For a given Keras model and batch of inputs, computes the per-example
|
||||
clip weights so that the gradient of the loss function, weighted by these
|
||||
weights, is equivalent to the gradient of the original loss function but
|
||||
with the per-example gradients clipped by some clip weight. Uses a variant
|
||||
of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
|
||||
`compute_gradient_norms()` function).
|
||||
"""
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||
|
||||
|
||||
def combine_pre_and_post_sqr_norms(pre_sqr_norm, post_grad, layer_hash):
|
||||
"""Combines pre and post-activation tensors for a given variable.
|
||||
|
||||
The logic for combining norms depends on the variable's underlying layer.
|
||||
|
||||
Args:
|
||||
pre_sqr_norm: A `tf.Tensor` whose first dimension is the batch dimension.
|
||||
Contains squared norms that are related to the pre-activation Tensor.
|
||||
post_grad: A `tf.Tensor` whose first dimension is the batch dimension.
|
||||
Contains gradients that are related to the post-activation Tensor.
|
||||
layer_hash: A `float` that is the hash of the variable's underlying layer
|
||||
class.
|
||||
|
||||
Returns:
|
||||
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
|
||||
per-example loss function with respect to the given variable.
|
||||
"""
|
||||
post_sqr_grads = tf.square(post_grad)
|
||||
if layer_hash == hash(tf.keras.layers.Embedding):
|
||||
scaled_grads = tf.expand_dims(pre_sqr_norm, axis=-1) * post_sqr_grads
|
||||
reduction_axes = tf.range(1, tf.rank(scaled_grads))
|
||||
return tf.reduce_sum(scaled_grads, axis=reduction_axes)
|
||||
else:
|
||||
reduction_axes = tf.range(1, tf.rank(post_sqr_grads))
|
||||
post_sqr_norm = tf.reduce_sum(post_sqr_grads, axis=reduction_axes)
|
||||
return pre_sqr_norm * post_sqr_norm
|
||||
|
||||
|
||||
def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
|
||||
"""Computes the per-example loss gradient norms for given data.
|
||||
|
||||
Applies the approach given in https://arxiv.org/pdf/2009.03106.pdf, except
|
||||
the batch matrix multiplication operation in Algorithm 2 is replaced with
|
||||
the computation of two norm computations.
|
||||
|
||||
Args:
|
||||
input_model: The `tf.keras.Model` from which to obtain the layers from. The
|
||||
loss of the model *must* be a scalar loss.
|
||||
x_batch: A `tf.Tensor` representing a batch of inputs to the model. The
|
||||
first axis must be the batch dimension.
|
||||
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
||||
must be the batch dimension. The number of examples should match the
|
||||
number of examples in `x_batch`.
|
||||
layer_registry: A `dict` of layers that support "fast" gradient norm
|
||||
computations. The key is the class of the layer and the value is a
|
||||
function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
|
||||
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
|
||||
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
||||
trainable weights (see `layer_registry_factories.py` for examples).
|
||||
|
||||
Returns:
|
||||
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
|
||||
per-example loss function.
|
||||
"""
|
||||
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
||||
# First loop computes the norms of the layer inputs, caches these inputs,
|
||||
# and computes the summed loss.
|
||||
with tape:
|
||||
model_outputs, pre_norm_list, var_list, layer_hash_list = (
|
||||
gradient_clipping_utils.forward_norm_pass(
|
||||
input_model, x_batch, tape, layer_registry
|
||||
)
|
||||
)
|
||||
# Ignore the original loss function's reduction to get per-example loss.
|
||||
loss_config = input_model.loss.get_config()
|
||||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
||||
losses = per_example_loss_fn(y_batch, model_outputs)
|
||||
if tf.rank(tf.squeeze(losses)) > 1:
|
||||
raise NotImplementedError('Vector losses are not supported.')
|
||||
summed_loss = tf.reduce_sum(losses)
|
||||
# Second loop computes the norm of the gradient of the loss with respect to
|
||||
# the pre-activation tensors, and multiplies these norms with the results of
|
||||
# the first loop.
|
||||
full_norm_list = []
|
||||
grads = tape.gradient(summed_loss, var_list)
|
||||
for i in range(len(var_list)):
|
||||
full_norm = combine_pre_and_post_sqr_norms(
|
||||
pre_norm_list[i], grads[i], layer_hash_list[i]
|
||||
)
|
||||
full_norm_list.append(full_norm)
|
||||
del tape
|
||||
# Post-processing for compatibility with non-eager mode (very annoying).
|
||||
full_norm_tsr = tf.stack(full_norm_list, axis=1)
|
||||
return tf.sqrt(tf.reduce_sum(full_norm_tsr, axis=1))
|
||||
|
||||
|
||||
def compute_clip_weights(l2_norm_clip, gradient_norms):
|
||||
"""Computes the per-example loss/clip weights for clipping.
|
||||
|
||||
When the sum of the per-example losses is replaced a weighted sum, where
|
||||
the weights are generated by this method, then the gradients of each
|
||||
term in the weighted sum are clipped by the given clip value.
|
||||
|
||||
Args:
|
||||
l2_norm_clip: A `float` indicating the norm to which per-example gradients
|
||||
will be clipped. That is, all gradients of the per-example loss functions
|
||||
will have norm at most `l2_norm_clip`.
|
||||
gradient_norms: A 1D `tf.Tensor` whose i-th entry is the norm of the
|
||||
gradient of the loss function for the i-th input.
|
||||
|
||||
Returns:
|
||||
A 1D `tf.Tensor` representing whose i-th entry `C[i]` is either `1.0` if the
|
||||
norm of the gradient of i-th per-example loss `G[i]` is less than
|
||||
`l2_norm_clip` or a number less than `1.0` so that
|
||||
`|G[i]| * C[i] == l2_norm_clip` otherwise.
|
||||
"""
|
||||
if l2_norm_clip is None:
|
||||
return None
|
||||
return l2_norm_clip / tf.math.maximum(l2_norm_clip, gradient_norms)
|
||||
|
||||
|
||||
def compute_pred_and_clipped_gradients(
|
||||
input_model, x_batch, y_batch, l2_norm_clip, layer_registry
|
||||
):
|
||||
"""Computes the per-example predictions and per-example clipped loss gradient.
|
||||
|
||||
Given a batch of observations `(x_batch, y_batch)`, the main steps of this
|
||||
function are: (i) compute the l2-norm of the gradients of the trainable
|
||||
variables of `input_model` for each example in the batch; (ii) use the norms
|
||||
computed in (i) to obtain "clip_weights" that are used to generate a weighted
|
||||
loss function whose gradient for each example has l2-norm at most
|
||||
`l2_norm_clip`; (iii) output the clipped gradients in (ii) and the
|
||||
`tf.Tensor` generated by `input_model` when it is given `x_batch` as its
|
||||
input.
|
||||
|
||||
Args:
|
||||
input_model: The `tf.keras.Model` from which to obtain the layers from.
|
||||
x_batch: A `tf.Tensor` representing a batch of inputs to the model. The
|
||||
first axis must be the batch dimension.
|
||||
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
||||
must be the batch dimension. The number of examples should match the
|
||||
number of examples in `x_batch`.
|
||||
l2_norm_clip: A `float` indicating the norm to which per-example gradients
|
||||
will be clipped. That is, all gradients of the per-example loss functions
|
||||
will have norm at most `l2_norm_clip`.
|
||||
layer_registry: A `dict` of layers that support "fast" gradient norm
|
||||
computations. The key is the class of the layer and the value is a
|
||||
function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
|
||||
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
|
||||
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
||||
trainable weights (see `layer_registry_factories.py` for examples).
|
||||
|
||||
Returns:
|
||||
A `tuple` `(y_pred, grad)`. The first element is the prediction generated by
|
||||
the model on the input `x_batch`. The second element is the clipped
|
||||
gradient of the loss function.
|
||||
"""
|
||||
gradient_norms = compute_gradient_norms(
|
||||
input_model, x_batch, y_batch, layer_registry
|
||||
)
|
||||
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
|
||||
with tf.GradientTape() as tape:
|
||||
y_pred = input_model(x_batch, training=True)
|
||||
loss_value = input_model.compute_loss(
|
||||
x_batch, y_batch, y_pred, loss_weights
|
||||
)
|
||||
return y_pred, tape.gradient(loss_value, input_model.trainable_variables)
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright 2022, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utility functions that help in the computation of per-example gradient norms."""
|
||||
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def has_internal_compute_graph(input_object):
|
||||
"""Checks if input is a TF model and has a TF internal compute graph."""
|
||||
return (
|
||||
isinstance(input_object, tf.keras.Model)
|
||||
and hasattr(input_object, '_flatten_to_reference_inputs')
|
||||
and hasattr(input_object, '_tensor_usage_count')
|
||||
and hasattr(input_object, '_conform_to_reference_input')
|
||||
and hasattr(input_object, '_nodes_by_depth')
|
||||
)
|
||||
|
||||
|
||||
def forward_norm_pass(input_model, x_batch, tape, layer_registry):
|
||||
"""Does a forward pass of a model and returns some useful intermediates.
|
||||
|
||||
NOTE: the graph traversal algorithm is an adaptation of the logic in the
|
||||
_run_internal_graph() method in the functional.Functional class. Hence,
|
||||
forward_norm_pass should only be invoked if the generated model
|
||||
instance is an instance of the functional.Functional class.
|
||||
|
||||
Args:
|
||||
input_model: A Keras functional model to compute the quantities for.
|
||||
x_batch: A collection of Tensors to be fed into the input layer of the
|
||||
model.
|
||||
tape: A tf.GradientTape() instance used to record certain operations.
|
||||
Assumes that this function is being called inside of this tape.
|
||||
layer_registry: A dictionary of layers that support "fast" gradient norm
|
||||
computations. The key is the class of the layer and the value is a
|
||||
function that returns a triple (norm_list, var_list, transform). For more
|
||||
details, see `layer_registry_factories.py`.
|
||||
|
||||
Returns:
|
||||
Four objects (outputs, norm_list, var_list, layer_hash_list). The first are
|
||||
the outputs that are generated as a result of a forward pass. The second is
|
||||
either `None` or a collection of squared norms for each example that helps
|
||||
in computing the gradient norm of an example (if such a "nice" computation
|
||||
exists). The third is an ordered list of tf.Tensor() objects that are
|
||||
intended to be evaluated with respect to the summed loss of the model. The
|
||||
fourth is a list whose i-th element is the hash of the layer class of the
|
||||
i-th element of var_list.
|
||||
"""
|
||||
# TODO: Avoid or remove the references to protected methods of `input_model`. # pylint: disable=g-bad-todo
|
||||
# Prepare the inputs and BFS variables.
|
||||
flattened_inputs = input_model._flatten_to_reference_inputs(x_batch) # pylint: disable=protected-access
|
||||
tensor_dict = {}
|
||||
tensor_usage_count = input_model._tensor_usage_count # pylint: disable=protected-access
|
||||
for x, y in zip(input_model.inputs, flattened_inputs):
|
||||
y = input_model._conform_to_reference_input(y, ref_input=x) # pylint: disable=protected-access
|
||||
x_id = str(id(x))
|
||||
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
|
||||
|
||||
# Main computations.
|
||||
nodes_by_depth = input_model._nodes_by_depth # pylint: disable=protected-access
|
||||
depth_keys = list(nodes_by_depth.keys())
|
||||
depth_keys.sort(reverse=True)
|
||||
var_list = []
|
||||
norm_list = []
|
||||
layer_hash_list = []
|
||||
node_outputs = None
|
||||
# Perform BFS feedforward computations.
|
||||
for depth in depth_keys:
|
||||
for node in nodes_by_depth[depth]:
|
||||
if node.is_input:
|
||||
continue # inputs already exist
|
||||
if any(t_id not in tensor_dict for t_id in node.flat_input_ids):
|
||||
continue # node is not computable; try skipping
|
||||
args, kwargs = node.map_arguments(tensor_dict)
|
||||
if has_internal_compute_graph(node.layer):
|
||||
# If this node has an internal computational graph, we can recurse.
|
||||
node_outputs, node_norm_list, node_var_list, node_layer_hash_list = (
|
||||
forward_norm_pass(node.layer, args, tape, layer_registry)
|
||||
)
|
||||
var_list += node_var_list
|
||||
norm_list += node_norm_list
|
||||
layer_hash_list += node_layer_hash_list
|
||||
else:
|
||||
# Either pass through or record some metadata.
|
||||
if not node.layer.trainable_variables:
|
||||
node_outputs = node.layer(*args, **kwargs)
|
||||
else:
|
||||
lyr_hash = hash(node.layer.__class__)
|
||||
if lyr_hash not in layer_registry:
|
||||
raise NotImplementedError(
|
||||
'Layer %s is not in the registry of known layers that can'
|
||||
'be used for efficient gradient clipping.'
|
||||
% node.layer.__class__.__name__
|
||||
)
|
||||
lyr = layer_registry[lyr_hash]
|
||||
node_norms, node_vars, transform = lyr(node.layer, args)
|
||||
tape.watch(node_vars)
|
||||
node_outputs = transform(node_vars) if transform else node_vars
|
||||
var_list.append(node_vars)
|
||||
norm_list.append(node_norms)
|
||||
layer_hash_list.append(lyr_hash)
|
||||
# update the current dictionary of inputs for the next node.
|
||||
for x_id, y in zip(node.flat_output_ids, tf.nest.flatten(node_outputs)):
|
||||
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
|
||||
|
||||
return node_outputs, norm_list, var_list, layer_hash_list
|
||||
|
||||
|
||||
def get_trainable_hidden_layers(input_model):
|
||||
"""Obtains the trainable hidden layers of a Keras model.
|
||||
|
||||
Args:
|
||||
input_model: The Keras model to obtain the layers from.
|
||||
|
||||
Returns:
|
||||
A list of Keras layers where the tensorflow.keras.layers.Layer
|
||||
ancestor class MUST precede any existing tensorflow.keras.models.Model
|
||||
ancestor class.
|
||||
"""
|
||||
hidden_layers = []
|
||||
for l in input_model.layers:
|
||||
for c in l.__class__.__mro__:
|
||||
if c == tf.keras.models.Model:
|
||||
hidden_layers += get_trainable_hidden_layers(l)
|
||||
break
|
||||
elif c == tf.keras.layers.InputLayer:
|
||||
break
|
||||
elif c == tf.keras.layers.Layer:
|
||||
if l.trainable_variables:
|
||||
hidden_layers.append(l)
|
||||
break
|
||||
return hidden_layers
|
||||
|
||||
|
||||
def all_trainable_layers_are_registered(input_model, layer_registry):
|
||||
"""Check if an input model's trainable layers are all registered.
|
||||
|
||||
Args:
|
||||
input_model: The Keras model from which to obtain the layers from.
|
||||
layer_registry: A dictionary of layers that support "fast" gradient norm
|
||||
computations. The key is the class of the layer and the value is a
|
||||
function that returns a triple (output, sqr_grad_norms, vars), where
|
||||
output is the pre-activator tensor, sqr_grad_norms is the square of the
|
||||
norm of the layer's input, and vars is an ordered list of the trainable
|
||||
weights.
|
||||
|
||||
Returns:
|
||||
True if all the trainable layers in `input_model` are in `layer_registry`.
|
||||
False otherwise.
|
||||
"""
|
||||
hidden_layers = get_trainable_hidden_layers(input_model)
|
||||
for l in hidden_layers:
|
||||
if hash(l.__class__) not in layer_registry:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def add_aggregate_noise(
|
||||
input_model, x_batch, clipped_grads, l2_norm_clip, noise_multiplier
|
||||
):
|
||||
"""Adds noise to a collection of clipped gradients.
|
||||
|
||||
The magnitude of the noise depends on the aggregation strategy of the
|
||||
input model's loss function.
|
||||
|
||||
Args:
|
||||
input_model: The Keras model to obtain the layers from.
|
||||
x_batch: A collection of Tensors to be fed into the input layer of the
|
||||
model.
|
||||
clipped_grads: A list of tensors representing the clipped gradients.
|
||||
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
|
||||
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
||||
|
||||
Returns:
|
||||
A list of tensors containing the clipped gradients, but with the right
|
||||
amount of Gaussian noise added to them (depending on the reduction
|
||||
strategy of the loss function).
|
||||
"""
|
||||
scale = l2_norm_clip
|
||||
if input_model.loss.reduction in [
|
||||
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
|
||||
tf.keras.losses.Reduction.AUTO,
|
||||
]:
|
||||
if input_model.loss.reduction == tf.keras.losses.Reduction.AUTO:
|
||||
logging.info('Assuming that the loss reduction is `SUM_OVER_BATCH_SIZE`.')
|
||||
if isinstance(x_batch, tf.Tensor):
|
||||
scale /= tf.cast(tf.shape(x_batch)[0], tf.float32)
|
||||
elif isinstance(x_batch, dict):
|
||||
batch_sizes = [
|
||||
tf.cast(tf.shape(v)[0], tf.float32) for v in x_batch.values()
|
||||
]
|
||||
scale /= tf.math.reduce_min(batch_sizes)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Unknown container/class %s for input' % x_batch.__class__.__name__
|
||||
)
|
||||
|
||||
def add_noise(g):
|
||||
return g + tf.random.normal(
|
||||
tf.shape(g), mean=0.0, stddev=noise_multiplier * scale
|
||||
)
|
||||
|
||||
return tf.nest.map_structure(add_noise, clipped_grads)
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright 2022, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Generates a default layer registry.
|
||||
|
||||
Defines "fast" gradient norm layer registry functions for use in the "fast"
|
||||
gradient clipping algorithm. Specifically, each registry function takes
|
||||
in two inputs (i) a layer instance and (ii) `tf.Tensor` inputs to produce three
|
||||
outputs: (a) a `tf.Tensor` `G` of gradient norms, (b) a differentiable
|
||||
`tf.Tensor` `Z`, and (c) either `None` or a function that maps the object in
|
||||
(b) to the layer instance's output when using the inputs in (ii). If (c) is
|
||||
`None`, then (b) contains the layer outputs.
|
||||
|
||||
When a layer registry function is defined, it is generally assumed that the
|
||||
following relation holds for each pair `(g, z)` in `zip(G, Z)`:
|
||||
|
||||
`|dL/dw|^2 == |dL/dz|^2 * g^2`
|
||||
|
||||
where `L` is any per-example loss function and `w` are the trainable variables
|
||||
corresponding to `(g, z)`.
|
||||
|
||||
For example, this relation holds if the layer instance is tf.keras.layers.Dense,
|
||||
Z contains the pre-activation tensors, i.e., `z = X * w` for input `X`, and `g`
|
||||
is the norm of the input corresponding to the given per-example loss (see the
|
||||
formulae in https://arxiv.org/abs/1510.01799 for more details).
|
||||
|
||||
The registry functions are registered in a `dict` (registry) whose key is the
|
||||
hash of the layer class and whose value is the registry function.
|
||||
"""
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Supported Keras layers
|
||||
# ==============================================================================
|
||||
def dense_layer_computation(layer_instance, inputs):
|
||||
"""Registry function for `tf.keras.layers.Dense`.
|
||||
|
||||
The logic for this computation is based on the following paper:
|
||||
https://arxiv.org/abs/1510.01799
|
||||
|
||||
For the sake of efficiency, we fuse the variables and square grad norms
|
||||
for the kernel weights and bias vector together.
|
||||
|
||||
Args:
|
||||
layer_instance: A `tf.keras.layers.Dense` instance.
|
||||
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
||||
`layer_instance(inputs)` returns a valid output.
|
||||
|
||||
Returns:
|
||||
A `tuple` `(sqr_grad_norms, base_vars, transform)`, where `norms` is a 1D
|
||||
`tf.Tensor` of the squared l2-norms of the input tensors, `base_vars` is the
|
||||
intermediate Tensor used in the chain-rule / "fast" clipping trick, and
|
||||
`transform` is a function that maps `base_vars` to the layer outputs.
|
||||
"""
|
||||
orig_activation = layer_instance.activation
|
||||
layer_instance.activation = None
|
||||
base_vars = layer_instance(*inputs)
|
||||
sqr_inputs = tf.square(*inputs)
|
||||
inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs))
|
||||
sqr_grad_norms = tf.reduce_sum(tf.square(*inputs), axis=inputs_reduction_axes)
|
||||
if layer_instance.use_bias:
|
||||
# Adding a bias term is equivalent to a layer with no bias term and which
|
||||
# adds an additional variable to the layer input that only takes a constant
|
||||
# value of 1.0. This is thus equivalent to adding 1.0 to the sum of the
|
||||
# squared values of the inputs.
|
||||
sqr_grad_norms += tf.cast(1.0, dtype=sqr_grad_norms.dtype)
|
||||
layer_instance.activation = orig_activation
|
||||
return sqr_grad_norms, base_vars, layer_instance.activation
|
||||
|
||||
|
||||
def embedding_layer_computation(layer_instance, inputs):
|
||||
"""Registry function for `tf.keras.layers.Embedding`.
|
||||
|
||||
The logic of this computation is based on the `tf.keras.layers.Dense`
|
||||
computation and the fact that an embedding layer is just a dense layer
|
||||
with no activation function and an output vector of the form X*W for input
|
||||
X, where the i-th row of W is the i-th embedding vector and the j-th row of
|
||||
X is a one-hot vector representing the input of example j.
|
||||
|
||||
Args:
|
||||
layer_instance: A `tf.keras.layers.Embedding` instance.
|
||||
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
||||
`layer_instance(inputs)` returns a valid output.
|
||||
|
||||
Returns:
|
||||
A `tuple` `(sqr_grad_norms, base_vars, None)`, where `sqr_grad_norms` is
|
||||
a `tf.Tensor` that is related to the squared l2-norms of the input tensors
|
||||
and `base_vars` is the intermediate Tensor used in the chain-rule / "fast"
|
||||
clipping trick.
|
||||
"""
|
||||
if layer_instance.sparse:
|
||||
raise NotImplementedError("Sparse output vectors are not supported.")
|
||||
# The logic below is applied to properly handle repeated embedding indices.
|
||||
# Specifically, sqr_grad_norms will contain the total counts of each embedding
|
||||
# index (see how it is processed in the combine_pre_and_post_sqr_norms()
|
||||
# function in clip_grads.py). An example is as follows:
|
||||
#
|
||||
# inputs =
|
||||
# [[0 0 0 1 2 2],
|
||||
# [0 2 2 2 1 1]]
|
||||
#
|
||||
# counts =
|
||||
# [[3 1 2]
|
||||
# [1 2 3]]
|
||||
#
|
||||
# sqr_grad_norms =
|
||||
# [[3 3 3 1 2 2],
|
||||
# [1 3 3 3 2 2]]
|
||||
#
|
||||
base_vars = layer_instance(*inputs)
|
||||
indices = tf.cast(*inputs, tf.int32)
|
||||
if isinstance(indices, tf.SparseTensor):
|
||||
indices = tf.sparse.to_dense(indices)
|
||||
counts = tf.math.bincount(indices, axis=-1)
|
||||
sqr_grad_norms = tf.cast(
|
||||
tf.gather(counts, indices, batch_dims=1), base_vars.dtype
|
||||
)
|
||||
return sqr_grad_norms, base_vars, None
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main factory methods
|
||||
# ==============================================================================
|
||||
def make_default_layer_registry():
|
||||
registry = {}
|
||||
registry[hash(tf.keras.layers.Dense)] = dense_layer_computation
|
||||
registry[hash(tf.keras.layers.Embedding)] = embedding_layer_computation
|
||||
return registry
|
Loading…
Reference in a new issue