Remove pfor dependency in BUILD file, and strengthen unit tests for clip_and_aggregate_gradients.py.

PiperOrigin-RevId: 482050282
This commit is contained in:
Steve Chien 2022-10-18 16:20:56 -07:00 committed by A. Unique TensorFlower
parent 4aa531faa4
commit 0fcfd0bf69
3 changed files with 70 additions and 30 deletions

View file

@ -15,7 +15,6 @@ py_library(
"clip_and_aggregate_gradients.py", "clip_and_aggregate_gradients.py",
], ],
srcs_version = "PY3", srcs_version = "PY3",
deps = ["//third_party/tensorflow/python/ops/parallel_for:control_flow_ops"],
) )
py_library( py_library(

View file

@ -191,9 +191,8 @@ def clip_and_aggregate_gradients(
output = control_flow_ops.pfor(loop_fn, target_size) output = control_flow_ops.pfor(loop_fn, target_size)
except ValueError as err: except ValueError as err:
raise ValueError( raise ValueError(
'Encountered an exception while vectorizing the ' 'Encountered an exception while vectorizing the jacobian computation. '
'batch_jacobian computation. Vectorization can be disabled by ' 'Consider using a non-vectorized optimizer instead.') from err
'setting experimental_use_pfor to False.') from err
grads = [] grads = []
for i, out in enumerate(output): for i, out in enumerate(output):
if out is not None: if out is not None:

View file

@ -22,8 +22,8 @@ from tensorflow_privacy.privacy.optimizers import clip_and_aggregate_gradients a
class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase): class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
"""Tests clip_and_aggreate_gradients.""" """Tests clip_and_aggreate_gradients."""
def _get_loss_and_vars_fn(self, n, keepdims=False): def _get_loss_and_vars(self, n, keepdims=False, is_callable=False):
"""Returns the function for creating the loss and variables.""" """Returns the loss and variable tensors or callables."""
# The "model" here consists of both sparse and dense parameters to make sure # The "model" here consists of both sparse and dense parameters to make sure
# `clip_and_aggregate_gradients` computes the gradients in the correct way # `clip_and_aggregate_gradients` computes the gradients in the correct way
# and in the right format. The sparse layer is the embedding layer `emb0`, # and in the right format. The sparse layer is the embedding layer `emb0`,
@ -51,9 +51,20 @@ class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
return 0.5 * tf.reduce_sum( return 0.5 * tf.reduce_sum(
input_tensor=tf.math.squared_difference(val0, val1), axis=1) input_tensor=tf.math.squared_difference(val0, val1), axis=1)
def _loss_and_vars_fn(): if is_callable:
# We concatenate the embeddings with some constant values to make sure
# backprop does only go through those gathered indices. def loss_fn():
val0 = tf.concat([emb0(ind0), tf.constant([[0.0, 0.0]])], axis=0)
return tf.reduce_sum(
tf.reshape(_loss(data0, val0) + _loss(data1, var1), [n, -1]),
keepdims=keepdims,
axis=1)
def vars_fn():
return (emb0.embeddings, var1, dummy_var)
return loss_fn, vars_fn
else:
val0 = tf.concat([emb0(ind0), tf.constant([[0.0, 0.0]])], axis=0) val0 = tf.concat([emb0(ind0), tf.constant([[0.0, 0.0]])], axis=0)
loss = tf.reduce_sum( loss = tf.reduce_sum(
tf.reshape(_loss(data0, val0) + _loss(data1, var1), [n, -1]), tf.reshape(_loss(data0, val0) + _loss(data1, var1), [n, -1]),
@ -61,8 +72,6 @@ class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
axis=1) axis=1)
return loss, (emb0.embeddings, var1, dummy_var) return loss, (emb0.embeddings, var1, dummy_var)
return _loss_and_vars_fn
def _get_true_grads(self, def _get_true_grads(self,
n, n,
normalize=False, normalize=False,
@ -130,24 +139,59 @@ class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
(3, False, 1.0, 'sum', 2, 'zero'), (3, False, 1.0, 'sum', 2, 'zero'),
(1, True, 0.5, 'mean', 3, 'none'), (1, True, 0.5, 'mean', 3, 'none'),
) )
def testCorrect(self, n, normalize, l2_norm_clip, agg_method, def testCorrectTensor(self, n, normalize, l2_norm_clip, agg_method,
keep_sparse_threshold, unconnected): keep_sparse_threshold, unconnected):
"""Tests the correctness of the computation.""" """Tests the correctness of the computation for tensors."""
loss_and_vars_fn = self._get_loss_and_vars_fn(n)
true_grads = self._get_true_grads(n, normalize, l2_norm_clip, agg_method, true_grads = self._get_true_grads(n, normalize, l2_norm_clip, agg_method,
unconnected) unconnected)
with tf.GradientTape() as tape: tape = tf.GradientTape()
loss, test_vars = loss_and_vars_fn() with tape:
results = cag.clip_and_aggregate_gradients( loss, test_vars = self._get_loss_and_vars(n)
tape, results = cag.clip_and_aggregate_gradients(
loss, tape,
test_vars, loss,
normalize=normalize, test_vars,
l2_norm_clip=l2_norm_clip, normalize=normalize,
aggregate_method=agg_method, l2_norm_clip=l2_norm_clip,
unconnected_gradients=unconnected, aggregate_method=agg_method,
keep_sparse_threshold=keep_sparse_threshold) unconnected_gradients=unconnected,
keep_sparse_threshold=keep_sparse_threshold)
for r, t in zip(results, true_grads):
if t is None:
self.assertIsNone(r)
else:
r = self._to_dense_array(r)
self.assertAllCloseAccordingToType(r, t)
@parameterized.parameters(
(6, False, None, 'mean', -1, 'none'),
(6, True, None, 'sum', 1, 'none'),
(2, False, None, 'sum', 3, 'none'),
(2, True, 100.0, 'mean', 1, 'zero'),
(3, False, 1.0, 'sum', 2, 'zero'),
(1, True, 0.5, 'mean', 3, 'none'),
)
def testCorrectCallable(self, n, normalize, l2_norm_clip, agg_method,
keep_sparse_threshold, unconnected):
"""Tests the correctness of the computation for callables."""
true_grads = self._get_true_grads(n, normalize, l2_norm_clip, agg_method,
unconnected)
loss_fn, vars_fn = self._get_loss_and_vars(n, is_callable=True)
tape = tf.GradientTape()
with tape:
loss = loss_fn()
test_vars = vars_fn()
results = cag.clip_and_aggregate_gradients(
tape,
loss,
test_vars,
normalize=normalize,
l2_norm_clip=l2_norm_clip,
aggregate_method=agg_method,
unconnected_gradients=unconnected,
keep_sparse_threshold=keep_sparse_threshold)
for r, t in zip(results, true_grads): for r, t in zip(results, true_grads):
if t is None: if t is None:
self.assertIsNone(r) self.assertIsNone(r)
@ -163,11 +207,10 @@ class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
) )
def testTargetShape(self, n, keepdims): def testTargetShape(self, n, keepdims):
"""Tests target gets vectorized regardless of their original shape.""" """Tests target gets vectorized regardless of their original shape."""
loss_and_vars_fn = self._get_loss_and_vars_fn(n, keepdims)
true_grads = self._get_true_grads(n) true_grads = self._get_true_grads(n)
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
loss, test_vars = loss_and_vars_fn() loss, test_vars = self._get_loss_and_vars(n, keepdims)
results = cag.clip_and_aggregate_gradients(tape, loss, test_vars) results = cag.clip_and_aggregate_gradients(tape, loss, test_vars)
for r, t in zip(results, true_grads): for r, t in zip(results, true_grads):
if t is None: if t is None:
@ -184,9 +227,8 @@ class ClipAndAggregateGradientsTest(tf.test.TestCase, parameterized.TestCase):
) )
def testSparse(self, keep_sparse_threshold): def testSparse(self, keep_sparse_threshold):
"""Tests the outcome is in the desired (dense or sparse) tensor form.""" """Tests the outcome is in the desired (dense or sparse) tensor form."""
loss_and_vars_fn = self._get_loss_and_vars_fn(3)
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
loss, test_vars = loss_and_vars_fn() loss, test_vars = self._get_loss_and_vars(3)
results = cag.clip_and_aggregate_gradients( results = cag.clip_and_aggregate_gradients(
tape, tape,
loss, loss,