Add tests for varying number of microbatches in dp_optimizer_test.py.
PiperOrigin-RevId: 404072714
This commit is contained in:
parent
977647a3bf
commit
c530356ae9
1 changed files with 64 additions and 8 deletions
|
@ -35,6 +35,24 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
return 0.5 * tf.reduce_sum(
|
||||
input_tensor=tf.math.squared_difference(val0, val1), axis=1)
|
||||
|
||||
def _compute_expected_gradients(self, per_example_gradients,
|
||||
l2_norm_clip, num_microbatches):
|
||||
batch_size, num_vars = per_example_gradients.shape
|
||||
microbatch_gradients = np.mean(
|
||||
np.reshape(per_example_gradients,
|
||||
[num_microbatches,
|
||||
np.int(batch_size / num_microbatches), num_vars]),
|
||||
axis=1)
|
||||
microbatch_gradients_norms = np.linalg.norm(microbatch_gradients, axis=1)
|
||||
|
||||
def scale(x):
|
||||
return 1.0 if x < l2_norm_clip else l2_norm_clip / x
|
||||
|
||||
scales = np.array(list(map(scale, microbatch_gradients_norms)))
|
||||
mean_clipped_gradients = np.mean(
|
||||
microbatch_gradients * scales[:, None], axis=0)
|
||||
return mean_clipped_gradients
|
||||
|
||||
# Parameters for testing: optimizer, num_microbatches, expected answer.
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1,
|
||||
|
@ -98,18 +116,56 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer),
|
||||
('DPAdagrad', dp_optimizer.DPAdagradOptimizer),
|
||||
('DPAdam', dp_optimizer.DPAdamOptimizer),
|
||||
('DPRMSPropOptimizer', dp_optimizer.DPRMSPropOptimizer))
|
||||
def testNoiseMultiplier(self, cls):
|
||||
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1),
|
||||
('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2),
|
||||
('DPGradientDescent 4', dp_optimizer.DPGradientDescentOptimizer, 4),
|
||||
)
|
||||
def testClippingNormWithMicrobatches(self, cls, num_microbatches):
|
||||
with self.cached_session() as sess:
|
||||
var0 = tf.Variable([0.0, 0.0])
|
||||
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0], [-9.0, -12.0],
|
||||
[-12.0, -16.0]])
|
||||
|
||||
l2_norm_clip = 1.0
|
||||
dp_sum_query = gaussian_query.GaussianSumQuery(l2_norm_clip, 0.0)
|
||||
|
||||
opt = cls(dp_sum_query, num_microbatches=num_microbatches,
|
||||
learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
var_np = self.evaluate(var0)
|
||||
self.assertAllClose([0.0, 0.0], var_np)
|
||||
|
||||
# Compute expected gradient, which is the sum of differences.
|
||||
data_np = self.evaluate(data0)
|
||||
per_example_gradients = var_np - data_np
|
||||
mean_clipped_gradients = self._compute_expected_gradients(
|
||||
per_example_gradients, l2_norm_clip, num_microbatches)
|
||||
|
||||
# Compare actual with expected gradients.
|
||||
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
|
||||
grads_and_vars = sess.run(gradient_op)
|
||||
print('mean_clipped_gradients: ', mean_clipped_gradients)
|
||||
self.assertAllCloseAccordingToType(mean_clipped_gradients,
|
||||
grads_and_vars[0][0])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1),
|
||||
('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2),
|
||||
('DPGradientDescent 4', dp_optimizer.DPGradientDescentOptimizer, 4),
|
||||
('DPAdagrad', dp_optimizer.DPAdagradOptimizer, 1),
|
||||
('DPAdam', dp_optimizer.DPAdamOptimizer, 1),
|
||||
('DPRMSPropOptimizer', dp_optimizer.DPRMSPropOptimizer, 1))
|
||||
def testNoiseMultiplier(self, cls, num_microbatches):
|
||||
with self.cached_session() as sess:
|
||||
var0 = tf.Variable([0.0])
|
||||
data0 = tf.Variable([[0.0]])
|
||||
data0 = tf.Variable([[0.0], [0.0], [0.0], [0.0]])
|
||||
|
||||
dp_sum_query = gaussian_query.GaussianSumQuery(4.0, 8.0)
|
||||
|
||||
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)
|
||||
opt = cls(
|
||||
dp_sum_query, num_microbatches=num_microbatches, learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
|
@ -122,7 +178,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
grads.append(grads_and_vars[0][0])
|
||||
|
||||
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
||||
self.assertNear(np.std(grads), 2.0 * 4.0, 0.5)
|
||||
self.assertNear(np.std(grads), 2.0 * 4.0 / num_microbatches, 0.5)
|
||||
|
||||
@mock.patch('absl.logging.warning')
|
||||
def testComputeGradientsOverrideWarning(self, mock_logging):
|
||||
|
|
Loading…
Reference in a new issue