Implement and test a registry function for tfm.nlp.layers.EinsumDense
+ small formatting fixes.
PiperOrigin-RevId: 576215816
This commit is contained in:
parent
8b52ba246c
commit
39c8a8c1af
6 changed files with 442 additions and 5 deletions
|
@ -13,6 +13,7 @@ py_library(
|
||||||
name = "einsum_utils",
|
name = "einsum_utils",
|
||||||
srcs = ["einsum_utils.py"],
|
srcs = ["einsum_utils.py"],
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
|
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
@ -24,6 +25,33 @@ py_test(
|
||||||
deps = [":einsum_utils"],
|
deps = [":einsum_utils"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "einsum_dense",
|
||||||
|
srcs = ["einsum_dense.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":einsum_utils",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "einsum_dense_test",
|
||||||
|
size = "large",
|
||||||
|
srcs = ["einsum_dense_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
shard_count = 12,
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":dense",
|
||||||
|
":einsum_dense",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "dense",
|
name = "dense",
|
||||||
srcs = ["dense.py"],
|
srcs = ["dense.py"],
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Fast clipping function for `tfm.nlp.layers.EinsumDense`."""
|
||||||
|
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any, Optional
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_utils
|
||||||
|
|
||||||
|
|
||||||
|
def einsum_layer_computation(
|
||||||
|
layer_instance: tf.keras.layers.EinsumDense,
|
||||||
|
input_args: Sequence[Any],
|
||||||
|
input_kwargs: Mapping[str, Any],
|
||||||
|
tape: tf.GradientTape,
|
||||||
|
num_microbatches: Optional[tf.Tensor] = None,
|
||||||
|
) -> type_aliases.RegistryFunctionOutput:
|
||||||
|
"""Registry function for `tf.keras.layers.EinsumDense`.
|
||||||
|
|
||||||
|
For the technical details, see the documentation of
|
||||||
|
`einsum_utils.compute_fast_einsum_gradient_norm()`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_instance: A `tf.keras.layers.EinsumDense` instance.
|
||||||
|
input_args: See `dense_layer_computation()` in `dense.py`.
|
||||||
|
input_kwargs: See `dense_layer_computation()` in `dense.py`.
|
||||||
|
tape: See `dense_layer_computation()` in `dense.py`.
|
||||||
|
num_microbatches: See `dense_layer_computation()` in `dense.py`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
See `dense_layer_computation()` in `dense.py`.
|
||||||
|
"""
|
||||||
|
if input_kwargs:
|
||||||
|
raise ValueError("EinsumDense layer calls should not receive kwargs.")
|
||||||
|
del input_kwargs
|
||||||
|
if len(input_args) != 1:
|
||||||
|
raise ValueError("Only layer inputs of length 1 are permitted.")
|
||||||
|
orig_activation = layer_instance.activation
|
||||||
|
# Some activation functions may not apply a transform to the elements of the
|
||||||
|
# output individually (which is needed for the fast clipping trick to work).
|
||||||
|
# To avoid this case, we watch the variables that are only generated by the
|
||||||
|
# linear transformation of the `EinsumDense` layer instance.
|
||||||
|
layer_instance.activation = None
|
||||||
|
base_vars = layer_instance(*input_args)
|
||||||
|
tape.watch(base_vars)
|
||||||
|
layer_instance.activation = orig_activation
|
||||||
|
outputs = orig_activation(base_vars) if orig_activation else base_vars
|
||||||
|
|
||||||
|
def sqr_norm_fn(grads):
|
||||||
|
return einsum_utils.compute_fast_einsum_squared_gradient_norm(
|
||||||
|
layer_instance.equation,
|
||||||
|
input_args[0],
|
||||||
|
grads,
|
||||||
|
layer_instance.bias_axes,
|
||||||
|
num_microbatches,
|
||||||
|
)
|
||||||
|
|
||||||
|
return base_vars, outputs, sqr_norm_fn
|
|
@ -0,0 +1,171 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_models as tfm
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_dense
|
||||||
|
|
||||||
|
|
||||||
|
def get_einsum_layer_generators():
|
||||||
|
def pure_einsum_layer(equation, output_dims, bias_axes):
|
||||||
|
return tf.keras.layers.EinsumDense(
|
||||||
|
equation, output_dims, bias_axes=bias_axes
|
||||||
|
)
|
||||||
|
|
||||||
|
def sigmoid_einsum_layer(equation, output_dims, bias_axes):
|
||||||
|
return tf.keras.layers.EinsumDense(
|
||||||
|
equation, output_dims, bias_axes=bias_axes, activation='sigmoid'
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'pure_einsum': pure_einsum_layer,
|
||||||
|
'sigmoid_einsum': sigmoid_einsum_layer,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_einsum_parameter_tuples():
|
||||||
|
# (equation, input_dims, output_dims, bias_axes)
|
||||||
|
return [
|
||||||
|
# Case C1.
|
||||||
|
('ab,bc->ac', [2], [3], None),
|
||||||
|
('ab,bc->ac', [2], [3], 'c'),
|
||||||
|
('abc,cd->abd', [2, 3], [2, 4], None),
|
||||||
|
('abc,cd->abd', [2, 3], [2, 4], 'b'),
|
||||||
|
('abc,cd->abd', [2, 3], [2, 4], 'd'),
|
||||||
|
('abc,cd->abd', [2, 3], [2, 4], 'bd'),
|
||||||
|
('abc,cef->abef', [2, 3], [2, 4, 5], None),
|
||||||
|
('abc,cef->abef', [2, 3], [2, 4, 5], 'bf'),
|
||||||
|
# Case C2.
|
||||||
|
('...b,bc->...c', [2, 3], [4], None),
|
||||||
|
('...b,bc->...c', [2, 3], [4], 'c'),
|
||||||
|
('...ab,bc->...ac', [2, 3], [2, 4], None),
|
||||||
|
('...ab,bc->...ac', [2, 4], [2, 4], 'c'),
|
||||||
|
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], None),
|
||||||
|
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'b'),
|
||||||
|
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'd'),
|
||||||
|
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'bd'),
|
||||||
|
('...abc,cef->...abef', [2, 3, 4], [2, 3, 5, 6], None),
|
||||||
|
('...abc,cef->...abef', [2, 3, 4], [2, 3, 5, 6], 'bf'),
|
||||||
|
# Case C3.
|
||||||
|
('ab...,bc->ac...', [2, 3], [4, 3], None),
|
||||||
|
('ab...,bc->ac...', [2, 3], [4, 3], 'c'),
|
||||||
|
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], None),
|
||||||
|
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'b'),
|
||||||
|
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'd'),
|
||||||
|
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'bd'),
|
||||||
|
('abc...,cef->abef...', [2, 3, 4], [2, 5, 6, 4], None),
|
||||||
|
('abc...,cef->abef...', [2, 3, 4], [2, 5, 6, 4], 'bf'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_einsum_layer_registry():
|
||||||
|
einsum_registry = layer_registry.LayerRegistry()
|
||||||
|
einsum_registry.insert(
|
||||||
|
tfm.nlp.layers.EinsumDense,
|
||||||
|
einsum_dense.einsum_layer_computation,
|
||||||
|
)
|
||||||
|
return einsum_registry
|
||||||
|
|
||||||
|
|
||||||
|
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.strategy = tf.distribute.get_strategy()
|
||||||
|
self.using_tpu = False
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
layer_name=list(get_einsum_layer_generators()),
|
||||||
|
param_tuple=get_einsum_parameter_tuples(),
|
||||||
|
num_microbatches=[None, 2],
|
||||||
|
is_eager=[True, False],
|
||||||
|
)
|
||||||
|
def test_gradient_norms_on_various_models(
|
||||||
|
self,
|
||||||
|
layer_name,
|
||||||
|
param_tuple,
|
||||||
|
num_microbatches,
|
||||||
|
is_eager,
|
||||||
|
):
|
||||||
|
# Parse inputs to generate test data. Note that each batched input is a
|
||||||
|
# reshape of a `tf.range()` call.
|
||||||
|
equation, input_dims, output_dims, bias_axes = param_tuple
|
||||||
|
batch_size = 4
|
||||||
|
example_size = tf.reduce_prod(input_dims)
|
||||||
|
example_values = tf.range(batch_size * example_size, dtype=tf.float32)
|
||||||
|
x_batch = tf.reshape(example_values, [batch_size] + input_dims)
|
||||||
|
|
||||||
|
# Make the layer generator via currying.
|
||||||
|
einsum_generator = get_einsum_layer_generators()[layer_name]
|
||||||
|
|
||||||
|
def curried_generator(a, b):
|
||||||
|
del a, b
|
||||||
|
return einsum_generator(equation, output_dims, bias_axes)
|
||||||
|
|
||||||
|
# Load shared assets to all devices.
|
||||||
|
with self.strategy.scope():
|
||||||
|
model = common_test_utils.get_model_from_generator(
|
||||||
|
model_generator=common_test_utils.make_one_layer_functional_model,
|
||||||
|
layer_generator=curried_generator,
|
||||||
|
input_dims=input_dims,
|
||||||
|
output_dims=output_dims,
|
||||||
|
is_eager=is_eager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define the main testing ops. These may be later compiled to a Graph op.
|
||||||
|
def test_op(x):
|
||||||
|
return common_test_utils.get_computed_and_true_norms_from_model(
|
||||||
|
model=model,
|
||||||
|
per_example_loss_fn=None,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
x_batch=x,
|
||||||
|
registry=get_einsum_layer_registry(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# TPUs can only run `tf.function`-decorated functions.
|
||||||
|
if self.using_tpu:
|
||||||
|
test_op = tf.function(test_op, autograph=False)
|
||||||
|
|
||||||
|
# TPUs use lower precision than CPUs, so we relax our criterion.
|
||||||
|
# E.g., one of the TPU runs generated the following results:
|
||||||
|
#
|
||||||
|
# computed_norm = 93.48296
|
||||||
|
# true_norm = 93.31176
|
||||||
|
# abs_diff = 0.17120361
|
||||||
|
# rel_diff = 0.00183475
|
||||||
|
#
|
||||||
|
# which is a reasonable level of error for computing gradient norms.
|
||||||
|
# Other trials also give an absolute (resp. relative) error of around
|
||||||
|
# 0.05 (resp. 0.0015).
|
||||||
|
rtol = 1e-2 if self.using_tpu else 1e-3
|
||||||
|
atol = 5e-1 if self.using_tpu else 1e-2
|
||||||
|
|
||||||
|
# Set up the device ops and run the test.
|
||||||
|
computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,))
|
||||||
|
# TPUs return replica contexts, which must be unwrapped.
|
||||||
|
if self.using_tpu:
|
||||||
|
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||||
|
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||||
|
computed_norms = computed_norms.values[0]
|
||||||
|
true_norms = true_norms.values[0]
|
||||||
|
expected_size = num_microbatches or batch_size
|
||||||
|
self.assertEqual(tf.shape(computed_norms)[0], expected_size)
|
||||||
|
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_dense_test
|
||||||
|
|
||||||
|
|
||||||
|
class GradNormTpuTest(einsum_dense_test.GradNormTest):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(einsum_dense_test.GradNormTest, self).setUp()
|
||||||
|
self.strategy = common_test_utils.create_tpu_strategy()
|
||||||
|
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||||
|
self.using_tpu = True
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -17,9 +17,11 @@ import enum
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
||||||
|
|
||||||
EquationType = enum.Enum(
|
EquationType = enum.Enum(
|
||||||
"EquationType",
|
"EquationType",
|
||||||
|
@ -139,9 +141,7 @@ def _reshape_einsum_inputs(
|
||||||
(num_batches, num_rows, num_columns)
|
(num_batches, num_rows, num_columns)
|
||||||
```
|
```
|
||||||
When `input_tensor` is a rank-2 `tf.Tensor`, the number of output rows is 1
|
When `input_tensor` is a rank-2 `tf.Tensor`, the number of output rows is 1
|
||||||
and the number of output columns is the second dimension of the input. The
|
and the number of output columns is the second dimension of the input.
|
||||||
product of the non-trivial dimensions of the output should be equal to
|
|
||||||
the product of the dimensions of `input_tensor`.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If `equation` is not a valid einsum equation in the context of
|
ValueError: If `equation` is not a valid einsum equation in the context of
|
||||||
|
@ -345,3 +345,141 @@ def _get_einsum_bias_adjoint_reduction_axes(
|
||||||
else:
|
else:
|
||||||
reduction_axes.append(idx)
|
reduction_axes.append(idx)
|
||||||
return reduction_axes
|
return reduction_axes
|
||||||
|
|
||||||
|
|
||||||
|
def compute_fast_einsum_squared_gradient_norm(
|
||||||
|
equation: str,
|
||||||
|
input_tensor: tf.Tensor,
|
||||||
|
grad_tensor: tf.Tensor,
|
||||||
|
bias_axes: Optional[str],
|
||||||
|
num_microbatches: Optional[int] = None,
|
||||||
|
) -> tf.Tensor:
|
||||||
|
"""Computes the batch gradient norms of an Einsum gradient decompostion.
|
||||||
|
|
||||||
|
This logic generalizes the one for `tf.keras.layers.Dense` and assumes that
|
||||||
|
the `equation` parameter is one of the following forms:
|
||||||
|
|
||||||
|
C1. ab,bc->ac,
|
||||||
|
C2. ...ab,bc->...ac
|
||||||
|
C3. ab...,bc->ac...
|
||||||
|
|
||||||
|
where `a`, `b`, and `c` are non-empty substrings.
|
||||||
|
|
||||||
|
For reference, we describe part of the mathematical analysis below. It can be
|
||||||
|
safely skipped upon the first reading of this docstring.
|
||||||
|
|
||||||
|
-----------------------------------------------------------------------------
|
||||||
|
BEGIN ANALYSIS
|
||||||
|
-----------------------------------------------------------------------------
|
||||||
|
For ease of exposition, all analysis is done for a single example, i.e.,
|
||||||
|
batch dimension is excluded from our consideration.
|
||||||
|
|
||||||
|
Recall that the einsum dense computation, excluding activation functions is of
|
||||||
|
the form
|
||||||
|
```
|
||||||
|
output = tf.einsum(equation, input, kernel) + bias,
|
||||||
|
```
|
||||||
|
where `bias` is broadcasted and summed with the output of the `tf.einsum()`
|
||||||
|
call, and equation is of the forms in C1, C2, and C3.
|
||||||
|
|
||||||
|
Mathematically, the above computation is equivalent to:
|
||||||
|
```
|
||||||
|
output = tf.matmul(X, W) + Q(bias)
|
||||||
|
```
|
||||||
|
where `X` (resp. `W`) is a 2D tensor reshaped from `input` (resp. `kernel`)
|
||||||
|
and `Q` is a linear operator that transforms `bias` to comport with the
|
||||||
|
tensor output by the `tf.matmul()` call. When generalizing to a batch of
|
||||||
|
examples, `X` is a 3D tensor whose first dimension is the batch dimension.
|
||||||
|
|
||||||
|
Following the same trick as for `tf.keras.layers.Dense` layers, suppose that
|
||||||
|
we have:
|
||||||
|
```
|
||||||
|
loss = f(output)
|
||||||
|
G = tape.gradient(loss, output)
|
||||||
|
```
|
||||||
|
Then, using the chain rule and denoting `A'` to be the adjoint of a matrix
|
||||||
|
`A`, one can show that the gradient of `loss` with respect to `W` is given by
|
||||||
|
the block matrix `K := [X'G; Q'G]`. Hence, the square norm of `K`, i.e., what
|
||||||
|
is returned by `sqr_norm_fn` is given by
|
||||||
|
```
|
||||||
|
sqr_norm = <XX', GG'> + ||Q'G||_F^2
|
||||||
|
```
|
||||||
|
where `||.||_F` is the Frobenius norm and `<.,.>` is the Euclidean inner
|
||||||
|
product for matrices.
|
||||||
|
-----------------------------------------------------------------------------
|
||||||
|
END ANALYSIS
|
||||||
|
-----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
Args:
|
||||||
|
equation: A `string` representing the einsum equation.
|
||||||
|
input_tensor: A `tf.Tensor` reprenting the einsum input.
|
||||||
|
grad_tensor: A `tf.Tensor` that is the gradient of the scalar loss with
|
||||||
|
respect to the pre-activation tensor.
|
||||||
|
bias_axes: An optional `string` that specifies the einsum biases in
|
||||||
|
`equation`.
|
||||||
|
num_microbatches: An optional `int` that specifies the number of
|
||||||
|
microbatches used in a batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 1D `tf.Tensor` whose i-th entry is the squared gradient corresponding
|
||||||
|
to the i-th example in `input_tensor`.
|
||||||
|
"""
|
||||||
|
# Compute the matrix `X X'` and `G G'` for each example or microbatch.
|
||||||
|
# `x.shape = (batch_size, num_rows, num_columns)`
|
||||||
|
x = _reshape_einsum_inputs(input_tensor, equation)
|
||||||
|
g = _reshape_einsum_outputs(grad_tensor, equation)
|
||||||
|
# Adding microbatches is equivalent to splitting the first `(batch_size)`
|
||||||
|
# axis into `(num_microbatches, microbatch_size)` axes and merging the
|
||||||
|
# `microbatch_size` axis with the `num_rows` axis via a reshape.
|
||||||
|
if num_microbatches is not None:
|
||||||
|
# `x.shape = (num_microbatches, microbatch_size, num_rows, num_columns)`
|
||||||
|
x = common_manip_utils.maybe_add_microbatch_axis(x, num_microbatches)
|
||||||
|
g = common_manip_utils.maybe_add_microbatch_axis(g, num_microbatches)
|
||||||
|
sx = tf.shape(x)
|
||||||
|
sg = tf.shape(g)
|
||||||
|
# `x.shape = (num_microbatches, microbatch_size * num_rows, num_columns)`
|
||||||
|
x = tf.reshape(x, shape=[sx[0], sx[1] * sx[2], sx[3]])
|
||||||
|
g = tf.reshape(g, shape=[sg[0], sg[1] * sg[2], sg[3]])
|
||||||
|
# NOTE: When the input/gradient tensors are 1D, it is MUCH faster to do
|
||||||
|
# a `tf.square()` + `tf.reduce_sum()` than a single `tf.matmul()`.
|
||||||
|
if (
|
||||||
|
_is_batch_of_vectors(input_tensor)
|
||||||
|
and _is_batch_of_vectors(grad_tensor)
|
||||||
|
and num_microbatches is None
|
||||||
|
):
|
||||||
|
x_matrix = tf.reshape(x, [tf.shape(x)[0], -1])
|
||||||
|
g_matrix = tf.reshape(g, [tf.shape(g)[0], -1])
|
||||||
|
batch_xxt = tf.reduce_sum(tf.square(x_matrix), axis=1)
|
||||||
|
batch_ggt = tf.reduce_sum(tf.square(g_matrix), axis=1)
|
||||||
|
else:
|
||||||
|
batch_xxt = tf.matmul(x, x, transpose_b=True)
|
||||||
|
batch_ggt = tf.matmul(g, g, transpose_b=True)
|
||||||
|
# Compute the (micro)batch inner product; adjust for biases if necessary.
|
||||||
|
batch_xxt_ggt = tf.multiply(batch_xxt, batch_ggt)
|
||||||
|
reduction_axes = tf.range(1, tf.rank(batch_xxt_ggt))
|
||||||
|
sqr_norms = tf.reduce_sum(batch_xxt_ggt, axis=reduction_axes)
|
||||||
|
if bias_axes is not None:
|
||||||
|
# The adjoint operator `Q` on `G` is a reduce sum on the axes in `G` that
|
||||||
|
# are not broadcasted from `bias`.
|
||||||
|
grad_rank = len(grad_tensor.shape)
|
||||||
|
adjoint_reduction_axes = _get_einsum_bias_adjoint_reduction_axes(
|
||||||
|
equation,
|
||||||
|
bias_axes,
|
||||||
|
grad_rank,
|
||||||
|
)
|
||||||
|
# Adding microbatches with non-trival bias axes is equivalent to splitting
|
||||||
|
# the first `(batch_size)` axis into `(num_microbatches, microbatch_size)`
|
||||||
|
# axes, and adding the `microbatch_size` axis (=1) to the reduction axes
|
||||||
|
# needed to compute the bias broadcast adjoint operator.
|
||||||
|
if num_microbatches is not None:
|
||||||
|
grad_tensor = common_manip_utils.maybe_add_microbatch_axis(
|
||||||
|
grad_tensor, num_microbatches
|
||||||
|
)
|
||||||
|
adjoint_reduction_axes = [i + 1 for i in adjoint_reduction_axes]
|
||||||
|
adjoint_reduction_axes = [1] + adjoint_reduction_axes
|
||||||
|
qg = tf.reduce_sum(grad_tensor, axis=adjoint_reduction_axes)
|
||||||
|
qg_reduction_axes = tf.range(1, tf.rank(qg))
|
||||||
|
bias_sqr_norms = tf.reduce_sum(tf.square(qg), axis=qg_reduction_axes)
|
||||||
|
sqr_norms += bias_sqr_norms
|
||||||
|
|
||||||
|
return sqr_norms
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization_test
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization_test
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ class GradNormTpuTest(layer_normalization_test.GradNormTest):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(layer_normalization_test.GradNormTest, self).setUp()
|
super(layer_normalization_test.GradNormTest, self).setUp()
|
||||||
self.strategy = ctu.create_tpu_strategy()
|
self.strategy = common_test_utils.create_tpu_strategy()
|
||||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||||
self.using_tpu = True
|
self.using_tpu = True
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue