Remove pfor dependency in BUILD file, and strengthen unit tests for clip_and_aggregate_gradients.py.
PiperOrigin-RevId: 482050282
This commit is contained in:
parent
4aa531faa4
commit
0fcfd0bf69
3 changed files with 70 additions and 30 deletions
|
@ -15,7 +15,6 @@ py_library(
|
|||
"clip_and_aggregate_gradients.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = ["//third_party/tensorflow/python/ops/parallel_for:control_flow_ops"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
|
|
@ -191,9 +191,8 @@ def clip_and_aggregate_gradients(
|
|||
output = control_flow_ops.pfor(loop_fn, target_size)
|
||||
except ValueError as err:
|
||||
raise ValueError(
|
||||
'Encountered an exception while vectorizing the '
|
||||
'batch_jacobian computation. Vectorization can be disabled by '
|
||||
'setting experimental_use_pfor to False.') from err
|
||||
'Encountered an exception while vectorizing the jacobian computation. '
|
||||
'Consider using a non-vectorized optimizer instead.') from err
|
||||
grads = []
|
||||
for i, out in enumerate(output):
|
||||
if out is not None:
|
||||
|
|
|
@ -22,8 +22,8 @@ from tensorflow_privacy.privacy.optimizers import clip_and_aggregate_gradients a
|
|||
class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||
"""Tests clip_and_aggreate_gradients."""
|
||||
|
||||
def _get_loss_and_vars_fn(self, n, keepdims=False):
|
||||
"""Returns the function for creating the loss and variables."""
|
||||
def _get_loss_and_vars(self, n, keepdims=False, is_callable=False):
|
||||
"""Returns the loss and variable tensors or callables."""
|
||||
# The "model" here consists of both sparse and dense parameters to make sure
|
||||
# `clip_and_aggregate_gradients` computes the gradients in the correct way
|
||||
# and in the right format. The sparse layer is the embedding layer `emb0`,
|
||||
|
@ -51,9 +51,20 @@ class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
return 0.5 * tf.reduce_sum(
|
||||
input_tensor=tf.math.squared_difference(val0, val1), axis=1)
|
||||
|
||||
def _loss_and_vars_fn():
|
||||
# We concatenate the embeddings with some constant values to make sure
|
||||
# backprop does only go through those gathered indices.
|
||||
if is_callable:
|
||||
|
||||
def loss_fn():
|
||||
val0 = tf.concat([emb0(ind0), tf.constant([[0.0, 0.0]])], axis=0)
|
||||
return tf.reduce_sum(
|
||||
tf.reshape(_loss(data0, val0) + _loss(data1, var1), [n, -1]),
|
||||
keepdims=keepdims,
|
||||
axis=1)
|
||||
|
||||
def vars_fn():
|
||||
return (emb0.embeddings, var1, dummy_var)
|
||||
|
||||
return loss_fn, vars_fn
|
||||
else:
|
||||
val0 = tf.concat([emb0(ind0), tf.constant([[0.0, 0.0]])], axis=0)
|
||||
loss = tf.reduce_sum(
|
||||
tf.reshape(_loss(data0, val0) + _loss(data1, var1), [n, -1]),
|
||||
|
@ -61,8 +72,6 @@ class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
axis=1)
|
||||
return loss, (emb0.embeddings, var1, dummy_var)
|
||||
|
||||
return _loss_and_vars_fn
|
||||
|
||||
def _get_true_grads(self,
|
||||
n,
|
||||
normalize=False,
|
||||
|
@ -130,24 +139,59 @@ class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
(3, False, 1.0, 'sum', 2, 'zero'),
|
||||
(1, True, 0.5, 'mean', 3, 'none'),
|
||||
)
|
||||
def testCorrect(self, n, normalize, l2_norm_clip, agg_method,
|
||||
keep_sparse_threshold, unconnected):
|
||||
"""Tests the correctness of the computation."""
|
||||
loss_and_vars_fn = self._get_loss_and_vars_fn(n)
|
||||
def testCorrectTensor(self, n, normalize, l2_norm_clip, agg_method,
|
||||
keep_sparse_threshold, unconnected):
|
||||
"""Tests the correctness of the computation for tensors."""
|
||||
true_grads = self._get_true_grads(n, normalize, l2_norm_clip, agg_method,
|
||||
unconnected)
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
loss, test_vars = loss_and_vars_fn()
|
||||
results = cag.clip_and_aggregate_gradients(
|
||||
tape,
|
||||
loss,
|
||||
test_vars,
|
||||
normalize=normalize,
|
||||
l2_norm_clip=l2_norm_clip,
|
||||
aggregate_method=agg_method,
|
||||
unconnected_gradients=unconnected,
|
||||
keep_sparse_threshold=keep_sparse_threshold)
|
||||
tape = tf.GradientTape()
|
||||
with tape:
|
||||
loss, test_vars = self._get_loss_and_vars(n)
|
||||
results = cag.clip_and_aggregate_gradients(
|
||||
tape,
|
||||
loss,
|
||||
test_vars,
|
||||
normalize=normalize,
|
||||
l2_norm_clip=l2_norm_clip,
|
||||
aggregate_method=agg_method,
|
||||
unconnected_gradients=unconnected,
|
||||
keep_sparse_threshold=keep_sparse_threshold)
|
||||
for r, t in zip(results, true_grads):
|
||||
if t is None:
|
||||
self.assertIsNone(r)
|
||||
else:
|
||||
r = self._to_dense_array(r)
|
||||
self.assertAllCloseAccordingToType(r, t)
|
||||
|
||||
@parameterized.parameters(
|
||||
(6, False, None, 'mean', -1, 'none'),
|
||||
(6, True, None, 'sum', 1, 'none'),
|
||||
(2, False, None, 'sum', 3, 'none'),
|
||||
(2, True, 100.0, 'mean', 1, 'zero'),
|
||||
(3, False, 1.0, 'sum', 2, 'zero'),
|
||||
(1, True, 0.5, 'mean', 3, 'none'),
|
||||
)
|
||||
def testCorrectCallable(self, n, normalize, l2_norm_clip, agg_method,
|
||||
keep_sparse_threshold, unconnected):
|
||||
"""Tests the correctness of the computation for callables."""
|
||||
true_grads = self._get_true_grads(n, normalize, l2_norm_clip, agg_method,
|
||||
unconnected)
|
||||
|
||||
loss_fn, vars_fn = self._get_loss_and_vars(n, is_callable=True)
|
||||
tape = tf.GradientTape()
|
||||
with tape:
|
||||
loss = loss_fn()
|
||||
test_vars = vars_fn()
|
||||
results = cag.clip_and_aggregate_gradients(
|
||||
tape,
|
||||
loss,
|
||||
test_vars,
|
||||
normalize=normalize,
|
||||
l2_norm_clip=l2_norm_clip,
|
||||
aggregate_method=agg_method,
|
||||
unconnected_gradients=unconnected,
|
||||
keep_sparse_threshold=keep_sparse_threshold)
|
||||
for r, t in zip(results, true_grads):
|
||||
if t is None:
|
||||
self.assertIsNone(r)
|
||||
|
@ -163,11 +207,10 @@ class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
)
|
||||
def testTargetShape(self, n, keepdims):
|
||||
"""Tests target gets vectorized regardless of their original shape."""
|
||||
loss_and_vars_fn = self._get_loss_and_vars_fn(n, keepdims)
|
||||
true_grads = self._get_true_grads(n)
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
loss, test_vars = loss_and_vars_fn()
|
||||
loss, test_vars = self._get_loss_and_vars(n, keepdims)
|
||||
results = cag.clip_and_aggregate_gradients(tape, loss, test_vars)
|
||||
for r, t in zip(results, true_grads):
|
||||
if t is None:
|
||||
|
@ -184,9 +227,8 @@ class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
)
|
||||
def testSparse(self, keep_sparse_threshold):
|
||||
"""Tests the outcome is in the desired (dense or sparse) tensor form."""
|
||||
loss_and_vars_fn = self._get_loss_and_vars_fn(3)
|
||||
with tf.GradientTape() as tape:
|
||||
loss, test_vars = loss_and_vars_fn()
|
||||
loss, test_vars = self._get_loss_and_vars(3)
|
||||
results = cag.clip_and_aggregate_gradients(
|
||||
tape,
|
||||
loss,
|
||||
|
|
Loading…
Reference in a new issue