Clip (per-example) and aggregate gradients.

PiperOrigin-RevId: 480761907
This commit is contained in:
A. Unique TensorFlower 2022-10-12 17:42:56 -07:00
parent 71837fbeec
commit c25cb4a41b
3 changed files with 480 additions and 0 deletions

View file

@ -9,6 +9,18 @@ py_library(
srcs = ["__init__.py"],
)
py_library(
name = "clip_and_aggregate_gradients",
srcs = [
"clip_and_aggregate_gradients.py",
],
srcs_version = "PY3",
deps = [
"//third_party/py/six",
"//third_party/tensorflow/python/ops/parallel_for:control_flow_ops",
],
)
py_library(
name = "dp_optimizer",
srcs = [
@ -63,6 +75,14 @@ py_library(
deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"],
)
py_test(
name = "clip_and_aggregate_gradients_test",
srcs = ["clip_and_aggregate_gradients_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":clip_and_aggregate_gradients"],
)
py_test(
name = "dp_optimizer_test",
timeout = "long",

View file

@ -0,0 +1,247 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Per example gradients clipping and aggregation for sparse gradients.
Modified from tape.jacobian to support sparse gradients.
"""
import sys
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
import six
import tensorflow as tf
from tensorflow.python.ops.parallel_for import control_flow_ops # pylint: disable=g-direct-tensorflow-import
GradientTensor = Union[tf.Tensor, tf.IndexedSlices]
T = TypeVar('T')
Nested = Union[T, Tuple[Any, ...], List[Any], Dict[str, Any]]
def _deduplicate_batch_indexed_slices(
batched_values: tf.Tensor,
indices: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Removes duplication of indexed slices by summing them."""
perm = tf.concat([
tf.constant([1, 0], dtype=tf.int32),
tf.range(tf.rank(batched_values))[2:]
],
axis=0)
batched_values = tf.transpose(batched_values, perm=perm)
unique_indices, new_pos = tf.unique(indices)
summed_values = tf.math.unsorted_segment_sum(batched_values, new_pos,
tf.shape(unique_indices)[0])
return tf.transpose(summed_values, perm=perm), unique_indices
def _batch_global_norm(vals: List[tf.Tensor]) -> tf.Tensor:
"""Computes the global norm for each row in the batch."""
def _norm_squared(v):
return tf.cast(
tf.reduce_sum(
tf.reshape(tf.square(v), tf.stack([tf.shape(v)[0], -1])), axis=1),
tf.float32)
return tf.sqrt(tf.add_n([_norm_squared(v) for v in vals if v is not None]))
def _batch_clip_by_global_norm(
vals: List[tf.Tensor], normalize: bool,
l2_norm_clip: Optional[float]) -> List[tf.Tensor]:
"""Batch clips by global norm with normalize option."""
batch_global_norm = _batch_global_norm(vals)
if l2_norm_clip is None:
l2_norm_clip = 1.0
clip_ratio = l2_norm_clip / tf.maximum(batch_global_norm, 1e-8)
if not normalize:
clip_ratio = tf.minimum(1.0, clip_ratio)
def _expand_dims(e, v):
new_shape = tf.concat(
[tf.shape(v)[0:1],
tf.ones_like(tf.shape(v), dtype=tf.int32)[:-1]],
axis=0)
return tf.reshape(e, new_shape)
return [
v *
_expand_dims(tf.cast(clip_ratio, v.dtype), v) if v is not None else None
for v in vals
]
def clip_and_aggregate_gradients(
tape: tf.GradientTape,
target: tf.Tensor,
sources: Nested[tf.Tensor],
unconnected_gradients: tf.UnconnectedGradients = tf.UnconnectedGradients
.NONE,
normalize: bool = False,
l2_norm_clip: Optional[float] = None,
aggregate_method: str = 'mean',
keep_sparse_threshold: int = 10000) -> Nested[GradientTensor]:
"""Clips (per-example) and aggregates gradients.
This procedure computes the Jacobian with respect to a vectorized loss,
i.e. the `target` argument, clips the gradient with repsect to each
individual output, and sums the clipped gradients. This is correct as
per-example gradient if there is a one to one mapping from the input example
to the output loss.
Args:
tape: a persistent tape.
target: Tensor to be differentiated. It is assumed that each value in
`target` is associated with an example so the gradient clipping would be
applied to the vectorized target.
sources: a list or nested structure of Tensors or Variables. `target` will
be differentiated against elements in `sources`.
unconnected_gradients: a value which can either hold 'none' or 'zero' and
alters the value which will be returned if the `target` and `sources` are
unconnected. The possible values and effects are detailed in
'UnconnectedGradients' and it defaults to 'none'.
normalize: whether to normalize each gradient.
l2_norm_clip: when `normalize` is `True`, every gradient is scaled to
`l2_norm_clip` (which can be set to None, understood as 1). When
`normalize` is `False`, it performs the regular clipping, i.e. scaling the
gradient to `l2_norm_clip` only if the gradient's L2 norm is larger than
`l2_norm_clip`. When `l2_norm_clip` is `None`, do nothing.
aggregate_method: the method for aggregating the gradients. Currently only
supports `sum` and `mean`, default to `mean`.
keep_sparse_threshold: when the gradient is a `tf.IndexedSlices`,
`keep_sparse_threshold` is used to determine if we should keep it in its
sparse representation (when the number of embedding items, i.e. vocabulary
size >= `keep_sparse_threshold`) or convert it into a dense tensor (when <
`keep_sparse_threshold`). The reason for this parameter is that the
current implementation of embedding lookup merges all the indices in a
batch, hence the sparse representation has input size the same as the
number of indices. When it is larger than the embedding size, it would be
more efficient to convert the sparse representation to a dense tensor. So
this threshold should be set around the number of indices in a typical
batch. When it is -1, always convert the sparse tensor to a dense tensor.
Returns:
Gradients stored in the same structure as `sources` with a one to one
mapping to the variables in `sources`. Each gradients may be a dense
tensor or a `tf.IndexedSlices`.
Raises:
RuntimeError: if `tape` is not persistent.
ValueError: if aggregate_method is not 'mean' or 'sum'.
"""
if tape._tape is None: # pylint: disable=protected-access
raise RuntimeError('A non-persistent GradientTape can only be used to '
'compute one set of gradients (or jacobians)')
if aggregate_method not in ['mean', 'sum']:
raise ValueError('Only mean and sum methods are supported. But got '
f'{aggregate_method}')
flat_sources = tf.nest.flatten(sources)
# Note that we push and pop the tape here and below. This is needed since we
# need gradients through the enclosed operations.
with tape._ensure_recording(): # pylint: disable=protected-access
target = tf.reshape(target, [-1])
target_shape = target.shape
convert_to_dense_indicator = [True for _ in flat_sources]
if keep_sparse_threshold >= 0:
convert_to_dense_indicator = [
s.shape[0] < keep_sparse_threshold for s in flat_sources
]
def _unpack_indexed_slices(x, convert_to_dense):
"""Optionally unpacks `tf.IndexedSlices` to dict of three dense tensors."""
if convert_to_dense or not isinstance(x, tf.IndexedSlices):
# If x is kept as a tf.IndexedSlices, it will be converted to a dense
# tensor in pfor.
return x
return {
'indices': x.indices,
'values': x.values,
'dense_shape': x.dense_shape
}
def loop_fn(i):
with tape._ensure_recording(): # pylint: disable=protected-access
y = tf.gather(target, i)
g = tape.gradient(
y, flat_sources, unconnected_gradients=unconnected_gradients)
g = tf.nest.map_structure(_unpack_indexed_slices, g,
convert_to_dense_indicator)
return g
try:
target_size = int(target.shape[0])
except TypeError:
# When the shape is unavailable, fall back to the tensor op.
target_size = tf.shape(target)[0]
try:
output = control_flow_ops.pfor(loop_fn, target_size)
except ValueError as err:
six.reraise(
ValueError,
ValueError(
str(err) + '\nEncountered an exception while vectorizing the '
'jacobian computation. Consider using a non-vectorized version, '
'i.e. by computing the gradient for each output sequentially.'),
sys.exc_info()[2])
grads = []
for i, out in enumerate(output):
if out is not None:
# Determines if the output is a unpacked tf.IndexedSlices. Since `sources`
# has been flattened, it is only when the output is a dictionary (of three
# dense tensors).
if not isinstance(out, dict):
if tf.executing_eagerly():
out.set_shape(target_shape.concatenate(flat_sources[i].shape))
grads.append((out, None, None))
else:
# Remove duplicates at per-example level. This is for both correctness
# (when the same index gets gathered more than once in the same example)
# and efficiency (for the subsequent clipping). All the examples in
# the batch should have the same indices so it suffices to take the
# first row.
values, indices = _deduplicate_batch_indexed_slices(
out['values'], out['indices'][0])
# The `dense_shape` of all the examples are the same so we take the
# first row.
grads.append((values, indices, out['dense_shape'][0]))
else:
grads.append((None, None, None))
if normalize or l2_norm_clip is not None:
values, indices, dense_shape = zip(*grads)
values = _batch_clip_by_global_norm(values, normalize, l2_norm_clip)
grads = zip(values, indices, dense_shape)
new_output = []
for values, indices, dense_shape in grads:
if values is None:
new_output.append(None)
continue
if aggregate_method == 'sum':
values = tf.reduce_sum(values, axis=0)
else:
values = tf.reduce_mean(values, axis=0)
if indices is None:
new_output.append(values)
else:
new_output.append(
tf.IndexedSlices(
values=values, indices=indices, dense_shape=dense_shape))
return tf.nest.pack_sequence_as(sources, new_output)

View file

@ -0,0 +1,213 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test the correctness and sparseness of clip_and_aggregate_gradients."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.optimizers import clip_and_aggregate_gradients as cag
class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
"""Tests clip_and_aggreate_gradients."""
def _get_loss_and_vars_fn(self, n, keepdims=False):
"""Returns the function for creating the loss and variables."""
# The "model" here consists of both sparse and dense parameters to make sure
# `clip_and_aggregate_gradients` computes the gradients in the correct way
# and in the right format. The sparse layer is the embedding layer `emb0`,
# from which multiple embeddings are gathered, with indices stored
# in `ind0`. And the dense parameters is the variable var1 which is directly
# used. The loss is the quadratic loss between the model output and the
# data stored in `data0` and `data1`. We also add a dummy variable
# `dummy_var` which does not participate in the loss computation to test
# the `unconnected` argument.
emb0 = tf.keras.layers.Embedding(
4,
2,
embeddings_initializer=tf.keras.initializers.Constant(
np.array([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0]])))
ind0 = tf.constant([1, 1, 2, 3, 2])
data0 = tf.constant([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0],
[-2.0, -1.0], [-3.0, -2.0]])
var1 = tf.Variable([[1.0], [1.0], [2.0], [2.0], [3.0], [3.0]])
data1 = tf.constant([[-1.0], [-2.0], [-2.0], [-3.0], [-3.0], [-4.0]])
dummy_var = tf.Variable(np.array([[1.0]]).astype(np.float64))
def _loss(val0, val1):
return 0.5 * tf.reduce_sum(
input_tensor=tf.math.squared_difference(val0, val1), axis=1)
def _loss_and_vars_fn():
# We concatenate the embeddings with some constant values to make sure
# backprop does only go through those gathered indices.
val0 = tf.concat([emb0(ind0), tf.constant([[0.0, 0.0]])], axis=0)
loss = tf.reduce_sum(
tf.reshape(_loss(data0, val0) + _loss(data1, var1), [n, -1]),
keepdims=keepdims,
axis=1)
return loss, (emb0.embeddings, var1, dummy_var)
return _loss_and_vars_fn
def _get_true_grads(self,
n,
normalize=False,
l2_norm_clip=None,
agg_method='mean',
unconnected='none'):
# The per-example gradients (or jacobians) below are computed manually.
# With the (half) quadratic loss, it is the difference between the
# variable value and the data value.
grad0 = np.array([[[0., 0.], [-2., -3.], [0., 0.], [0., 0.]],
[[0., 0.], [-4., -5.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [-5., -6.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [4., 3.]],
[[0., 0.], [0., 0.], [4., 3.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]]],
dtype=np.float32)
grad1 = np.array([[[2.], [0.], [0.], [0.], [0.], [0.]],
[[0.], [3.], [0.], [0.], [0.], [0.]],
[[0.], [0.], [4.], [0.], [0.], [0.]],
[[0.], [0.], [0.], [5.], [0.], [0.]],
[[0.], [0.], [0.], [0.], [6.], [0.]],
[[0.], [0.], [0.], [0.], [0.], [7.]]],
dtype=np.float32)
grad2 = np.array([[[0.]], [[0.]], [[0.]], [[0.]], [[0.]], [[0.]]],
dtype=np.float64)
grads = [
np.sum(np.reshape(g, (n, -1, g.shape[1], g.shape[2])), axis=1)
for g in [grad0, grad1, grad2]
]
if normalize or l2_norm_clip is not None:
if l2_norm_clip is None:
l2_norm_clip = 1.0
global_norm = np.sqrt(
np.sum([
np.sum(np.square(np.reshape(g, (n, -1))), axis=1) for g in grads
],
axis=0))
clip_ratio = l2_norm_clip / np.maximum(global_norm, 1e-8)
if not normalize:
clip_ratio = np.minimum(1.0, clip_ratio)
r = np.reshape(clip_ratio, [n, 1, 1])
grads = [g * r for g in grads]
if agg_method == 'sum':
grads = [np.sum(g, axis=0) for g in grads]
else:
grads = [np.mean(g, axis=0) for g in grads]
if unconnected == 'none':
grads[2] = None
return grads
def _to_dense_array(self, g):
if g is None:
return None
return np.array(tf.convert_to_tensor(g))
@parameterized.parameters(
(6, False, None, 'mean', -1, 'none'),
(6, True, None, 'sum', 1, 'none'),
(2, False, None, 'sum', 3, 'none'),
(2, True, 100.0, 'mean', 1, 'zero'),
(3, False, 1.0, 'sum', 2, 'zero'),
(1, True, 0.5, 'mean', 3, 'none'),
)
def testCorrect(self, n, normalize, l2_norm_clip, agg_method,
keep_sparse_threshold, unconnected):
"""Tests the correctness of the computation."""
loss_and_vars_fn = self._get_loss_and_vars_fn(n)
true_grads = self._get_true_grads(n, normalize, l2_norm_clip, agg_method,
unconnected)
with tf.GradientTape() as tape:
loss, test_vars = loss_and_vars_fn()
results = cag.clip_and_aggregate_gradients(
tape,
loss,
test_vars,
normalize=normalize,
l2_norm_clip=l2_norm_clip,
aggregate_method=agg_method,
unconnected_gradients=unconnected,
keep_sparse_threshold=keep_sparse_threshold)
for r, t in zip(results, true_grads):
if t is None:
self.assertIsNone(r)
else:
r = self._to_dense_array(r)
self.assertAllCloseAccordingToType(r, t)
@parameterized.parameters(
(6, True),
(6, False),
(1, True),
(1, False),
)
def testTargetShape(self, n, keepdims):
"""Tests target gets vectorized regardless of their original shape."""
loss_and_vars_fn = self._get_loss_and_vars_fn(n, keepdims)
true_grads = self._get_true_grads(n)
with tf.GradientTape() as tape:
loss, test_vars = loss_and_vars_fn()
results = cag.clip_and_aggregate_gradients(tape, loss, test_vars)
for r, t in zip(results, true_grads):
if t is None:
self.assertIsNone(r)
else:
r = self._to_dense_array(r)
self.assertAllCloseAccordingToType(r, t)
@parameterized.parameters(
(-1),
(0),
(4),
(5),
)
def testSparse(self, keep_sparse_threshold):
"""Tests the outcome is in the desired (dense or sparse) tensor form."""
loss_and_vars_fn = self._get_loss_and_vars_fn(3)
with tf.GradientTape() as tape:
loss, test_vars = loss_and_vars_fn()
results = cag.clip_and_aggregate_gradients(
tape,
loss,
test_vars,
normalize=False,
l2_norm_clip=1.0,
aggregate_method='mean',
unconnected_gradients='zero',
keep_sparse_threshold=keep_sparse_threshold)
grads0, grads1, grads2 = results
# emb0 has 4 items so grads0 should be in the sparse, i.e.
# `tf.IndexedSlices`, form iff `keep_sparse_threshold` is in [0, 4].
if keep_sparse_threshold >= 0 and keep_sparse_threshold <= 4:
self.assertIsInstance(grads0, tf.IndexedSlices)
self.assertLen(grads0.indices, 3)
else:
self.assertIsInstance(grads0, tf.Tensor)
# grads1 and grads2 should always be in the dense, i.e. `tf.Tensor`, form.
self.assertIsInstance(grads1, tf.Tensor)
self.assertIsInstance(grads2, tf.Tensor)
if __name__ == '__main__':
tf.test.main()