forked from 626_privacy/tensorflow_privacy
Clip (per-example) and aggregate gradients.
PiperOrigin-RevId: 480761907
This commit is contained in:
parent
71837fbeec
commit
c25cb4a41b
3 changed files with 480 additions and 0 deletions
|
@ -9,6 +9,18 @@ py_library(
|
||||||
srcs = ["__init__.py"],
|
srcs = ["__init__.py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "clip_and_aggregate_gradients",
|
||||||
|
srcs = [
|
||||||
|
"clip_and_aggregate_gradients.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
"//third_party/py/six",
|
||||||
|
"//third_party/tensorflow/python/ops/parallel_for:control_flow_ops",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "dp_optimizer",
|
name = "dp_optimizer",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -63,6 +75,14 @@ py_library(
|
||||||
deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"],
|
deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "clip_and_aggregate_gradients_test",
|
||||||
|
srcs = ["clip_and_aggregate_gradients_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [":clip_and_aggregate_gradients"],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "dp_optimizer_test",
|
name = "dp_optimizer_test",
|
||||||
timeout = "long",
|
timeout = "long",
|
||||||
|
|
|
@ -0,0 +1,247 @@
|
||||||
|
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Per example gradients clipping and aggregation for sparse gradients.
|
||||||
|
|
||||||
|
Modified from tape.jacobian to support sparse gradients.
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
|
||||||
|
|
||||||
|
import six
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.python.ops.parallel_for import control_flow_ops # pylint: disable=g-direct-tensorflow-import
|
||||||
|
|
||||||
|
GradientTensor = Union[tf.Tensor, tf.IndexedSlices]
|
||||||
|
T = TypeVar('T')
|
||||||
|
Nested = Union[T, Tuple[Any, ...], List[Any], Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
def _deduplicate_batch_indexed_slices(
|
||||||
|
batched_values: tf.Tensor,
|
||||||
|
indices: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||||
|
"""Removes duplication of indexed slices by summing them."""
|
||||||
|
perm = tf.concat([
|
||||||
|
tf.constant([1, 0], dtype=tf.int32),
|
||||||
|
tf.range(tf.rank(batched_values))[2:]
|
||||||
|
],
|
||||||
|
axis=0)
|
||||||
|
batched_values = tf.transpose(batched_values, perm=perm)
|
||||||
|
unique_indices, new_pos = tf.unique(indices)
|
||||||
|
summed_values = tf.math.unsorted_segment_sum(batched_values, new_pos,
|
||||||
|
tf.shape(unique_indices)[0])
|
||||||
|
return tf.transpose(summed_values, perm=perm), unique_indices
|
||||||
|
|
||||||
|
|
||||||
|
def _batch_global_norm(vals: List[tf.Tensor]) -> tf.Tensor:
|
||||||
|
"""Computes the global norm for each row in the batch."""
|
||||||
|
|
||||||
|
def _norm_squared(v):
|
||||||
|
return tf.cast(
|
||||||
|
tf.reduce_sum(
|
||||||
|
tf.reshape(tf.square(v), tf.stack([tf.shape(v)[0], -1])), axis=1),
|
||||||
|
tf.float32)
|
||||||
|
|
||||||
|
return tf.sqrt(tf.add_n([_norm_squared(v) for v in vals if v is not None]))
|
||||||
|
|
||||||
|
|
||||||
|
def _batch_clip_by_global_norm(
|
||||||
|
vals: List[tf.Tensor], normalize: bool,
|
||||||
|
l2_norm_clip: Optional[float]) -> List[tf.Tensor]:
|
||||||
|
"""Batch clips by global norm with normalize option."""
|
||||||
|
batch_global_norm = _batch_global_norm(vals)
|
||||||
|
if l2_norm_clip is None:
|
||||||
|
l2_norm_clip = 1.0
|
||||||
|
clip_ratio = l2_norm_clip / tf.maximum(batch_global_norm, 1e-8)
|
||||||
|
if not normalize:
|
||||||
|
clip_ratio = tf.minimum(1.0, clip_ratio)
|
||||||
|
|
||||||
|
def _expand_dims(e, v):
|
||||||
|
new_shape = tf.concat(
|
||||||
|
[tf.shape(v)[0:1],
|
||||||
|
tf.ones_like(tf.shape(v), dtype=tf.int32)[:-1]],
|
||||||
|
axis=0)
|
||||||
|
return tf.reshape(e, new_shape)
|
||||||
|
|
||||||
|
return [
|
||||||
|
v *
|
||||||
|
_expand_dims(tf.cast(clip_ratio, v.dtype), v) if v is not None else None
|
||||||
|
for v in vals
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def clip_and_aggregate_gradients(
|
||||||
|
tape: tf.GradientTape,
|
||||||
|
target: tf.Tensor,
|
||||||
|
sources: Nested[tf.Tensor],
|
||||||
|
unconnected_gradients: tf.UnconnectedGradients = tf.UnconnectedGradients
|
||||||
|
.NONE,
|
||||||
|
normalize: bool = False,
|
||||||
|
l2_norm_clip: Optional[float] = None,
|
||||||
|
aggregate_method: str = 'mean',
|
||||||
|
keep_sparse_threshold: int = 10000) -> Nested[GradientTensor]:
|
||||||
|
"""Clips (per-example) and aggregates gradients.
|
||||||
|
|
||||||
|
This procedure computes the Jacobian with respect to a vectorized loss,
|
||||||
|
i.e. the `target` argument, clips the gradient with repsect to each
|
||||||
|
individual output, and sums the clipped gradients. This is correct as
|
||||||
|
per-example gradient if there is a one to one mapping from the input example
|
||||||
|
to the output loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tape: a persistent tape.
|
||||||
|
target: Tensor to be differentiated. It is assumed that each value in
|
||||||
|
`target` is associated with an example so the gradient clipping would be
|
||||||
|
applied to the vectorized target.
|
||||||
|
sources: a list or nested structure of Tensors or Variables. `target` will
|
||||||
|
be differentiated against elements in `sources`.
|
||||||
|
unconnected_gradients: a value which can either hold 'none' or 'zero' and
|
||||||
|
alters the value which will be returned if the `target` and `sources` are
|
||||||
|
unconnected. The possible values and effects are detailed in
|
||||||
|
'UnconnectedGradients' and it defaults to 'none'.
|
||||||
|
normalize: whether to normalize each gradient.
|
||||||
|
l2_norm_clip: when `normalize` is `True`, every gradient is scaled to
|
||||||
|
`l2_norm_clip` (which can be set to None, understood as 1). When
|
||||||
|
`normalize` is `False`, it performs the regular clipping, i.e. scaling the
|
||||||
|
gradient to `l2_norm_clip` only if the gradient's L2 norm is larger than
|
||||||
|
`l2_norm_clip`. When `l2_norm_clip` is `None`, do nothing.
|
||||||
|
aggregate_method: the method for aggregating the gradients. Currently only
|
||||||
|
supports `sum` and `mean`, default to `mean`.
|
||||||
|
keep_sparse_threshold: when the gradient is a `tf.IndexedSlices`,
|
||||||
|
`keep_sparse_threshold` is used to determine if we should keep it in its
|
||||||
|
sparse representation (when the number of embedding items, i.e. vocabulary
|
||||||
|
size >= `keep_sparse_threshold`) or convert it into a dense tensor (when <
|
||||||
|
`keep_sparse_threshold`). The reason for this parameter is that the
|
||||||
|
current implementation of embedding lookup merges all the indices in a
|
||||||
|
batch, hence the sparse representation has input size the same as the
|
||||||
|
number of indices. When it is larger than the embedding size, it would be
|
||||||
|
more efficient to convert the sparse representation to a dense tensor. So
|
||||||
|
this threshold should be set around the number of indices in a typical
|
||||||
|
batch. When it is -1, always convert the sparse tensor to a dense tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Gradients stored in the same structure as `sources` with a one to one
|
||||||
|
mapping to the variables in `sources`. Each gradients may be a dense
|
||||||
|
tensor or a `tf.IndexedSlices`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: if `tape` is not persistent.
|
||||||
|
ValueError: if aggregate_method is not 'mean' or 'sum'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if tape._tape is None: # pylint: disable=protected-access
|
||||||
|
raise RuntimeError('A non-persistent GradientTape can only be used to '
|
||||||
|
'compute one set of gradients (or jacobians)')
|
||||||
|
|
||||||
|
if aggregate_method not in ['mean', 'sum']:
|
||||||
|
raise ValueError('Only mean and sum methods are supported. But got '
|
||||||
|
f'{aggregate_method}')
|
||||||
|
|
||||||
|
flat_sources = tf.nest.flatten(sources)
|
||||||
|
# Note that we push and pop the tape here and below. This is needed since we
|
||||||
|
# need gradients through the enclosed operations.
|
||||||
|
with tape._ensure_recording(): # pylint: disable=protected-access
|
||||||
|
target = tf.reshape(target, [-1])
|
||||||
|
target_shape = target.shape
|
||||||
|
|
||||||
|
convert_to_dense_indicator = [True for _ in flat_sources]
|
||||||
|
if keep_sparse_threshold >= 0:
|
||||||
|
convert_to_dense_indicator = [
|
||||||
|
s.shape[0] < keep_sparse_threshold for s in flat_sources
|
||||||
|
]
|
||||||
|
|
||||||
|
def _unpack_indexed_slices(x, convert_to_dense):
|
||||||
|
"""Optionally unpacks `tf.IndexedSlices` to dict of three dense tensors."""
|
||||||
|
if convert_to_dense or not isinstance(x, tf.IndexedSlices):
|
||||||
|
# If x is kept as a tf.IndexedSlices, it will be converted to a dense
|
||||||
|
# tensor in pfor.
|
||||||
|
return x
|
||||||
|
return {
|
||||||
|
'indices': x.indices,
|
||||||
|
'values': x.values,
|
||||||
|
'dense_shape': x.dense_shape
|
||||||
|
}
|
||||||
|
|
||||||
|
def loop_fn(i):
|
||||||
|
with tape._ensure_recording(): # pylint: disable=protected-access
|
||||||
|
y = tf.gather(target, i)
|
||||||
|
g = tape.gradient(
|
||||||
|
y, flat_sources, unconnected_gradients=unconnected_gradients)
|
||||||
|
g = tf.nest.map_structure(_unpack_indexed_slices, g,
|
||||||
|
convert_to_dense_indicator)
|
||||||
|
return g
|
||||||
|
|
||||||
|
try:
|
||||||
|
target_size = int(target.shape[0])
|
||||||
|
except TypeError:
|
||||||
|
# When the shape is unavailable, fall back to the tensor op.
|
||||||
|
target_size = tf.shape(target)[0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = control_flow_ops.pfor(loop_fn, target_size)
|
||||||
|
except ValueError as err:
|
||||||
|
six.reraise(
|
||||||
|
ValueError,
|
||||||
|
ValueError(
|
||||||
|
str(err) + '\nEncountered an exception while vectorizing the '
|
||||||
|
'jacobian computation. Consider using a non-vectorized version, '
|
||||||
|
'i.e. by computing the gradient for each output sequentially.'),
|
||||||
|
sys.exc_info()[2])
|
||||||
|
|
||||||
|
grads = []
|
||||||
|
for i, out in enumerate(output):
|
||||||
|
if out is not None:
|
||||||
|
# Determines if the output is a unpacked tf.IndexedSlices. Since `sources`
|
||||||
|
# has been flattened, it is only when the output is a dictionary (of three
|
||||||
|
# dense tensors).
|
||||||
|
if not isinstance(out, dict):
|
||||||
|
if tf.executing_eagerly():
|
||||||
|
out.set_shape(target_shape.concatenate(flat_sources[i].shape))
|
||||||
|
grads.append((out, None, None))
|
||||||
|
else:
|
||||||
|
# Remove duplicates at per-example level. This is for both correctness
|
||||||
|
# (when the same index gets gathered more than once in the same example)
|
||||||
|
# and efficiency (for the subsequent clipping). All the examples in
|
||||||
|
# the batch should have the same indices so it suffices to take the
|
||||||
|
# first row.
|
||||||
|
values, indices = _deduplicate_batch_indexed_slices(
|
||||||
|
out['values'], out['indices'][0])
|
||||||
|
# The `dense_shape` of all the examples are the same so we take the
|
||||||
|
# first row.
|
||||||
|
grads.append((values, indices, out['dense_shape'][0]))
|
||||||
|
else:
|
||||||
|
grads.append((None, None, None))
|
||||||
|
|
||||||
|
if normalize or l2_norm_clip is not None:
|
||||||
|
values, indices, dense_shape = zip(*grads)
|
||||||
|
values = _batch_clip_by_global_norm(values, normalize, l2_norm_clip)
|
||||||
|
grads = zip(values, indices, dense_shape)
|
||||||
|
|
||||||
|
new_output = []
|
||||||
|
for values, indices, dense_shape in grads:
|
||||||
|
if values is None:
|
||||||
|
new_output.append(None)
|
||||||
|
continue
|
||||||
|
if aggregate_method == 'sum':
|
||||||
|
values = tf.reduce_sum(values, axis=0)
|
||||||
|
else:
|
||||||
|
values = tf.reduce_mean(values, axis=0)
|
||||||
|
if indices is None:
|
||||||
|
new_output.append(values)
|
||||||
|
else:
|
||||||
|
new_output.append(
|
||||||
|
tf.IndexedSlices(
|
||||||
|
values=values, indices=indices, dense_shape=dense_shape))
|
||||||
|
return tf.nest.pack_sequence_as(sources, new_output)
|
|
@ -0,0 +1,213 @@
|
||||||
|
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Test the correctness and sparseness of clip_and_aggregate_gradients."""
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.optimizers import clip_and_aggregate_gradients as cag
|
||||||
|
|
||||||
|
|
||||||
|
class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
"""Tests clip_and_aggreate_gradients."""
|
||||||
|
|
||||||
|
def _get_loss_and_vars_fn(self, n, keepdims=False):
|
||||||
|
"""Returns the function for creating the loss and variables."""
|
||||||
|
# The "model" here consists of both sparse and dense parameters to make sure
|
||||||
|
# `clip_and_aggregate_gradients` computes the gradients in the correct way
|
||||||
|
# and in the right format. The sparse layer is the embedding layer `emb0`,
|
||||||
|
# from which multiple embeddings are gathered, with indices stored
|
||||||
|
# in `ind0`. And the dense parameters is the variable var1 which is directly
|
||||||
|
# used. The loss is the quadratic loss between the model output and the
|
||||||
|
# data stored in `data0` and `data1`. We also add a dummy variable
|
||||||
|
# `dummy_var` which does not participate in the loss computation to test
|
||||||
|
# the `unconnected` argument.
|
||||||
|
emb0 = tf.keras.layers.Embedding(
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
embeddings_initializer=tf.keras.initializers.Constant(
|
||||||
|
np.array([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0]])))
|
||||||
|
ind0 = tf.constant([1, 1, 2, 3, 2])
|
||||||
|
data0 = tf.constant([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0],
|
||||||
|
[-2.0, -1.0], [-3.0, -2.0]])
|
||||||
|
|
||||||
|
var1 = tf.Variable([[1.0], [1.0], [2.0], [2.0], [3.0], [3.0]])
|
||||||
|
data1 = tf.constant([[-1.0], [-2.0], [-2.0], [-3.0], [-3.0], [-4.0]])
|
||||||
|
|
||||||
|
dummy_var = tf.Variable(np.array([[1.0]]).astype(np.float64))
|
||||||
|
|
||||||
|
def _loss(val0, val1):
|
||||||
|
return 0.5 * tf.reduce_sum(
|
||||||
|
input_tensor=tf.math.squared_difference(val0, val1), axis=1)
|
||||||
|
|
||||||
|
def _loss_and_vars_fn():
|
||||||
|
# We concatenate the embeddings with some constant values to make sure
|
||||||
|
# backprop does only go through those gathered indices.
|
||||||
|
val0 = tf.concat([emb0(ind0), tf.constant([[0.0, 0.0]])], axis=0)
|
||||||
|
loss = tf.reduce_sum(
|
||||||
|
tf.reshape(_loss(data0, val0) + _loss(data1, var1), [n, -1]),
|
||||||
|
keepdims=keepdims,
|
||||||
|
axis=1)
|
||||||
|
return loss, (emb0.embeddings, var1, dummy_var)
|
||||||
|
|
||||||
|
return _loss_and_vars_fn
|
||||||
|
|
||||||
|
def _get_true_grads(self,
|
||||||
|
n,
|
||||||
|
normalize=False,
|
||||||
|
l2_norm_clip=None,
|
||||||
|
agg_method='mean',
|
||||||
|
unconnected='none'):
|
||||||
|
# The per-example gradients (or jacobians) below are computed manually.
|
||||||
|
# With the (half) quadratic loss, it is the difference between the
|
||||||
|
# variable value and the data value.
|
||||||
|
grad0 = np.array([[[0., 0.], [-2., -3.], [0., 0.], [0., 0.]],
|
||||||
|
[[0., 0.], [-4., -5.], [0., 0.], [0., 0.]],
|
||||||
|
[[0., 0.], [0., 0.], [-5., -6.], [0., 0.]],
|
||||||
|
[[0., 0.], [0., 0.], [0., 0.], [4., 3.]],
|
||||||
|
[[0., 0.], [0., 0.], [4., 3.], [0., 0.]],
|
||||||
|
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]]],
|
||||||
|
dtype=np.float32)
|
||||||
|
grad1 = np.array([[[2.], [0.], [0.], [0.], [0.], [0.]],
|
||||||
|
[[0.], [3.], [0.], [0.], [0.], [0.]],
|
||||||
|
[[0.], [0.], [4.], [0.], [0.], [0.]],
|
||||||
|
[[0.], [0.], [0.], [5.], [0.], [0.]],
|
||||||
|
[[0.], [0.], [0.], [0.], [6.], [0.]],
|
||||||
|
[[0.], [0.], [0.], [0.], [0.], [7.]]],
|
||||||
|
dtype=np.float32)
|
||||||
|
grad2 = np.array([[[0.]], [[0.]], [[0.]], [[0.]], [[0.]], [[0.]]],
|
||||||
|
dtype=np.float64)
|
||||||
|
|
||||||
|
grads = [
|
||||||
|
np.sum(np.reshape(g, (n, -1, g.shape[1], g.shape[2])), axis=1)
|
||||||
|
for g in [grad0, grad1, grad2]
|
||||||
|
]
|
||||||
|
|
||||||
|
if normalize or l2_norm_clip is not None:
|
||||||
|
if l2_norm_clip is None:
|
||||||
|
l2_norm_clip = 1.0
|
||||||
|
global_norm = np.sqrt(
|
||||||
|
np.sum([
|
||||||
|
np.sum(np.square(np.reshape(g, (n, -1))), axis=1) for g in grads
|
||||||
|
],
|
||||||
|
axis=0))
|
||||||
|
clip_ratio = l2_norm_clip / np.maximum(global_norm, 1e-8)
|
||||||
|
if not normalize:
|
||||||
|
clip_ratio = np.minimum(1.0, clip_ratio)
|
||||||
|
r = np.reshape(clip_ratio, [n, 1, 1])
|
||||||
|
grads = [g * r for g in grads]
|
||||||
|
|
||||||
|
if agg_method == 'sum':
|
||||||
|
grads = [np.sum(g, axis=0) for g in grads]
|
||||||
|
else:
|
||||||
|
grads = [np.mean(g, axis=0) for g in grads]
|
||||||
|
|
||||||
|
if unconnected == 'none':
|
||||||
|
grads[2] = None
|
||||||
|
return grads
|
||||||
|
|
||||||
|
def _to_dense_array(self, g):
|
||||||
|
if g is None:
|
||||||
|
return None
|
||||||
|
return np.array(tf.convert_to_tensor(g))
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(6, False, None, 'mean', -1, 'none'),
|
||||||
|
(6, True, None, 'sum', 1, 'none'),
|
||||||
|
(2, False, None, 'sum', 3, 'none'),
|
||||||
|
(2, True, 100.0, 'mean', 1, 'zero'),
|
||||||
|
(3, False, 1.0, 'sum', 2, 'zero'),
|
||||||
|
(1, True, 0.5, 'mean', 3, 'none'),
|
||||||
|
)
|
||||||
|
def testCorrect(self, n, normalize, l2_norm_clip, agg_method,
|
||||||
|
keep_sparse_threshold, unconnected):
|
||||||
|
"""Tests the correctness of the computation."""
|
||||||
|
loss_and_vars_fn = self._get_loss_and_vars_fn(n)
|
||||||
|
true_grads = self._get_true_grads(n, normalize, l2_norm_clip, agg_method,
|
||||||
|
unconnected)
|
||||||
|
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
loss, test_vars = loss_and_vars_fn()
|
||||||
|
results = cag.clip_and_aggregate_gradients(
|
||||||
|
tape,
|
||||||
|
loss,
|
||||||
|
test_vars,
|
||||||
|
normalize=normalize,
|
||||||
|
l2_norm_clip=l2_norm_clip,
|
||||||
|
aggregate_method=agg_method,
|
||||||
|
unconnected_gradients=unconnected,
|
||||||
|
keep_sparse_threshold=keep_sparse_threshold)
|
||||||
|
for r, t in zip(results, true_grads):
|
||||||
|
if t is None:
|
||||||
|
self.assertIsNone(r)
|
||||||
|
else:
|
||||||
|
r = self._to_dense_array(r)
|
||||||
|
self.assertAllCloseAccordingToType(r, t)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(6, True),
|
||||||
|
(6, False),
|
||||||
|
(1, True),
|
||||||
|
(1, False),
|
||||||
|
)
|
||||||
|
def testTargetShape(self, n, keepdims):
|
||||||
|
"""Tests target gets vectorized regardless of their original shape."""
|
||||||
|
loss_and_vars_fn = self._get_loss_and_vars_fn(n, keepdims)
|
||||||
|
true_grads = self._get_true_grads(n)
|
||||||
|
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
loss, test_vars = loss_and_vars_fn()
|
||||||
|
results = cag.clip_and_aggregate_gradients(tape, loss, test_vars)
|
||||||
|
for r, t in zip(results, true_grads):
|
||||||
|
if t is None:
|
||||||
|
self.assertIsNone(r)
|
||||||
|
else:
|
||||||
|
r = self._to_dense_array(r)
|
||||||
|
self.assertAllCloseAccordingToType(r, t)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(-1),
|
||||||
|
(0),
|
||||||
|
(4),
|
||||||
|
(5),
|
||||||
|
)
|
||||||
|
def testSparse(self, keep_sparse_threshold):
|
||||||
|
"""Tests the outcome is in the desired (dense or sparse) tensor form."""
|
||||||
|
loss_and_vars_fn = self._get_loss_and_vars_fn(3)
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
loss, test_vars = loss_and_vars_fn()
|
||||||
|
results = cag.clip_and_aggregate_gradients(
|
||||||
|
tape,
|
||||||
|
loss,
|
||||||
|
test_vars,
|
||||||
|
normalize=False,
|
||||||
|
l2_norm_clip=1.0,
|
||||||
|
aggregate_method='mean',
|
||||||
|
unconnected_gradients='zero',
|
||||||
|
keep_sparse_threshold=keep_sparse_threshold)
|
||||||
|
grads0, grads1, grads2 = results
|
||||||
|
# emb0 has 4 items so grads0 should be in the sparse, i.e.
|
||||||
|
# `tf.IndexedSlices`, form iff `keep_sparse_threshold` is in [0, 4].
|
||||||
|
if keep_sparse_threshold >= 0 and keep_sparse_threshold <= 4:
|
||||||
|
self.assertIsInstance(grads0, tf.IndexedSlices)
|
||||||
|
self.assertLen(grads0.indices, 3)
|
||||||
|
else:
|
||||||
|
self.assertIsInstance(grads0, tf.Tensor)
|
||||||
|
# grads1 and grads2 should always be in the dense, i.e. `tf.Tensor`, form.
|
||||||
|
self.assertIsInstance(grads1, tf.Tensor)
|
||||||
|
self.assertIsInstance(grads2, tf.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
Loading…
Reference in a new issue