diff --git a/privacy/optimizers/dp_optimizer.py b/privacy/optimizers/dp_optimizer.py index 58d9c55..f08a448 100644 --- a/privacy/optimizers/dp_optimizer.py +++ b/privacy/optimizers/dp_optimizer.py @@ -35,13 +35,23 @@ def make_optimizer_class(cls): class DPOptimizerClass(cls): """Differentially private subclass of given class cls.""" - def __init__(self, l2_norm_clip, noise_multiplier, num_microbatches, *args, - **kwargs): + def __init__( + self, + l2_norm_clip, + noise_multiplier, + num_microbatches, + unroll_microbatches=False, + *args, # pylint: disable=keyword-arg-before-vararg + **kwargs): super(DPOptimizerClass, self).__init__(*args, **kwargs) stddev = l2_norm_clip * noise_multiplier self._num_microbatches = num_microbatches self._private_query = gaussian_query.GaussianAverageQuery( l2_norm_clip, stddev, num_microbatches) + # TODO(b/122613513): Set unroll_microbatches=True to avoid this bug. + # Beware: When num_microbatches is large (>100), enabling this parameter + # may cause an OOM error. + self._unroll_microbatches = unroll_microbatches self._global_state = self._private_query.initial_global_state() def compute_gradients(self, @@ -68,9 +78,7 @@ def make_optimizer_class(cls): grads_list = list(grads) sample_state = self._private_query.accumulate_record( sample_params, sample_state, grads_list) - return [tf.add(i, 1), sample_state] - - i = tf.constant(0) + return sample_state if var_list is None: var_list = ( @@ -79,14 +87,20 @@ def make_optimizer_class(cls): sample_state = self._private_query.initial_sample_state( self._global_state, var_list) - # Use of while_loop here requires that sample_state be a nested structure - # of tensors. In general, we would prefer to allow it to be an arbitrary - # opaque type. - _, final_state = tf.while_loop( - lambda i, _: tf.less(i, self._num_microbatches), process_microbatch, - [i, sample_state]) + if self._unroll_microbatches: + for idx in range(self._num_microbatches): + sample_state = process_microbatch(idx, sample_state) + else: + # Use of while_loop here requires that sample_state be a nested + # structure of tensors. In general, we would prefer to allow it to be + # an arbitrary opaque type. + cond_fn = lambda i, _: tf.less(i, self._num_microbatches) + body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)] + idx = tf.constant(0) + _, sample_state = tf.while_loop(cond_fn, body_fn, [idx, sample_state]) + final_grads, self._global_state = ( - self._private_query.get_noised_average(final_state, + self._private_query.get_noised_average(sample_state, self._global_state)) return list(zip(final_grads, var_list)) diff --git a/privacy/optimizers/dp_optimizer_test.py b/privacy/optimizers/dp_optimizer_test.py index 69e131c..98fd2a3 100644 --- a/privacy/optimizers/dp_optimizer_test.py +++ b/privacy/optimizers/dp_optimizer_test.py @@ -136,6 +136,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): def testEstimator(self): """Tests that DP optimizers work with tf.estimator.""" + def linear_model_fn(features, labels, mode): preds = tf.keras.layers.Dense( 1, activation='linear', name='dense').apply(features['x']) @@ -173,6 +174,32 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): true_weights, atol=1.0) + @parameterized.named_parameters( + ('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer), + ('DPAdagrad', dp_optimizer.DPAdagradOptimizer), + ('DPAdam', dp_optimizer.DPAdamOptimizer)) + def testUnrollMicrobatches(self, cls): + with self.cached_session() as sess: + var0 = tf.Variable([1.0, 2.0]) + data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) + + opt = cls( + l2_norm_clip=1.0e9, + noise_multiplier=0.0, + num_microbatches=4, + learning_rate=2.0, + unroll_microbatches=True) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + + # Expected gradient is sum of differences divided by number of + # microbatches. + gradient_op = opt.compute_gradients(loss(data0, var0), [var0]) + grads_and_vars = sess.run(gradient_op) + self.assertAllCloseAccordingToType([-2.5, -2.5], grads_and_vars[0][0]) + if __name__ == '__main__': tf.test.main()