Add tests for varying number of microbatches in dp_optimizer_test.py.
PiperOrigin-RevId: 404072714
This commit is contained in:
parent
977647a3bf
commit
c530356ae9
1 changed files with 64 additions and 8 deletions
|
@ -35,6 +35,24 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
return 0.5 * tf.reduce_sum(
|
return 0.5 * tf.reduce_sum(
|
||||||
input_tensor=tf.math.squared_difference(val0, val1), axis=1)
|
input_tensor=tf.math.squared_difference(val0, val1), axis=1)
|
||||||
|
|
||||||
|
def _compute_expected_gradients(self, per_example_gradients,
|
||||||
|
l2_norm_clip, num_microbatches):
|
||||||
|
batch_size, num_vars = per_example_gradients.shape
|
||||||
|
microbatch_gradients = np.mean(
|
||||||
|
np.reshape(per_example_gradients,
|
||||||
|
[num_microbatches,
|
||||||
|
np.int(batch_size / num_microbatches), num_vars]),
|
||||||
|
axis=1)
|
||||||
|
microbatch_gradients_norms = np.linalg.norm(microbatch_gradients, axis=1)
|
||||||
|
|
||||||
|
def scale(x):
|
||||||
|
return 1.0 if x < l2_norm_clip else l2_norm_clip / x
|
||||||
|
|
||||||
|
scales = np.array(list(map(scale, microbatch_gradients_norms)))
|
||||||
|
mean_clipped_gradients = np.mean(
|
||||||
|
microbatch_gradients * scales[:, None], axis=0)
|
||||||
|
return mean_clipped_gradients
|
||||||
|
|
||||||
# Parameters for testing: optimizer, num_microbatches, expected answer.
|
# Parameters for testing: optimizer, num_microbatches, expected answer.
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1,
|
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1,
|
||||||
|
@ -98,18 +116,56 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
|
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer),
|
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1),
|
||||||
('DPAdagrad', dp_optimizer.DPAdagradOptimizer),
|
('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2),
|
||||||
('DPAdam', dp_optimizer.DPAdamOptimizer),
|
('DPGradientDescent 4', dp_optimizer.DPGradientDescentOptimizer, 4),
|
||||||
('DPRMSPropOptimizer', dp_optimizer.DPRMSPropOptimizer))
|
)
|
||||||
def testNoiseMultiplier(self, cls):
|
def testClippingNormWithMicrobatches(self, cls, num_microbatches):
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
var0 = tf.Variable([0.0, 0.0])
|
||||||
|
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0], [-9.0, -12.0],
|
||||||
|
[-12.0, -16.0]])
|
||||||
|
|
||||||
|
l2_norm_clip = 1.0
|
||||||
|
dp_sum_query = gaussian_query.GaussianSumQuery(l2_norm_clip, 0.0)
|
||||||
|
|
||||||
|
opt = cls(dp_sum_query, num_microbatches=num_microbatches,
|
||||||
|
learning_rate=2.0)
|
||||||
|
|
||||||
|
self.evaluate(tf.global_variables_initializer())
|
||||||
|
# Fetch params to validate initial values
|
||||||
|
var_np = self.evaluate(var0)
|
||||||
|
self.assertAllClose([0.0, 0.0], var_np)
|
||||||
|
|
||||||
|
# Compute expected gradient, which is the sum of differences.
|
||||||
|
data_np = self.evaluate(data0)
|
||||||
|
per_example_gradients = var_np - data_np
|
||||||
|
mean_clipped_gradients = self._compute_expected_gradients(
|
||||||
|
per_example_gradients, l2_norm_clip, num_microbatches)
|
||||||
|
|
||||||
|
# Compare actual with expected gradients.
|
||||||
|
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
|
||||||
|
grads_and_vars = sess.run(gradient_op)
|
||||||
|
print('mean_clipped_gradients: ', mean_clipped_gradients)
|
||||||
|
self.assertAllCloseAccordingToType(mean_clipped_gradients,
|
||||||
|
grads_and_vars[0][0])
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1),
|
||||||
|
('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2),
|
||||||
|
('DPGradientDescent 4', dp_optimizer.DPGradientDescentOptimizer, 4),
|
||||||
|
('DPAdagrad', dp_optimizer.DPAdagradOptimizer, 1),
|
||||||
|
('DPAdam', dp_optimizer.DPAdamOptimizer, 1),
|
||||||
|
('DPRMSPropOptimizer', dp_optimizer.DPRMSPropOptimizer, 1))
|
||||||
|
def testNoiseMultiplier(self, cls, num_microbatches):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
var0 = tf.Variable([0.0])
|
var0 = tf.Variable([0.0])
|
||||||
data0 = tf.Variable([[0.0]])
|
data0 = tf.Variable([[0.0], [0.0], [0.0], [0.0]])
|
||||||
|
|
||||||
dp_sum_query = gaussian_query.GaussianSumQuery(4.0, 8.0)
|
dp_sum_query = gaussian_query.GaussianSumQuery(4.0, 8.0)
|
||||||
|
|
||||||
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)
|
opt = cls(
|
||||||
|
dp_sum_query, num_microbatches=num_microbatches, learning_rate=2.0)
|
||||||
|
|
||||||
self.evaluate(tf.global_variables_initializer())
|
self.evaluate(tf.global_variables_initializer())
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
|
@ -122,7 +178,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
grads.append(grads_and_vars[0][0])
|
grads.append(grads_and_vars[0][0])
|
||||||
|
|
||||||
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
||||||
self.assertNear(np.std(grads), 2.0 * 4.0, 0.5)
|
self.assertNear(np.std(grads), 2.0 * 4.0 / num_microbatches, 0.5)
|
||||||
|
|
||||||
@mock.patch('absl.logging.warning')
|
@mock.patch('absl.logging.warning')
|
||||||
def testComputeGradientsOverrideWarning(self, mock_logging):
|
def testComputeGradientsOverrideWarning(self, mock_logging):
|
||||||
|
|
Loading…
Reference in a new issue