forked from 626_privacy/tensorflow_privacy
Clip (per-example) and aggregate gradients.
PiperOrigin-RevId: 480761907
This commit is contained in:
parent
71837fbeec
commit
c25cb4a41b
3 changed files with 480 additions and 0 deletions
|
@ -9,6 +9,18 @@ py_library(
|
|||
srcs = ["__init__.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "clip_and_aggregate_gradients",
|
||||
srcs = [
|
||||
"clip_and_aggregate_gradients.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//third_party/py/six",
|
||||
"//third_party/tensorflow/python/ops/parallel_for:control_flow_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dp_optimizer",
|
||||
srcs = [
|
||||
|
@ -63,6 +75,14 @@ py_library(
|
|||
deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "clip_and_aggregate_gradients_test",
|
||||
srcs = ["clip_and_aggregate_gradients_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":clip_and_aggregate_gradients"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dp_optimizer_test",
|
||||
timeout = "long",
|
||||
|
|
|
@ -0,0 +1,247 @@
|
|||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Per example gradients clipping and aggregation for sparse gradients.
|
||||
|
||||
Modified from tape.jacobian to support sparse gradients.
|
||||
"""
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import six
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.ops.parallel_for import control_flow_ops # pylint: disable=g-direct-tensorflow-import
|
||||
|
||||
GradientTensor = Union[tf.Tensor, tf.IndexedSlices]
|
||||
T = TypeVar('T')
|
||||
Nested = Union[T, Tuple[Any, ...], List[Any], Dict[str, Any]]
|
||||
|
||||
|
||||
def _deduplicate_batch_indexed_slices(
|
||||
batched_values: tf.Tensor,
|
||||
indices: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||
"""Removes duplication of indexed slices by summing them."""
|
||||
perm = tf.concat([
|
||||
tf.constant([1, 0], dtype=tf.int32),
|
||||
tf.range(tf.rank(batched_values))[2:]
|
||||
],
|
||||
axis=0)
|
||||
batched_values = tf.transpose(batched_values, perm=perm)
|
||||
unique_indices, new_pos = tf.unique(indices)
|
||||
summed_values = tf.math.unsorted_segment_sum(batched_values, new_pos,
|
||||
tf.shape(unique_indices)[0])
|
||||
return tf.transpose(summed_values, perm=perm), unique_indices
|
||||
|
||||
|
||||
def _batch_global_norm(vals: List[tf.Tensor]) -> tf.Tensor:
|
||||
"""Computes the global norm for each row in the batch."""
|
||||
|
||||
def _norm_squared(v):
|
||||
return tf.cast(
|
||||
tf.reduce_sum(
|
||||
tf.reshape(tf.square(v), tf.stack([tf.shape(v)[0], -1])), axis=1),
|
||||
tf.float32)
|
||||
|
||||
return tf.sqrt(tf.add_n([_norm_squared(v) for v in vals if v is not None]))
|
||||
|
||||
|
||||
def _batch_clip_by_global_norm(
|
||||
vals: List[tf.Tensor], normalize: bool,
|
||||
l2_norm_clip: Optional[float]) -> List[tf.Tensor]:
|
||||
"""Batch clips by global norm with normalize option."""
|
||||
batch_global_norm = _batch_global_norm(vals)
|
||||
if l2_norm_clip is None:
|
||||
l2_norm_clip = 1.0
|
||||
clip_ratio = l2_norm_clip / tf.maximum(batch_global_norm, 1e-8)
|
||||
if not normalize:
|
||||
clip_ratio = tf.minimum(1.0, clip_ratio)
|
||||
|
||||
def _expand_dims(e, v):
|
||||
new_shape = tf.concat(
|
||||
[tf.shape(v)[0:1],
|
||||
tf.ones_like(tf.shape(v), dtype=tf.int32)[:-1]],
|
||||
axis=0)
|
||||
return tf.reshape(e, new_shape)
|
||||
|
||||
return [
|
||||
v *
|
||||
_expand_dims(tf.cast(clip_ratio, v.dtype), v) if v is not None else None
|
||||
for v in vals
|
||||
]
|
||||
|
||||
|
||||
def clip_and_aggregate_gradients(
|
||||
tape: tf.GradientTape,
|
||||
target: tf.Tensor,
|
||||
sources: Nested[tf.Tensor],
|
||||
unconnected_gradients: tf.UnconnectedGradients = tf.UnconnectedGradients
|
||||
.NONE,
|
||||
normalize: bool = False,
|
||||
l2_norm_clip: Optional[float] = None,
|
||||
aggregate_method: str = 'mean',
|
||||
keep_sparse_threshold: int = 10000) -> Nested[GradientTensor]:
|
||||
"""Clips (per-example) and aggregates gradients.
|
||||
|
||||
This procedure computes the Jacobian with respect to a vectorized loss,
|
||||
i.e. the `target` argument, clips the gradient with repsect to each
|
||||
individual output, and sums the clipped gradients. This is correct as
|
||||
per-example gradient if there is a one to one mapping from the input example
|
||||
to the output loss.
|
||||
|
||||
Args:
|
||||
tape: a persistent tape.
|
||||
target: Tensor to be differentiated. It is assumed that each value in
|
||||
`target` is associated with an example so the gradient clipping would be
|
||||
applied to the vectorized target.
|
||||
sources: a list or nested structure of Tensors or Variables. `target` will
|
||||
be differentiated against elements in `sources`.
|
||||
unconnected_gradients: a value which can either hold 'none' or 'zero' and
|
||||
alters the value which will be returned if the `target` and `sources` are
|
||||
unconnected. The possible values and effects are detailed in
|
||||
'UnconnectedGradients' and it defaults to 'none'.
|
||||
normalize: whether to normalize each gradient.
|
||||
l2_norm_clip: when `normalize` is `True`, every gradient is scaled to
|
||||
`l2_norm_clip` (which can be set to None, understood as 1). When
|
||||
`normalize` is `False`, it performs the regular clipping, i.e. scaling the
|
||||
gradient to `l2_norm_clip` only if the gradient's L2 norm is larger than
|
||||
`l2_norm_clip`. When `l2_norm_clip` is `None`, do nothing.
|
||||
aggregate_method: the method for aggregating the gradients. Currently only
|
||||
supports `sum` and `mean`, default to `mean`.
|
||||
keep_sparse_threshold: when the gradient is a `tf.IndexedSlices`,
|
||||
`keep_sparse_threshold` is used to determine if we should keep it in its
|
||||
sparse representation (when the number of embedding items, i.e. vocabulary
|
||||
size >= `keep_sparse_threshold`) or convert it into a dense tensor (when <
|
||||
`keep_sparse_threshold`). The reason for this parameter is that the
|
||||
current implementation of embedding lookup merges all the indices in a
|
||||
batch, hence the sparse representation has input size the same as the
|
||||
number of indices. When it is larger than the embedding size, it would be
|
||||
more efficient to convert the sparse representation to a dense tensor. So
|
||||
this threshold should be set around the number of indices in a typical
|
||||
batch. When it is -1, always convert the sparse tensor to a dense tensor.
|
||||
|
||||
Returns:
|
||||
Gradients stored in the same structure as `sources` with a one to one
|
||||
mapping to the variables in `sources`. Each gradients may be a dense
|
||||
tensor or a `tf.IndexedSlices`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if `tape` is not persistent.
|
||||
ValueError: if aggregate_method is not 'mean' or 'sum'.
|
||||
"""
|
||||
|
||||
if tape._tape is None: # pylint: disable=protected-access
|
||||
raise RuntimeError('A non-persistent GradientTape can only be used to '
|
||||
'compute one set of gradients (or jacobians)')
|
||||
|
||||
if aggregate_method not in ['mean', 'sum']:
|
||||
raise ValueError('Only mean and sum methods are supported. But got '
|
||||
f'{aggregate_method}')
|
||||
|
||||
flat_sources = tf.nest.flatten(sources)
|
||||
# Note that we push and pop the tape here and below. This is needed since we
|
||||
# need gradients through the enclosed operations.
|
||||
with tape._ensure_recording(): # pylint: disable=protected-access
|
||||
target = tf.reshape(target, [-1])
|
||||
target_shape = target.shape
|
||||
|
||||
convert_to_dense_indicator = [True for _ in flat_sources]
|
||||
if keep_sparse_threshold >= 0:
|
||||
convert_to_dense_indicator = [
|
||||
s.shape[0] < keep_sparse_threshold for s in flat_sources
|
||||
]
|
||||
|
||||
def _unpack_indexed_slices(x, convert_to_dense):
|
||||
"""Optionally unpacks `tf.IndexedSlices` to dict of three dense tensors."""
|
||||
if convert_to_dense or not isinstance(x, tf.IndexedSlices):
|
||||
# If x is kept as a tf.IndexedSlices, it will be converted to a dense
|
||||
# tensor in pfor.
|
||||
return x
|
||||
return {
|
||||
'indices': x.indices,
|
||||
'values': x.values,
|
||||
'dense_shape': x.dense_shape
|
||||
}
|
||||
|
||||
def loop_fn(i):
|
||||
with tape._ensure_recording(): # pylint: disable=protected-access
|
||||
y = tf.gather(target, i)
|
||||
g = tape.gradient(
|
||||
y, flat_sources, unconnected_gradients=unconnected_gradients)
|
||||
g = tf.nest.map_structure(_unpack_indexed_slices, g,
|
||||
convert_to_dense_indicator)
|
||||
return g
|
||||
|
||||
try:
|
||||
target_size = int(target.shape[0])
|
||||
except TypeError:
|
||||
# When the shape is unavailable, fall back to the tensor op.
|
||||
target_size = tf.shape(target)[0]
|
||||
|
||||
try:
|
||||
output = control_flow_ops.pfor(loop_fn, target_size)
|
||||
except ValueError as err:
|
||||
six.reraise(
|
||||
ValueError,
|
||||
ValueError(
|
||||
str(err) + '\nEncountered an exception while vectorizing the '
|
||||
'jacobian computation. Consider using a non-vectorized version, '
|
||||
'i.e. by computing the gradient for each output sequentially.'),
|
||||
sys.exc_info()[2])
|
||||
|
||||
grads = []
|
||||
for i, out in enumerate(output):
|
||||
if out is not None:
|
||||
# Determines if the output is a unpacked tf.IndexedSlices. Since `sources`
|
||||
# has been flattened, it is only when the output is a dictionary (of three
|
||||
# dense tensors).
|
||||
if not isinstance(out, dict):
|
||||
if tf.executing_eagerly():
|
||||
out.set_shape(target_shape.concatenate(flat_sources[i].shape))
|
||||
grads.append((out, None, None))
|
||||
else:
|
||||
# Remove duplicates at per-example level. This is for both correctness
|
||||
# (when the same index gets gathered more than once in the same example)
|
||||
# and efficiency (for the subsequent clipping). All the examples in
|
||||
# the batch should have the same indices so it suffices to take the
|
||||
# first row.
|
||||
values, indices = _deduplicate_batch_indexed_slices(
|
||||
out['values'], out['indices'][0])
|
||||
# The `dense_shape` of all the examples are the same so we take the
|
||||
# first row.
|
||||
grads.append((values, indices, out['dense_shape'][0]))
|
||||
else:
|
||||
grads.append((None, None, None))
|
||||
|
||||
if normalize or l2_norm_clip is not None:
|
||||
values, indices, dense_shape = zip(*grads)
|
||||
values = _batch_clip_by_global_norm(values, normalize, l2_norm_clip)
|
||||
grads = zip(values, indices, dense_shape)
|
||||
|
||||
new_output = []
|
||||
for values, indices, dense_shape in grads:
|
||||
if values is None:
|
||||
new_output.append(None)
|
||||
continue
|
||||
if aggregate_method == 'sum':
|
||||
values = tf.reduce_sum(values, axis=0)
|
||||
else:
|
||||
values = tf.reduce_mean(values, axis=0)
|
||||
if indices is None:
|
||||
new_output.append(values)
|
||||
else:
|
||||
new_output.append(
|
||||
tf.IndexedSlices(
|
||||
values=values, indices=indices, dense_shape=dense_shape))
|
||||
return tf.nest.pack_sequence_as(sources, new_output)
|
|
@ -0,0 +1,213 @@
|
|||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test the correctness and sparseness of clip_and_aggregate_gradients."""
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.optimizers import clip_and_aggregate_gradients as cag
|
||||
|
||||
|
||||
class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||
"""Tests clip_and_aggreate_gradients."""
|
||||
|
||||
def _get_loss_and_vars_fn(self, n, keepdims=False):
|
||||
"""Returns the function for creating the loss and variables."""
|
||||
# The "model" here consists of both sparse and dense parameters to make sure
|
||||
# `clip_and_aggregate_gradients` computes the gradients in the correct way
|
||||
# and in the right format. The sparse layer is the embedding layer `emb0`,
|
||||
# from which multiple embeddings are gathered, with indices stored
|
||||
# in `ind0`. And the dense parameters is the variable var1 which is directly
|
||||
# used. The loss is the quadratic loss between the model output and the
|
||||
# data stored in `data0` and `data1`. We also add a dummy variable
|
||||
# `dummy_var` which does not participate in the loss computation to test
|
||||
# the `unconnected` argument.
|
||||
emb0 = tf.keras.layers.Embedding(
|
||||
4,
|
||||
2,
|
||||
embeddings_initializer=tf.keras.initializers.Constant(
|
||||
np.array([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0]])))
|
||||
ind0 = tf.constant([1, 1, 2, 3, 2])
|
||||
data0 = tf.constant([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0],
|
||||
[-2.0, -1.0], [-3.0, -2.0]])
|
||||
|
||||
var1 = tf.Variable([[1.0], [1.0], [2.0], [2.0], [3.0], [3.0]])
|
||||
data1 = tf.constant([[-1.0], [-2.0], [-2.0], [-3.0], [-3.0], [-4.0]])
|
||||
|
||||
dummy_var = tf.Variable(np.array([[1.0]]).astype(np.float64))
|
||||
|
||||
def _loss(val0, val1):
|
||||
return 0.5 * tf.reduce_sum(
|
||||
input_tensor=tf.math.squared_difference(val0, val1), axis=1)
|
||||
|
||||
def _loss_and_vars_fn():
|
||||
# We concatenate the embeddings with some constant values to make sure
|
||||
# backprop does only go through those gathered indices.
|
||||
val0 = tf.concat([emb0(ind0), tf.constant([[0.0, 0.0]])], axis=0)
|
||||
loss = tf.reduce_sum(
|
||||
tf.reshape(_loss(data0, val0) + _loss(data1, var1), [n, -1]),
|
||||
keepdims=keepdims,
|
||||
axis=1)
|
||||
return loss, (emb0.embeddings, var1, dummy_var)
|
||||
|
||||
return _loss_and_vars_fn
|
||||
|
||||
def _get_true_grads(self,
|
||||
n,
|
||||
normalize=False,
|
||||
l2_norm_clip=None,
|
||||
agg_method='mean',
|
||||
unconnected='none'):
|
||||
# The per-example gradients (or jacobians) below are computed manually.
|
||||
# With the (half) quadratic loss, it is the difference between the
|
||||
# variable value and the data value.
|
||||
grad0 = np.array([[[0., 0.], [-2., -3.], [0., 0.], [0., 0.]],
|
||||
[[0., 0.], [-4., -5.], [0., 0.], [0., 0.]],
|
||||
[[0., 0.], [0., 0.], [-5., -6.], [0., 0.]],
|
||||
[[0., 0.], [0., 0.], [0., 0.], [4., 3.]],
|
||||
[[0., 0.], [0., 0.], [4., 3.], [0., 0.]],
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]]],
|
||||
dtype=np.float32)
|
||||
grad1 = np.array([[[2.], [0.], [0.], [0.], [0.], [0.]],
|
||||
[[0.], [3.], [0.], [0.], [0.], [0.]],
|
||||
[[0.], [0.], [4.], [0.], [0.], [0.]],
|
||||
[[0.], [0.], [0.], [5.], [0.], [0.]],
|
||||
[[0.], [0.], [0.], [0.], [6.], [0.]],
|
||||
[[0.], [0.], [0.], [0.], [0.], [7.]]],
|
||||
dtype=np.float32)
|
||||
grad2 = np.array([[[0.]], [[0.]], [[0.]], [[0.]], [[0.]], [[0.]]],
|
||||
dtype=np.float64)
|
||||
|
||||
grads = [
|
||||
np.sum(np.reshape(g, (n, -1, g.shape[1], g.shape[2])), axis=1)
|
||||
for g in [grad0, grad1, grad2]
|
||||
]
|
||||
|
||||
if normalize or l2_norm_clip is not None:
|
||||
if l2_norm_clip is None:
|
||||
l2_norm_clip = 1.0
|
||||
global_norm = np.sqrt(
|
||||
np.sum([
|
||||
np.sum(np.square(np.reshape(g, (n, -1))), axis=1) for g in grads
|
||||
],
|
||||
axis=0))
|
||||
clip_ratio = l2_norm_clip / np.maximum(global_norm, 1e-8)
|
||||
if not normalize:
|
||||
clip_ratio = np.minimum(1.0, clip_ratio)
|
||||
r = np.reshape(clip_ratio, [n, 1, 1])
|
||||
grads = [g * r for g in grads]
|
||||
|
||||
if agg_method == 'sum':
|
||||
grads = [np.sum(g, axis=0) for g in grads]
|
||||
else:
|
||||
grads = [np.mean(g, axis=0) for g in grads]
|
||||
|
||||
if unconnected == 'none':
|
||||
grads[2] = None
|
||||
return grads
|
||||
|
||||
def _to_dense_array(self, g):
|
||||
if g is None:
|
||||
return None
|
||||
return np.array(tf.convert_to_tensor(g))
|
||||
|
||||
@parameterized.parameters(
|
||||
(6, False, None, 'mean', -1, 'none'),
|
||||
(6, True, None, 'sum', 1, 'none'),
|
||||
(2, False, None, 'sum', 3, 'none'),
|
||||
(2, True, 100.0, 'mean', 1, 'zero'),
|
||||
(3, False, 1.0, 'sum', 2, 'zero'),
|
||||
(1, True, 0.5, 'mean', 3, 'none'),
|
||||
)
|
||||
def testCorrect(self, n, normalize, l2_norm_clip, agg_method,
|
||||
keep_sparse_threshold, unconnected):
|
||||
"""Tests the correctness of the computation."""
|
||||
loss_and_vars_fn = self._get_loss_and_vars_fn(n)
|
||||
true_grads = self._get_true_grads(n, normalize, l2_norm_clip, agg_method,
|
||||
unconnected)
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
loss, test_vars = loss_and_vars_fn()
|
||||
results = cag.clip_and_aggregate_gradients(
|
||||
tape,
|
||||
loss,
|
||||
test_vars,
|
||||
normalize=normalize,
|
||||
l2_norm_clip=l2_norm_clip,
|
||||
aggregate_method=agg_method,
|
||||
unconnected_gradients=unconnected,
|
||||
keep_sparse_threshold=keep_sparse_threshold)
|
||||
for r, t in zip(results, true_grads):
|
||||
if t is None:
|
||||
self.assertIsNone(r)
|
||||
else:
|
||||
r = self._to_dense_array(r)
|
||||
self.assertAllCloseAccordingToType(r, t)
|
||||
|
||||
@parameterized.parameters(
|
||||
(6, True),
|
||||
(6, False),
|
||||
(1, True),
|
||||
(1, False),
|
||||
)
|
||||
def testTargetShape(self, n, keepdims):
|
||||
"""Tests target gets vectorized regardless of their original shape."""
|
||||
loss_and_vars_fn = self._get_loss_and_vars_fn(n, keepdims)
|
||||
true_grads = self._get_true_grads(n)
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
loss, test_vars = loss_and_vars_fn()
|
||||
results = cag.clip_and_aggregate_gradients(tape, loss, test_vars)
|
||||
for r, t in zip(results, true_grads):
|
||||
if t is None:
|
||||
self.assertIsNone(r)
|
||||
else:
|
||||
r = self._to_dense_array(r)
|
||||
self.assertAllCloseAccordingToType(r, t)
|
||||
|
||||
@parameterized.parameters(
|
||||
(-1),
|
||||
(0),
|
||||
(4),
|
||||
(5),
|
||||
)
|
||||
def testSparse(self, keep_sparse_threshold):
|
||||
"""Tests the outcome is in the desired (dense or sparse) tensor form."""
|
||||
loss_and_vars_fn = self._get_loss_and_vars_fn(3)
|
||||
with tf.GradientTape() as tape:
|
||||
loss, test_vars = loss_and_vars_fn()
|
||||
results = cag.clip_and_aggregate_gradients(
|
||||
tape,
|
||||
loss,
|
||||
test_vars,
|
||||
normalize=False,
|
||||
l2_norm_clip=1.0,
|
||||
aggregate_method='mean',
|
||||
unconnected_gradients='zero',
|
||||
keep_sparse_threshold=keep_sparse_threshold)
|
||||
grads0, grads1, grads2 = results
|
||||
# emb0 has 4 items so grads0 should be in the sparse, i.e.
|
||||
# `tf.IndexedSlices`, form iff `keep_sparse_threshold` is in [0, 4].
|
||||
if keep_sparse_threshold >= 0 and keep_sparse_threshold <= 4:
|
||||
self.assertIsInstance(grads0, tf.IndexedSlices)
|
||||
self.assertLen(grads0.indices, 3)
|
||||
else:
|
||||
self.assertIsInstance(grads0, tf.Tensor)
|
||||
# grads1 and grads2 should always be in the dense, i.e. `tf.Tensor`, form.
|
||||
self.assertIsInstance(grads1, tf.Tensor)
|
||||
self.assertIsInstance(grads2, tf.Tensor)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
Loading…
Reference in a new issue