Remove pfor dependency in BUILD file, and strengthen unit tests for clip_and_aggregate_gradients.py.

PiperOrigin-RevId: 482050282
This commit is contained in:
Steve Chien 2022-10-18 16:20:56 -07:00 committed by A. Unique TensorFlower
parent 4aa531faa4
commit 0fcfd0bf69
3 changed files with 70 additions and 30 deletions

View file

@ -15,7 +15,6 @@ py_library(
"clip_and_aggregate_gradients.py",
],
srcs_version = "PY3",
deps = ["//third_party/tensorflow/python/ops/parallel_for:control_flow_ops"],
)
py_library(

View file

@ -191,9 +191,8 @@ def clip_and_aggregate_gradients(
output = control_flow_ops.pfor(loop_fn, target_size)
except ValueError as err:
raise ValueError(
'Encountered an exception while vectorizing the '
'batch_jacobian computation. Vectorization can be disabled by '
'setting experimental_use_pfor to False.') from err
'Encountered an exception while vectorizing the jacobian computation. '
'Consider using a non-vectorized optimizer instead.') from err
grads = []
for i, out in enumerate(output):
if out is not None:

View file

@ -22,8 +22,8 @@ from tensorflow_privacy.privacy.optimizers import clip_and_aggregate_gradients a
class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
"""Tests clip_and_aggreate_gradients."""
def _get_loss_and_vars_fn(self, n, keepdims=False):
"""Returns the function for creating the loss and variables."""
def _get_loss_and_vars(self, n, keepdims=False, is_callable=False):
"""Returns the loss and variable tensors or callables."""
# The "model" here consists of both sparse and dense parameters to make sure
# `clip_and_aggregate_gradients` computes the gradients in the correct way
# and in the right format. The sparse layer is the embedding layer `emb0`,
@ -51,9 +51,20 @@ class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
return 0.5 * tf.reduce_sum(
input_tensor=tf.math.squared_difference(val0, val1), axis=1)
def _loss_and_vars_fn():
# We concatenate the embeddings with some constant values to make sure
# backprop does only go through those gathered indices.
if is_callable:
def loss_fn():
val0 = tf.concat([emb0(ind0), tf.constant([[0.0, 0.0]])], axis=0)
return tf.reduce_sum(
tf.reshape(_loss(data0, val0) + _loss(data1, var1), [n, -1]),
keepdims=keepdims,
axis=1)
def vars_fn():
return (emb0.embeddings, var1, dummy_var)
return loss_fn, vars_fn
else:
val0 = tf.concat([emb0(ind0), tf.constant([[0.0, 0.0]])], axis=0)
loss = tf.reduce_sum(
tf.reshape(_loss(data0, val0) + _loss(data1, var1), [n, -1]),
@ -61,8 +72,6 @@ class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
axis=1)
return loss, (emb0.embeddings, var1, dummy_var)
return _loss_and_vars_fn
def _get_true_grads(self,
n,
normalize=False,
@ -130,15 +139,50 @@ class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
(3, False, 1.0, 'sum', 2, 'zero'),
(1, True, 0.5, 'mean', 3, 'none'),
)
def testCorrect(self, n, normalize, l2_norm_clip, agg_method,
def testCorrectTensor(self, n, normalize, l2_norm_clip, agg_method,
keep_sparse_threshold, unconnected):
"""Tests the correctness of the computation."""
loss_and_vars_fn = self._get_loss_and_vars_fn(n)
"""Tests the correctness of the computation for tensors."""
true_grads = self._get_true_grads(n, normalize, l2_norm_clip, agg_method,
unconnected)
with tf.GradientTape() as tape:
loss, test_vars = loss_and_vars_fn()
tape = tf.GradientTape()
with tape:
loss, test_vars = self._get_loss_and_vars(n)
results = cag.clip_and_aggregate_gradients(
tape,
loss,
test_vars,
normalize=normalize,
l2_norm_clip=l2_norm_clip,
aggregate_method=agg_method,
unconnected_gradients=unconnected,
keep_sparse_threshold=keep_sparse_threshold)
for r, t in zip(results, true_grads):
if t is None:
self.assertIsNone(r)
else:
r = self._to_dense_array(r)
self.assertAllCloseAccordingToType(r, t)
@parameterized.parameters(
(6, False, None, 'mean', -1, 'none'),
(6, True, None, 'sum', 1, 'none'),
(2, False, None, 'sum', 3, 'none'),
(2, True, 100.0, 'mean', 1, 'zero'),
(3, False, 1.0, 'sum', 2, 'zero'),
(1, True, 0.5, 'mean', 3, 'none'),
)
def testCorrectCallable(self, n, normalize, l2_norm_clip, agg_method,
keep_sparse_threshold, unconnected):
"""Tests the correctness of the computation for callables."""
true_grads = self._get_true_grads(n, normalize, l2_norm_clip, agg_method,
unconnected)
loss_fn, vars_fn = self._get_loss_and_vars(n, is_callable=True)
tape = tf.GradientTape()
with tape:
loss = loss_fn()
test_vars = vars_fn()
results = cag.clip_and_aggregate_gradients(
tape,
loss,
@ -163,11 +207,10 @@ class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
)
def testTargetShape(self, n, keepdims):
"""Tests target gets vectorized regardless of their original shape."""
loss_and_vars_fn = self._get_loss_and_vars_fn(n, keepdims)
true_grads = self._get_true_grads(n)
with tf.GradientTape() as tape:
loss, test_vars = loss_and_vars_fn()
loss, test_vars = self._get_loss_and_vars(n, keepdims)
results = cag.clip_and_aggregate_gradients(tape, loss, test_vars)
for r, t in zip(results, true_grads):
if t is None:
@ -184,9 +227,8 @@ class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
)
def testSparse(self, keep_sparse_threshold):
"""Tests the outcome is in the desired (dense or sparse) tensor form."""
loss_and_vars_fn = self._get_loss_and_vars_fn(3)
with tf.GradientTape() as tape:
loss, test_vars = loss_and_vars_fn()
loss, test_vars = self._get_loss_and_vars(3)
results = cag.clip_and_aggregate_gradients(
tape,
loss,