Implement and test a registry function for tfm.nlp.layers.EinsumDense
+ small formatting fixes.
PiperOrigin-RevId: 576215816
This commit is contained in:
parent
8b52ba246c
commit
39c8a8c1af
6 changed files with 442 additions and 5 deletions
|
@ -13,6 +13,7 @@ py_library(
|
|||
name = "einsum_utils",
|
||||
srcs = ["einsum_utils.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -24,6 +25,33 @@ py_test(
|
|||
deps = [":einsum_utils"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "einsum_dense",
|
||||
srcs = ["einsum_dense.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":einsum_utils",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "einsum_dense_test",
|
||||
size = "large",
|
||||
srcs = ["einsum_dense_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 12,
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":dense",
|
||||
":einsum_dense",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dense",
|
||||
srcs = ["dense.py"],
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2023, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast clipping function for `tfm.nlp.layers.EinsumDense`."""
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_utils
|
||||
|
||||
|
||||
def einsum_layer_computation(
|
||||
layer_instance: tf.keras.layers.EinsumDense,
|
||||
input_args: Sequence[Any],
|
||||
input_kwargs: Mapping[str, Any],
|
||||
tape: tf.GradientTape,
|
||||
num_microbatches: Optional[tf.Tensor] = None,
|
||||
) -> type_aliases.RegistryFunctionOutput:
|
||||
"""Registry function for `tf.keras.layers.EinsumDense`.
|
||||
|
||||
For the technical details, see the documentation of
|
||||
`einsum_utils.compute_fast_einsum_gradient_norm()`.
|
||||
|
||||
Args:
|
||||
layer_instance: A `tf.keras.layers.EinsumDense` instance.
|
||||
input_args: See `dense_layer_computation()` in `dense.py`.
|
||||
input_kwargs: See `dense_layer_computation()` in `dense.py`.
|
||||
tape: See `dense_layer_computation()` in `dense.py`.
|
||||
num_microbatches: See `dense_layer_computation()` in `dense.py`.
|
||||
|
||||
Returns:
|
||||
See `dense_layer_computation()` in `dense.py`.
|
||||
"""
|
||||
if input_kwargs:
|
||||
raise ValueError("EinsumDense layer calls should not receive kwargs.")
|
||||
del input_kwargs
|
||||
if len(input_args) != 1:
|
||||
raise ValueError("Only layer inputs of length 1 are permitted.")
|
||||
orig_activation = layer_instance.activation
|
||||
# Some activation functions may not apply a transform to the elements of the
|
||||
# output individually (which is needed for the fast clipping trick to work).
|
||||
# To avoid this case, we watch the variables that are only generated by the
|
||||
# linear transformation of the `EinsumDense` layer instance.
|
||||
layer_instance.activation = None
|
||||
base_vars = layer_instance(*input_args)
|
||||
tape.watch(base_vars)
|
||||
layer_instance.activation = orig_activation
|
||||
outputs = orig_activation(base_vars) if orig_activation else base_vars
|
||||
|
||||
def sqr_norm_fn(grads):
|
||||
return einsum_utils.compute_fast_einsum_squared_gradient_norm(
|
||||
layer_instance.equation,
|
||||
input_args[0],
|
||||
grads,
|
||||
layer_instance.bias_axes,
|
||||
num_microbatches,
|
||||
)
|
||||
|
||||
return base_vars, outputs, sqr_norm_fn
|
|
@ -0,0 +1,171 @@
|
|||
# Copyright 2023, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
import tensorflow_models as tfm
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_dense
|
||||
|
||||
|
||||
def get_einsum_layer_generators():
|
||||
def pure_einsum_layer(equation, output_dims, bias_axes):
|
||||
return tf.keras.layers.EinsumDense(
|
||||
equation, output_dims, bias_axes=bias_axes
|
||||
)
|
||||
|
||||
def sigmoid_einsum_layer(equation, output_dims, bias_axes):
|
||||
return tf.keras.layers.EinsumDense(
|
||||
equation, output_dims, bias_axes=bias_axes, activation='sigmoid'
|
||||
)
|
||||
|
||||
return {
|
||||
'pure_einsum': pure_einsum_layer,
|
||||
'sigmoid_einsum': sigmoid_einsum_layer,
|
||||
}
|
||||
|
||||
|
||||
def get_einsum_parameter_tuples():
|
||||
# (equation, input_dims, output_dims, bias_axes)
|
||||
return [
|
||||
# Case C1.
|
||||
('ab,bc->ac', [2], [3], None),
|
||||
('ab,bc->ac', [2], [3], 'c'),
|
||||
('abc,cd->abd', [2, 3], [2, 4], None),
|
||||
('abc,cd->abd', [2, 3], [2, 4], 'b'),
|
||||
('abc,cd->abd', [2, 3], [2, 4], 'd'),
|
||||
('abc,cd->abd', [2, 3], [2, 4], 'bd'),
|
||||
('abc,cef->abef', [2, 3], [2, 4, 5], None),
|
||||
('abc,cef->abef', [2, 3], [2, 4, 5], 'bf'),
|
||||
# Case C2.
|
||||
('...b,bc->...c', [2, 3], [4], None),
|
||||
('...b,bc->...c', [2, 3], [4], 'c'),
|
||||
('...ab,bc->...ac', [2, 3], [2, 4], None),
|
||||
('...ab,bc->...ac', [2, 4], [2, 4], 'c'),
|
||||
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], None),
|
||||
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'b'),
|
||||
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'd'),
|
||||
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'bd'),
|
||||
('...abc,cef->...abef', [2, 3, 4], [2, 3, 5, 6], None),
|
||||
('...abc,cef->...abef', [2, 3, 4], [2, 3, 5, 6], 'bf'),
|
||||
# Case C3.
|
||||
('ab...,bc->ac...', [2, 3], [4, 3], None),
|
||||
('ab...,bc->ac...', [2, 3], [4, 3], 'c'),
|
||||
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], None),
|
||||
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'b'),
|
||||
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'd'),
|
||||
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'bd'),
|
||||
('abc...,cef->abef...', [2, 3, 4], [2, 5, 6, 4], None),
|
||||
('abc...,cef->abef...', [2, 3, 4], [2, 5, 6, 4], 'bf'),
|
||||
]
|
||||
|
||||
|
||||
def get_einsum_layer_registry():
|
||||
einsum_registry = layer_registry.LayerRegistry()
|
||||
einsum_registry.insert(
|
||||
tfm.nlp.layers.EinsumDense,
|
||||
einsum_dense.einsum_layer_computation,
|
||||
)
|
||||
return einsum_registry
|
||||
|
||||
|
||||
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.strategy = tf.distribute.get_strategy()
|
||||
self.using_tpu = False
|
||||
|
||||
@parameterized.product(
|
||||
layer_name=list(get_einsum_layer_generators()),
|
||||
param_tuple=get_einsum_parameter_tuples(),
|
||||
num_microbatches=[None, 2],
|
||||
is_eager=[True, False],
|
||||
)
|
||||
def test_gradient_norms_on_various_models(
|
||||
self,
|
||||
layer_name,
|
||||
param_tuple,
|
||||
num_microbatches,
|
||||
is_eager,
|
||||
):
|
||||
# Parse inputs to generate test data. Note that each batched input is a
|
||||
# reshape of a `tf.range()` call.
|
||||
equation, input_dims, output_dims, bias_axes = param_tuple
|
||||
batch_size = 4
|
||||
example_size = tf.reduce_prod(input_dims)
|
||||
example_values = tf.range(batch_size * example_size, dtype=tf.float32)
|
||||
x_batch = tf.reshape(example_values, [batch_size] + input_dims)
|
||||
|
||||
# Make the layer generator via currying.
|
||||
einsum_generator = get_einsum_layer_generators()[layer_name]
|
||||
|
||||
def curried_generator(a, b):
|
||||
del a, b
|
||||
return einsum_generator(equation, output_dims, bias_axes)
|
||||
|
||||
# Load shared assets to all devices.
|
||||
with self.strategy.scope():
|
||||
model = common_test_utils.get_model_from_generator(
|
||||
model_generator=common_test_utils.make_one_layer_functional_model,
|
||||
layer_generator=curried_generator,
|
||||
input_dims=input_dims,
|
||||
output_dims=output_dims,
|
||||
is_eager=is_eager,
|
||||
)
|
||||
|
||||
# Define the main testing ops. These may be later compiled to a Graph op.
|
||||
def test_op(x):
|
||||
return common_test_utils.get_computed_and_true_norms_from_model(
|
||||
model=model,
|
||||
per_example_loss_fn=None,
|
||||
num_microbatches=num_microbatches,
|
||||
x_batch=x,
|
||||
registry=get_einsum_layer_registry(),
|
||||
)
|
||||
|
||||
# TPUs can only run `tf.function`-decorated functions.
|
||||
if self.using_tpu:
|
||||
test_op = tf.function(test_op, autograph=False)
|
||||
|
||||
# TPUs use lower precision than CPUs, so we relax our criterion.
|
||||
# E.g., one of the TPU runs generated the following results:
|
||||
#
|
||||
# computed_norm = 93.48296
|
||||
# true_norm = 93.31176
|
||||
# abs_diff = 0.17120361
|
||||
# rel_diff = 0.00183475
|
||||
#
|
||||
# which is a reasonable level of error for computing gradient norms.
|
||||
# Other trials also give an absolute (resp. relative) error of around
|
||||
# 0.05 (resp. 0.0015).
|
||||
rtol = 1e-2 if self.using_tpu else 1e-3
|
||||
atol = 5e-1 if self.using_tpu else 1e-2
|
||||
|
||||
# Set up the device ops and run the test.
|
||||
computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,))
|
||||
# TPUs return replica contexts, which must be unwrapped.
|
||||
if self.using_tpu:
|
||||
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||
computed_norms = computed_norms.values[0]
|
||||
true_norms = true_norms.values[0]
|
||||
expected_size = num_microbatches or batch_size
|
||||
self.assertEqual(tf.shape(computed_norms)[0], expected_size)
|
||||
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2023, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_dense_test
|
||||
|
||||
|
||||
class GradNormTpuTest(einsum_dense_test.GradNormTest):
|
||||
|
||||
def setUp(self):
|
||||
super(einsum_dense_test.GradNormTest, self).setUp()
|
||||
self.strategy = common_test_utils.create_tpu_strategy()
|
||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||
self.using_tpu = True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -17,9 +17,11 @@ import enum
|
|||
import itertools
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
||||
|
||||
EquationType = enum.Enum(
|
||||
"EquationType",
|
||||
|
@ -139,9 +141,7 @@ def _reshape_einsum_inputs(
|
|||
(num_batches, num_rows, num_columns)
|
||||
```
|
||||
When `input_tensor` is a rank-2 `tf.Tensor`, the number of output rows is 1
|
||||
and the number of output columns is the second dimension of the input. The
|
||||
product of the non-trivial dimensions of the output should be equal to
|
||||
the product of the dimensions of `input_tensor`.
|
||||
and the number of output columns is the second dimension of the input.
|
||||
|
||||
Raises:
|
||||
ValueError: If `equation` is not a valid einsum equation in the context of
|
||||
|
@ -345,3 +345,141 @@ def _get_einsum_bias_adjoint_reduction_axes(
|
|||
else:
|
||||
reduction_axes.append(idx)
|
||||
return reduction_axes
|
||||
|
||||
|
||||
def compute_fast_einsum_squared_gradient_norm(
|
||||
equation: str,
|
||||
input_tensor: tf.Tensor,
|
||||
grad_tensor: tf.Tensor,
|
||||
bias_axes: Optional[str],
|
||||
num_microbatches: Optional[int] = None,
|
||||
) -> tf.Tensor:
|
||||
"""Computes the batch gradient norms of an Einsum gradient decompostion.
|
||||
|
||||
This logic generalizes the one for `tf.keras.layers.Dense` and assumes that
|
||||
the `equation` parameter is one of the following forms:
|
||||
|
||||
C1. ab,bc->ac,
|
||||
C2. ...ab,bc->...ac
|
||||
C3. ab...,bc->ac...
|
||||
|
||||
where `a`, `b`, and `c` are non-empty substrings.
|
||||
|
||||
For reference, we describe part of the mathematical analysis below. It can be
|
||||
safely skipped upon the first reading of this docstring.
|
||||
|
||||
-----------------------------------------------------------------------------
|
||||
BEGIN ANALYSIS
|
||||
-----------------------------------------------------------------------------
|
||||
For ease of exposition, all analysis is done for a single example, i.e.,
|
||||
batch dimension is excluded from our consideration.
|
||||
|
||||
Recall that the einsum dense computation, excluding activation functions is of
|
||||
the form
|
||||
```
|
||||
output = tf.einsum(equation, input, kernel) + bias,
|
||||
```
|
||||
where `bias` is broadcasted and summed with the output of the `tf.einsum()`
|
||||
call, and equation is of the forms in C1, C2, and C3.
|
||||
|
||||
Mathematically, the above computation is equivalent to:
|
||||
```
|
||||
output = tf.matmul(X, W) + Q(bias)
|
||||
```
|
||||
where `X` (resp. `W`) is a 2D tensor reshaped from `input` (resp. `kernel`)
|
||||
and `Q` is a linear operator that transforms `bias` to comport with the
|
||||
tensor output by the `tf.matmul()` call. When generalizing to a batch of
|
||||
examples, `X` is a 3D tensor whose first dimension is the batch dimension.
|
||||
|
||||
Following the same trick as for `tf.keras.layers.Dense` layers, suppose that
|
||||
we have:
|
||||
```
|
||||
loss = f(output)
|
||||
G = tape.gradient(loss, output)
|
||||
```
|
||||
Then, using the chain rule and denoting `A'` to be the adjoint of a matrix
|
||||
`A`, one can show that the gradient of `loss` with respect to `W` is given by
|
||||
the block matrix `K := [X'G; Q'G]`. Hence, the square norm of `K`, i.e., what
|
||||
is returned by `sqr_norm_fn` is given by
|
||||
```
|
||||
sqr_norm = <XX', GG'> + ||Q'G||_F^2
|
||||
```
|
||||
where `||.||_F` is the Frobenius norm and `<.,.>` is the Euclidean inner
|
||||
product for matrices.
|
||||
-----------------------------------------------------------------------------
|
||||
END ANALYSIS
|
||||
-----------------------------------------------------------------------------
|
||||
|
||||
Args:
|
||||
equation: A `string` representing the einsum equation.
|
||||
input_tensor: A `tf.Tensor` reprenting the einsum input.
|
||||
grad_tensor: A `tf.Tensor` that is the gradient of the scalar loss with
|
||||
respect to the pre-activation tensor.
|
||||
bias_axes: An optional `string` that specifies the einsum biases in
|
||||
`equation`.
|
||||
num_microbatches: An optional `int` that specifies the number of
|
||||
microbatches used in a batch.
|
||||
|
||||
Returns:
|
||||
A 1D `tf.Tensor` whose i-th entry is the squared gradient corresponding
|
||||
to the i-th example in `input_tensor`.
|
||||
"""
|
||||
# Compute the matrix `X X'` and `G G'` for each example or microbatch.
|
||||
# `x.shape = (batch_size, num_rows, num_columns)`
|
||||
x = _reshape_einsum_inputs(input_tensor, equation)
|
||||
g = _reshape_einsum_outputs(grad_tensor, equation)
|
||||
# Adding microbatches is equivalent to splitting the first `(batch_size)`
|
||||
# axis into `(num_microbatches, microbatch_size)` axes and merging the
|
||||
# `microbatch_size` axis with the `num_rows` axis via a reshape.
|
||||
if num_microbatches is not None:
|
||||
# `x.shape = (num_microbatches, microbatch_size, num_rows, num_columns)`
|
||||
x = common_manip_utils.maybe_add_microbatch_axis(x, num_microbatches)
|
||||
g = common_manip_utils.maybe_add_microbatch_axis(g, num_microbatches)
|
||||
sx = tf.shape(x)
|
||||
sg = tf.shape(g)
|
||||
# `x.shape = (num_microbatches, microbatch_size * num_rows, num_columns)`
|
||||
x = tf.reshape(x, shape=[sx[0], sx[1] * sx[2], sx[3]])
|
||||
g = tf.reshape(g, shape=[sg[0], sg[1] * sg[2], sg[3]])
|
||||
# NOTE: When the input/gradient tensors are 1D, it is MUCH faster to do
|
||||
# a `tf.square()` + `tf.reduce_sum()` than a single `tf.matmul()`.
|
||||
if (
|
||||
_is_batch_of_vectors(input_tensor)
|
||||
and _is_batch_of_vectors(grad_tensor)
|
||||
and num_microbatches is None
|
||||
):
|
||||
x_matrix = tf.reshape(x, [tf.shape(x)[0], -1])
|
||||
g_matrix = tf.reshape(g, [tf.shape(g)[0], -1])
|
||||
batch_xxt = tf.reduce_sum(tf.square(x_matrix), axis=1)
|
||||
batch_ggt = tf.reduce_sum(tf.square(g_matrix), axis=1)
|
||||
else:
|
||||
batch_xxt = tf.matmul(x, x, transpose_b=True)
|
||||
batch_ggt = tf.matmul(g, g, transpose_b=True)
|
||||
# Compute the (micro)batch inner product; adjust for biases if necessary.
|
||||
batch_xxt_ggt = tf.multiply(batch_xxt, batch_ggt)
|
||||
reduction_axes = tf.range(1, tf.rank(batch_xxt_ggt))
|
||||
sqr_norms = tf.reduce_sum(batch_xxt_ggt, axis=reduction_axes)
|
||||
if bias_axes is not None:
|
||||
# The adjoint operator `Q` on `G` is a reduce sum on the axes in `G` that
|
||||
# are not broadcasted from `bias`.
|
||||
grad_rank = len(grad_tensor.shape)
|
||||
adjoint_reduction_axes = _get_einsum_bias_adjoint_reduction_axes(
|
||||
equation,
|
||||
bias_axes,
|
||||
grad_rank,
|
||||
)
|
||||
# Adding microbatches with non-trival bias axes is equivalent to splitting
|
||||
# the first `(batch_size)` axis into `(num_microbatches, microbatch_size)`
|
||||
# axes, and adding the `microbatch_size` axis (=1) to the reduction axes
|
||||
# needed to compute the bias broadcast adjoint operator.
|
||||
if num_microbatches is not None:
|
||||
grad_tensor = common_manip_utils.maybe_add_microbatch_axis(
|
||||
grad_tensor, num_microbatches
|
||||
)
|
||||
adjoint_reduction_axes = [i + 1 for i in adjoint_reduction_axes]
|
||||
adjoint_reduction_axes = [1] + adjoint_reduction_axes
|
||||
qg = tf.reduce_sum(grad_tensor, axis=adjoint_reduction_axes)
|
||||
qg_reduction_axes = tf.range(1, tf.rank(qg))
|
||||
bias_sqr_norms = tf.reduce_sum(tf.square(qg), axis=qg_reduction_axes)
|
||||
sqr_norms += bias_sqr_norms
|
||||
|
||||
return sqr_norms
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization_test
|
||||
|
||||
|
||||
|
@ -21,7 +21,7 @@ class GradNormTpuTest(layer_normalization_test.GradNormTest):
|
|||
|
||||
def setUp(self):
|
||||
super(layer_normalization_test.GradNormTest, self).setUp()
|
||||
self.strategy = ctu.create_tpu_strategy()
|
||||
self.strategy = common_test_utils.create_tpu_strategy()
|
||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||
self.using_tpu = True
|
||||
|
||||
|
|
Loading…
Reference in a new issue