Implement and test a registry function for tfm.nlp.layers.EinsumDense + small formatting fixes.

PiperOrigin-RevId: 576215816
This commit is contained in:
William Kong 2023-10-24 11:54:24 -07:00 committed by A. Unique TensorFlower
parent 8b52ba246c
commit 39c8a8c1af
6 changed files with 442 additions and 5 deletions

View file

@ -13,6 +13,7 @@ py_library(
name = "einsum_utils",
srcs = ["einsum_utils.py"],
srcs_version = "PY3",
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils"],
)
py_test(
@ -24,6 +25,33 @@ py_test(
deps = [":einsum_utils"],
)
py_library(
name = "einsum_dense",
srcs = ["einsum_dense.py"],
srcs_version = "PY3",
deps = [
":einsum_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
],
)
py_test(
name = "einsum_dense_test",
size = "large",
srcs = ["einsum_dense_test.py"],
python_version = "PY3",
shard_count = 12,
srcs_version = "PY3",
deps = [
":dense",
":einsum_dense",
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
],
)
py_library(
name = "dense",
srcs = ["dense.py"],

View file

@ -0,0 +1,70 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast clipping function for `tfm.nlp.layers.EinsumDense`."""
from collections.abc import Mapping, Sequence
from typing import Any, Optional
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_utils
def einsum_layer_computation(
layer_instance: tf.keras.layers.EinsumDense,
input_args: Sequence[Any],
input_kwargs: Mapping[str, Any],
tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None,
) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.EinsumDense`.
For the technical details, see the documentation of
`einsum_utils.compute_fast_einsum_gradient_norm()`.
Args:
layer_instance: A `tf.keras.layers.EinsumDense` instance.
input_args: See `dense_layer_computation()` in `dense.py`.
input_kwargs: See `dense_layer_computation()` in `dense.py`.
tape: See `dense_layer_computation()` in `dense.py`.
num_microbatches: See `dense_layer_computation()` in `dense.py`.
Returns:
See `dense_layer_computation()` in `dense.py`.
"""
if input_kwargs:
raise ValueError("EinsumDense layer calls should not receive kwargs.")
del input_kwargs
if len(input_args) != 1:
raise ValueError("Only layer inputs of length 1 are permitted.")
orig_activation = layer_instance.activation
# Some activation functions may not apply a transform to the elements of the
# output individually (which is needed for the fast clipping trick to work).
# To avoid this case, we watch the variables that are only generated by the
# linear transformation of the `EinsumDense` layer instance.
layer_instance.activation = None
base_vars = layer_instance(*input_args)
tape.watch(base_vars)
layer_instance.activation = orig_activation
outputs = orig_activation(base_vars) if orig_activation else base_vars
def sqr_norm_fn(grads):
return einsum_utils.compute_fast_einsum_squared_gradient_norm(
layer_instance.equation,
input_args[0],
grads,
layer_instance.bias_axes,
num_microbatches,
)
return base_vars, outputs, sqr_norm_fn

View file

@ -0,0 +1,171 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import tensorflow as tf
import tensorflow_models as tfm
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_dense
def get_einsum_layer_generators():
def pure_einsum_layer(equation, output_dims, bias_axes):
return tf.keras.layers.EinsumDense(
equation, output_dims, bias_axes=bias_axes
)
def sigmoid_einsum_layer(equation, output_dims, bias_axes):
return tf.keras.layers.EinsumDense(
equation, output_dims, bias_axes=bias_axes, activation='sigmoid'
)
return {
'pure_einsum': pure_einsum_layer,
'sigmoid_einsum': sigmoid_einsum_layer,
}
def get_einsum_parameter_tuples():
# (equation, input_dims, output_dims, bias_axes)
return [
# Case C1.
('ab,bc->ac', [2], [3], None),
('ab,bc->ac', [2], [3], 'c'),
('abc,cd->abd', [2, 3], [2, 4], None),
('abc,cd->abd', [2, 3], [2, 4], 'b'),
('abc,cd->abd', [2, 3], [2, 4], 'd'),
('abc,cd->abd', [2, 3], [2, 4], 'bd'),
('abc,cef->abef', [2, 3], [2, 4, 5], None),
('abc,cef->abef', [2, 3], [2, 4, 5], 'bf'),
# Case C2.
('...b,bc->...c', [2, 3], [4], None),
('...b,bc->...c', [2, 3], [4], 'c'),
('...ab,bc->...ac', [2, 3], [2, 4], None),
('...ab,bc->...ac', [2, 4], [2, 4], 'c'),
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], None),
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'b'),
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'd'),
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'bd'),
('...abc,cef->...abef', [2, 3, 4], [2, 3, 5, 6], None),
('...abc,cef->...abef', [2, 3, 4], [2, 3, 5, 6], 'bf'),
# Case C3.
('ab...,bc->ac...', [2, 3], [4, 3], None),
('ab...,bc->ac...', [2, 3], [4, 3], 'c'),
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], None),
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'b'),
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'd'),
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'bd'),
('abc...,cef->abef...', [2, 3, 4], [2, 5, 6, 4], None),
('abc...,cef->abef...', [2, 3, 4], [2, 5, 6, 4], 'bf'),
]
def get_einsum_layer_registry():
einsum_registry = layer_registry.LayerRegistry()
einsum_registry.insert(
tfm.nlp.layers.EinsumDense,
einsum_dense.einsum_layer_computation,
)
return einsum_registry
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.strategy = tf.distribute.get_strategy()
self.using_tpu = False
@parameterized.product(
layer_name=list(get_einsum_layer_generators()),
param_tuple=get_einsum_parameter_tuples(),
num_microbatches=[None, 2],
is_eager=[True, False],
)
def test_gradient_norms_on_various_models(
self,
layer_name,
param_tuple,
num_microbatches,
is_eager,
):
# Parse inputs to generate test data. Note that each batched input is a
# reshape of a `tf.range()` call.
equation, input_dims, output_dims, bias_axes = param_tuple
batch_size = 4
example_size = tf.reduce_prod(input_dims)
example_values = tf.range(batch_size * example_size, dtype=tf.float32)
x_batch = tf.reshape(example_values, [batch_size] + input_dims)
# Make the layer generator via currying.
einsum_generator = get_einsum_layer_generators()[layer_name]
def curried_generator(a, b):
del a, b
return einsum_generator(equation, output_dims, bias_axes)
# Load shared assets to all devices.
with self.strategy.scope():
model = common_test_utils.get_model_from_generator(
model_generator=common_test_utils.make_one_layer_functional_model,
layer_generator=curried_generator,
input_dims=input_dims,
output_dims=output_dims,
is_eager=is_eager,
)
# Define the main testing ops. These may be later compiled to a Graph op.
def test_op(x):
return common_test_utils.get_computed_and_true_norms_from_model(
model=model,
per_example_loss_fn=None,
num_microbatches=num_microbatches,
x_batch=x,
registry=get_einsum_layer_registry(),
)
# TPUs can only run `tf.function`-decorated functions.
if self.using_tpu:
test_op = tf.function(test_op, autograph=False)
# TPUs use lower precision than CPUs, so we relax our criterion.
# E.g., one of the TPU runs generated the following results:
#
# computed_norm = 93.48296
# true_norm = 93.31176
# abs_diff = 0.17120361
# rel_diff = 0.00183475
#
# which is a reasonable level of error for computing gradient norms.
# Other trials also give an absolute (resp. relative) error of around
# 0.05 (resp. 0.0015).
rtol = 1e-2 if self.using_tpu else 1e-3
atol = 5e-1 if self.using_tpu else 1e-2
# Set up the device ops and run the test.
computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,))
# TPUs return replica contexts, which must be unwrapped.
if self.using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0]
true_norms = true_norms.values[0]
expected_size = num_microbatches or batch_size
self.assertEqual(tf.shape(computed_norms)[0], expected_size)
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,30 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_dense_test
class GradNormTpuTest(einsum_dense_test.GradNormTest):
def setUp(self):
super(einsum_dense_test.GradNormTest, self).setUp()
self.strategy = common_test_utils.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
self.using_tpu = True
if __name__ == '__main__':
tf.test.main()

View file

@ -17,9 +17,11 @@ import enum
import itertools
import os
import re
from typing import Optional
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
EquationType = enum.Enum(
"EquationType",
@ -139,9 +141,7 @@ def _reshape_einsum_inputs(
(num_batches, num_rows, num_columns)
```
When `input_tensor` is a rank-2 `tf.Tensor`, the number of output rows is 1
and the number of output columns is the second dimension of the input. The
product of the non-trivial dimensions of the output should be equal to
the product of the dimensions of `input_tensor`.
and the number of output columns is the second dimension of the input.
Raises:
ValueError: If `equation` is not a valid einsum equation in the context of
@ -345,3 +345,141 @@ def _get_einsum_bias_adjoint_reduction_axes(
else:
reduction_axes.append(idx)
return reduction_axes
def compute_fast_einsum_squared_gradient_norm(
equation: str,
input_tensor: tf.Tensor,
grad_tensor: tf.Tensor,
bias_axes: Optional[str],
num_microbatches: Optional[int] = None,
) -> tf.Tensor:
"""Computes the batch gradient norms of an Einsum gradient decompostion.
This logic generalizes the one for `tf.keras.layers.Dense` and assumes that
the `equation` parameter is one of the following forms:
C1. ab,bc->ac,
C2. ...ab,bc->...ac
C3. ab...,bc->ac...
where `a`, `b`, and `c` are non-empty substrings.
For reference, we describe part of the mathematical analysis below. It can be
safely skipped upon the first reading of this docstring.
-----------------------------------------------------------------------------
BEGIN ANALYSIS
-----------------------------------------------------------------------------
For ease of exposition, all analysis is done for a single example, i.e.,
batch dimension is excluded from our consideration.
Recall that the einsum dense computation, excluding activation functions is of
the form
```
output = tf.einsum(equation, input, kernel) + bias,
```
where `bias` is broadcasted and summed with the output of the `tf.einsum()`
call, and equation is of the forms in C1, C2, and C3.
Mathematically, the above computation is equivalent to:
```
output = tf.matmul(X, W) + Q(bias)
```
where `X` (resp. `W`) is a 2D tensor reshaped from `input` (resp. `kernel`)
and `Q` is a linear operator that transforms `bias` to comport with the
tensor output by the `tf.matmul()` call. When generalizing to a batch of
examples, `X` is a 3D tensor whose first dimension is the batch dimension.
Following the same trick as for `tf.keras.layers.Dense` layers, suppose that
we have:
```
loss = f(output)
G = tape.gradient(loss, output)
```
Then, using the chain rule and denoting `A'` to be the adjoint of a matrix
`A`, one can show that the gradient of `loss` with respect to `W` is given by
the block matrix `K := [X'G; Q'G]`. Hence, the square norm of `K`, i.e., what
is returned by `sqr_norm_fn` is given by
```
sqr_norm = <XX', GG'> + ||Q'G||_F^2
```
where `||.||_F` is the Frobenius norm and `<.,.>` is the Euclidean inner
product for matrices.
-----------------------------------------------------------------------------
END ANALYSIS
-----------------------------------------------------------------------------
Args:
equation: A `string` representing the einsum equation.
input_tensor: A `tf.Tensor` reprenting the einsum input.
grad_tensor: A `tf.Tensor` that is the gradient of the scalar loss with
respect to the pre-activation tensor.
bias_axes: An optional `string` that specifies the einsum biases in
`equation`.
num_microbatches: An optional `int` that specifies the number of
microbatches used in a batch.
Returns:
A 1D `tf.Tensor` whose i-th entry is the squared gradient corresponding
to the i-th example in `input_tensor`.
"""
# Compute the matrix `X X'` and `G G'` for each example or microbatch.
# `x.shape = (batch_size, num_rows, num_columns)`
x = _reshape_einsum_inputs(input_tensor, equation)
g = _reshape_einsum_outputs(grad_tensor, equation)
# Adding microbatches is equivalent to splitting the first `(batch_size)`
# axis into `(num_microbatches, microbatch_size)` axes and merging the
# `microbatch_size` axis with the `num_rows` axis via a reshape.
if num_microbatches is not None:
# `x.shape = (num_microbatches, microbatch_size, num_rows, num_columns)`
x = common_manip_utils.maybe_add_microbatch_axis(x, num_microbatches)
g = common_manip_utils.maybe_add_microbatch_axis(g, num_microbatches)
sx = tf.shape(x)
sg = tf.shape(g)
# `x.shape = (num_microbatches, microbatch_size * num_rows, num_columns)`
x = tf.reshape(x, shape=[sx[0], sx[1] * sx[2], sx[3]])
g = tf.reshape(g, shape=[sg[0], sg[1] * sg[2], sg[3]])
# NOTE: When the input/gradient tensors are 1D, it is MUCH faster to do
# a `tf.square()` + `tf.reduce_sum()` than a single `tf.matmul()`.
if (
_is_batch_of_vectors(input_tensor)
and _is_batch_of_vectors(grad_tensor)
and num_microbatches is None
):
x_matrix = tf.reshape(x, [tf.shape(x)[0], -1])
g_matrix = tf.reshape(g, [tf.shape(g)[0], -1])
batch_xxt = tf.reduce_sum(tf.square(x_matrix), axis=1)
batch_ggt = tf.reduce_sum(tf.square(g_matrix), axis=1)
else:
batch_xxt = tf.matmul(x, x, transpose_b=True)
batch_ggt = tf.matmul(g, g, transpose_b=True)
# Compute the (micro)batch inner product; adjust for biases if necessary.
batch_xxt_ggt = tf.multiply(batch_xxt, batch_ggt)
reduction_axes = tf.range(1, tf.rank(batch_xxt_ggt))
sqr_norms = tf.reduce_sum(batch_xxt_ggt, axis=reduction_axes)
if bias_axes is not None:
# The adjoint operator `Q` on `G` is a reduce sum on the axes in `G` that
# are not broadcasted from `bias`.
grad_rank = len(grad_tensor.shape)
adjoint_reduction_axes = _get_einsum_bias_adjoint_reduction_axes(
equation,
bias_axes,
grad_rank,
)
# Adding microbatches with non-trival bias axes is equivalent to splitting
# the first `(batch_size)` axis into `(num_microbatches, microbatch_size)`
# axes, and adding the `microbatch_size` axis (=1) to the reduction axes
# needed to compute the bias broadcast adjoint operator.
if num_microbatches is not None:
grad_tensor = common_manip_utils.maybe_add_microbatch_axis(
grad_tensor, num_microbatches
)
adjoint_reduction_axes = [i + 1 for i in adjoint_reduction_axes]
adjoint_reduction_axes = [1] + adjoint_reduction_axes
qg = tf.reduce_sum(grad_tensor, axis=adjoint_reduction_axes)
qg_reduction_axes = tf.range(1, tf.rank(qg))
bias_sqr_norms = tf.reduce_sum(tf.square(qg), axis=qg_reduction_axes)
sqr_norms += bias_sqr_norms
return sqr_norms

View file

@ -13,7 +13,7 @@
# limitations under the License.
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization_test
@ -21,7 +21,7 @@ class GradNormTpuTest(layer_normalization_test.GradNormTest):
def setUp(self):
super(layer_normalization_test.GradNormTest, self).setUp()
self.strategy = ctu.create_tpu_strategy()
self.strategy = common_test_utils.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
self.using_tpu = True