Implement and test a registry function for tfm.nlp.layers.EinsumDense + small formatting fixes.

PiperOrigin-RevId: 576215816
This commit is contained in:
William Kong 2023-10-24 11:54:24 -07:00 committed by A. Unique TensorFlower
parent 8b52ba246c
commit 39c8a8c1af
6 changed files with 442 additions and 5 deletions

View file

@ -13,6 +13,7 @@ py_library(
name = "einsum_utils", name = "einsum_utils",
srcs = ["einsum_utils.py"], srcs = ["einsum_utils.py"],
srcs_version = "PY3", srcs_version = "PY3",
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils"],
) )
py_test( py_test(
@ -24,6 +25,33 @@ py_test(
deps = [":einsum_utils"], deps = [":einsum_utils"],
) )
py_library(
name = "einsum_dense",
srcs = ["einsum_dense.py"],
srcs_version = "PY3",
deps = [
":einsum_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
],
)
py_test(
name = "einsum_dense_test",
size = "large",
srcs = ["einsum_dense_test.py"],
python_version = "PY3",
shard_count = 12,
srcs_version = "PY3",
deps = [
":dense",
":einsum_dense",
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
],
)
py_library( py_library(
name = "dense", name = "dense",
srcs = ["dense.py"], srcs = ["dense.py"],

View file

@ -0,0 +1,70 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast clipping function for `tfm.nlp.layers.EinsumDense`."""
from collections.abc import Mapping, Sequence
from typing import Any, Optional
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_utils
def einsum_layer_computation(
layer_instance: tf.keras.layers.EinsumDense,
input_args: Sequence[Any],
input_kwargs: Mapping[str, Any],
tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None,
) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.EinsumDense`.
For the technical details, see the documentation of
`einsum_utils.compute_fast_einsum_gradient_norm()`.
Args:
layer_instance: A `tf.keras.layers.EinsumDense` instance.
input_args: See `dense_layer_computation()` in `dense.py`.
input_kwargs: See `dense_layer_computation()` in `dense.py`.
tape: See `dense_layer_computation()` in `dense.py`.
num_microbatches: See `dense_layer_computation()` in `dense.py`.
Returns:
See `dense_layer_computation()` in `dense.py`.
"""
if input_kwargs:
raise ValueError("EinsumDense layer calls should not receive kwargs.")
del input_kwargs
if len(input_args) != 1:
raise ValueError("Only layer inputs of length 1 are permitted.")
orig_activation = layer_instance.activation
# Some activation functions may not apply a transform to the elements of the
# output individually (which is needed for the fast clipping trick to work).
# To avoid this case, we watch the variables that are only generated by the
# linear transformation of the `EinsumDense` layer instance.
layer_instance.activation = None
base_vars = layer_instance(*input_args)
tape.watch(base_vars)
layer_instance.activation = orig_activation
outputs = orig_activation(base_vars) if orig_activation else base_vars
def sqr_norm_fn(grads):
return einsum_utils.compute_fast_einsum_squared_gradient_norm(
layer_instance.equation,
input_args[0],
grads,
layer_instance.bias_axes,
num_microbatches,
)
return base_vars, outputs, sqr_norm_fn

View file

@ -0,0 +1,171 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import tensorflow as tf
import tensorflow_models as tfm
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_dense
def get_einsum_layer_generators():
def pure_einsum_layer(equation, output_dims, bias_axes):
return tf.keras.layers.EinsumDense(
equation, output_dims, bias_axes=bias_axes
)
def sigmoid_einsum_layer(equation, output_dims, bias_axes):
return tf.keras.layers.EinsumDense(
equation, output_dims, bias_axes=bias_axes, activation='sigmoid'
)
return {
'pure_einsum': pure_einsum_layer,
'sigmoid_einsum': sigmoid_einsum_layer,
}
def get_einsum_parameter_tuples():
# (equation, input_dims, output_dims, bias_axes)
return [
# Case C1.
('ab,bc->ac', [2], [3], None),
('ab,bc->ac', [2], [3], 'c'),
('abc,cd->abd', [2, 3], [2, 4], None),
('abc,cd->abd', [2, 3], [2, 4], 'b'),
('abc,cd->abd', [2, 3], [2, 4], 'd'),
('abc,cd->abd', [2, 3], [2, 4], 'bd'),
('abc,cef->abef', [2, 3], [2, 4, 5], None),
('abc,cef->abef', [2, 3], [2, 4, 5], 'bf'),
# Case C2.
('...b,bc->...c', [2, 3], [4], None),
('...b,bc->...c', [2, 3], [4], 'c'),
('...ab,bc->...ac', [2, 3], [2, 4], None),
('...ab,bc->...ac', [2, 4], [2, 4], 'c'),
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], None),
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'b'),
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'd'),
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'bd'),
('...abc,cef->...abef', [2, 3, 4], [2, 3, 5, 6], None),
('...abc,cef->...abef', [2, 3, 4], [2, 3, 5, 6], 'bf'),
# Case C3.
('ab...,bc->ac...', [2, 3], [4, 3], None),
('ab...,bc->ac...', [2, 3], [4, 3], 'c'),
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], None),
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'b'),
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'd'),
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'bd'),
('abc...,cef->abef...', [2, 3, 4], [2, 5, 6, 4], None),
('abc...,cef->abef...', [2, 3, 4], [2, 5, 6, 4], 'bf'),
]
def get_einsum_layer_registry():
einsum_registry = layer_registry.LayerRegistry()
einsum_registry.insert(
tfm.nlp.layers.EinsumDense,
einsum_dense.einsum_layer_computation,
)
return einsum_registry
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.strategy = tf.distribute.get_strategy()
self.using_tpu = False
@parameterized.product(
layer_name=list(get_einsum_layer_generators()),
param_tuple=get_einsum_parameter_tuples(),
num_microbatches=[None, 2],
is_eager=[True, False],
)
def test_gradient_norms_on_various_models(
self,
layer_name,
param_tuple,
num_microbatches,
is_eager,
):
# Parse inputs to generate test data. Note that each batched input is a
# reshape of a `tf.range()` call.
equation, input_dims, output_dims, bias_axes = param_tuple
batch_size = 4
example_size = tf.reduce_prod(input_dims)
example_values = tf.range(batch_size * example_size, dtype=tf.float32)
x_batch = tf.reshape(example_values, [batch_size] + input_dims)
# Make the layer generator via currying.
einsum_generator = get_einsum_layer_generators()[layer_name]
def curried_generator(a, b):
del a, b
return einsum_generator(equation, output_dims, bias_axes)
# Load shared assets to all devices.
with self.strategy.scope():
model = common_test_utils.get_model_from_generator(
model_generator=common_test_utils.make_one_layer_functional_model,
layer_generator=curried_generator,
input_dims=input_dims,
output_dims=output_dims,
is_eager=is_eager,
)
# Define the main testing ops. These may be later compiled to a Graph op.
def test_op(x):
return common_test_utils.get_computed_and_true_norms_from_model(
model=model,
per_example_loss_fn=None,
num_microbatches=num_microbatches,
x_batch=x,
registry=get_einsum_layer_registry(),
)
# TPUs can only run `tf.function`-decorated functions.
if self.using_tpu:
test_op = tf.function(test_op, autograph=False)
# TPUs use lower precision than CPUs, so we relax our criterion.
# E.g., one of the TPU runs generated the following results:
#
# computed_norm = 93.48296
# true_norm = 93.31176
# abs_diff = 0.17120361
# rel_diff = 0.00183475
#
# which is a reasonable level of error for computing gradient norms.
# Other trials also give an absolute (resp. relative) error of around
# 0.05 (resp. 0.0015).
rtol = 1e-2 if self.using_tpu else 1e-3
atol = 5e-1 if self.using_tpu else 1e-2
# Set up the device ops and run the test.
computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,))
# TPUs return replica contexts, which must be unwrapped.
if self.using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0]
true_norms = true_norms.values[0]
expected_size = num_microbatches or batch_size
self.assertEqual(tf.shape(computed_norms)[0], expected_size)
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,30 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_dense_test
class GradNormTpuTest(einsum_dense_test.GradNormTest):
def setUp(self):
super(einsum_dense_test.GradNormTest, self).setUp()
self.strategy = common_test_utils.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
self.using_tpu = True
if __name__ == '__main__':
tf.test.main()

View file

@ -17,9 +17,11 @@ import enum
import itertools import itertools
import os import os
import re import re
from typing import Optional
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
EquationType = enum.Enum( EquationType = enum.Enum(
"EquationType", "EquationType",
@ -139,9 +141,7 @@ def _reshape_einsum_inputs(
(num_batches, num_rows, num_columns) (num_batches, num_rows, num_columns)
``` ```
When `input_tensor` is a rank-2 `tf.Tensor`, the number of output rows is 1 When `input_tensor` is a rank-2 `tf.Tensor`, the number of output rows is 1
and the number of output columns is the second dimension of the input. The and the number of output columns is the second dimension of the input.
product of the non-trivial dimensions of the output should be equal to
the product of the dimensions of `input_tensor`.
Raises: Raises:
ValueError: If `equation` is not a valid einsum equation in the context of ValueError: If `equation` is not a valid einsum equation in the context of
@ -345,3 +345,141 @@ def _get_einsum_bias_adjoint_reduction_axes(
else: else:
reduction_axes.append(idx) reduction_axes.append(idx)
return reduction_axes return reduction_axes
def compute_fast_einsum_squared_gradient_norm(
equation: str,
input_tensor: tf.Tensor,
grad_tensor: tf.Tensor,
bias_axes: Optional[str],
num_microbatches: Optional[int] = None,
) -> tf.Tensor:
"""Computes the batch gradient norms of an Einsum gradient decompostion.
This logic generalizes the one for `tf.keras.layers.Dense` and assumes that
the `equation` parameter is one of the following forms:
C1. ab,bc->ac,
C2. ...ab,bc->...ac
C3. ab...,bc->ac...
where `a`, `b`, and `c` are non-empty substrings.
For reference, we describe part of the mathematical analysis below. It can be
safely skipped upon the first reading of this docstring.
-----------------------------------------------------------------------------
BEGIN ANALYSIS
-----------------------------------------------------------------------------
For ease of exposition, all analysis is done for a single example, i.e.,
batch dimension is excluded from our consideration.
Recall that the einsum dense computation, excluding activation functions is of
the form
```
output = tf.einsum(equation, input, kernel) + bias,
```
where `bias` is broadcasted and summed with the output of the `tf.einsum()`
call, and equation is of the forms in C1, C2, and C3.
Mathematically, the above computation is equivalent to:
```
output = tf.matmul(X, W) + Q(bias)
```
where `X` (resp. `W`) is a 2D tensor reshaped from `input` (resp. `kernel`)
and `Q` is a linear operator that transforms `bias` to comport with the
tensor output by the `tf.matmul()` call. When generalizing to a batch of
examples, `X` is a 3D tensor whose first dimension is the batch dimension.
Following the same trick as for `tf.keras.layers.Dense` layers, suppose that
we have:
```
loss = f(output)
G = tape.gradient(loss, output)
```
Then, using the chain rule and denoting `A'` to be the adjoint of a matrix
`A`, one can show that the gradient of `loss` with respect to `W` is given by
the block matrix `K := [X'G; Q'G]`. Hence, the square norm of `K`, i.e., what
is returned by `sqr_norm_fn` is given by
```
sqr_norm = <XX', GG'> + ||Q'G||_F^2
```
where `||.||_F` is the Frobenius norm and `<.,.>` is the Euclidean inner
product for matrices.
-----------------------------------------------------------------------------
END ANALYSIS
-----------------------------------------------------------------------------
Args:
equation: A `string` representing the einsum equation.
input_tensor: A `tf.Tensor` reprenting the einsum input.
grad_tensor: A `tf.Tensor` that is the gradient of the scalar loss with
respect to the pre-activation tensor.
bias_axes: An optional `string` that specifies the einsum biases in
`equation`.
num_microbatches: An optional `int` that specifies the number of
microbatches used in a batch.
Returns:
A 1D `tf.Tensor` whose i-th entry is the squared gradient corresponding
to the i-th example in `input_tensor`.
"""
# Compute the matrix `X X'` and `G G'` for each example or microbatch.
# `x.shape = (batch_size, num_rows, num_columns)`
x = _reshape_einsum_inputs(input_tensor, equation)
g = _reshape_einsum_outputs(grad_tensor, equation)
# Adding microbatches is equivalent to splitting the first `(batch_size)`
# axis into `(num_microbatches, microbatch_size)` axes and merging the
# `microbatch_size` axis with the `num_rows` axis via a reshape.
if num_microbatches is not None:
# `x.shape = (num_microbatches, microbatch_size, num_rows, num_columns)`
x = common_manip_utils.maybe_add_microbatch_axis(x, num_microbatches)
g = common_manip_utils.maybe_add_microbatch_axis(g, num_microbatches)
sx = tf.shape(x)
sg = tf.shape(g)
# `x.shape = (num_microbatches, microbatch_size * num_rows, num_columns)`
x = tf.reshape(x, shape=[sx[0], sx[1] * sx[2], sx[3]])
g = tf.reshape(g, shape=[sg[0], sg[1] * sg[2], sg[3]])
# NOTE: When the input/gradient tensors are 1D, it is MUCH faster to do
# a `tf.square()` + `tf.reduce_sum()` than a single `tf.matmul()`.
if (
_is_batch_of_vectors(input_tensor)
and _is_batch_of_vectors(grad_tensor)
and num_microbatches is None
):
x_matrix = tf.reshape(x, [tf.shape(x)[0], -1])
g_matrix = tf.reshape(g, [tf.shape(g)[0], -1])
batch_xxt = tf.reduce_sum(tf.square(x_matrix), axis=1)
batch_ggt = tf.reduce_sum(tf.square(g_matrix), axis=1)
else:
batch_xxt = tf.matmul(x, x, transpose_b=True)
batch_ggt = tf.matmul(g, g, transpose_b=True)
# Compute the (micro)batch inner product; adjust for biases if necessary.
batch_xxt_ggt = tf.multiply(batch_xxt, batch_ggt)
reduction_axes = tf.range(1, tf.rank(batch_xxt_ggt))
sqr_norms = tf.reduce_sum(batch_xxt_ggt, axis=reduction_axes)
if bias_axes is not None:
# The adjoint operator `Q` on `G` is a reduce sum on the axes in `G` that
# are not broadcasted from `bias`.
grad_rank = len(grad_tensor.shape)
adjoint_reduction_axes = _get_einsum_bias_adjoint_reduction_axes(
equation,
bias_axes,
grad_rank,
)
# Adding microbatches with non-trival bias axes is equivalent to splitting
# the first `(batch_size)` axis into `(num_microbatches, microbatch_size)`
# axes, and adding the `microbatch_size` axis (=1) to the reduction axes
# needed to compute the bias broadcast adjoint operator.
if num_microbatches is not None:
grad_tensor = common_manip_utils.maybe_add_microbatch_axis(
grad_tensor, num_microbatches
)
adjoint_reduction_axes = [i + 1 for i in adjoint_reduction_axes]
adjoint_reduction_axes = [1] + adjoint_reduction_axes
qg = tf.reduce_sum(grad_tensor, axis=adjoint_reduction_axes)
qg_reduction_axes = tf.range(1, tf.rank(qg))
bias_sqr_norms = tf.reduce_sum(tf.square(qg), axis=qg_reduction_axes)
sqr_norms += bias_sqr_norms
return sqr_norms

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization_test from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization_test
@ -21,7 +21,7 @@ class GradNormTpuTest(layer_normalization_test.GradNormTest):
def setUp(self): def setUp(self):
super(layer_normalization_test.GradNormTest, self).setUp() super(layer_normalization_test.GradNormTest, self).setUp()
self.strategy = ctu.create_tpu_strategy() self.strategy = common_test_utils.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0]) self.assertIn('TPU', self.strategy.extended.worker_devices[0])
self.using_tpu = True self.using_tpu = True