forked from 626_privacy/tensorflow_privacy
Created the optional unroll_microbatches parameter for the DpOptimizerClass as a workaround for b/122613513.
PiperOrigin-RevId: 229955297
This commit is contained in:
parent
87ec1a2e77
commit
6c5c39c4f2
2 changed files with 53 additions and 12 deletions
|
@ -35,13 +35,23 @@ def make_optimizer_class(cls):
|
||||||
class DPOptimizerClass(cls):
|
class DPOptimizerClass(cls):
|
||||||
"""Differentially private subclass of given class cls."""
|
"""Differentially private subclass of given class cls."""
|
||||||
|
|
||||||
def __init__(self, l2_norm_clip, noise_multiplier, num_microbatches, *args,
|
def __init__(
|
||||||
|
self,
|
||||||
|
l2_norm_clip,
|
||||||
|
noise_multiplier,
|
||||||
|
num_microbatches,
|
||||||
|
unroll_microbatches=False,
|
||||||
|
*args, # pylint: disable=keyword-arg-before-vararg
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(DPOptimizerClass, self).__init__(*args, **kwargs)
|
super(DPOptimizerClass, self).__init__(*args, **kwargs)
|
||||||
stddev = l2_norm_clip * noise_multiplier
|
stddev = l2_norm_clip * noise_multiplier
|
||||||
self._num_microbatches = num_microbatches
|
self._num_microbatches = num_microbatches
|
||||||
self._private_query = gaussian_query.GaussianAverageQuery(
|
self._private_query = gaussian_query.GaussianAverageQuery(
|
||||||
l2_norm_clip, stddev, num_microbatches)
|
l2_norm_clip, stddev, num_microbatches)
|
||||||
|
# TODO(b/122613513): Set unroll_microbatches=True to avoid this bug.
|
||||||
|
# Beware: When num_microbatches is large (>100), enabling this parameter
|
||||||
|
# may cause an OOM error.
|
||||||
|
self._unroll_microbatches = unroll_microbatches
|
||||||
self._global_state = self._private_query.initial_global_state()
|
self._global_state = self._private_query.initial_global_state()
|
||||||
|
|
||||||
def compute_gradients(self,
|
def compute_gradients(self,
|
||||||
|
@ -68,9 +78,7 @@ def make_optimizer_class(cls):
|
||||||
grads_list = list(grads)
|
grads_list = list(grads)
|
||||||
sample_state = self._private_query.accumulate_record(
|
sample_state = self._private_query.accumulate_record(
|
||||||
sample_params, sample_state, grads_list)
|
sample_params, sample_state, grads_list)
|
||||||
return [tf.add(i, 1), sample_state]
|
return sample_state
|
||||||
|
|
||||||
i = tf.constant(0)
|
|
||||||
|
|
||||||
if var_list is None:
|
if var_list is None:
|
||||||
var_list = (
|
var_list = (
|
||||||
|
@ -79,14 +87,20 @@ def make_optimizer_class(cls):
|
||||||
sample_state = self._private_query.initial_sample_state(
|
sample_state = self._private_query.initial_sample_state(
|
||||||
self._global_state, var_list)
|
self._global_state, var_list)
|
||||||
|
|
||||||
# Use of while_loop here requires that sample_state be a nested structure
|
if self._unroll_microbatches:
|
||||||
# of tensors. In general, we would prefer to allow it to be an arbitrary
|
for idx in range(self._num_microbatches):
|
||||||
# opaque type.
|
sample_state = process_microbatch(idx, sample_state)
|
||||||
_, final_state = tf.while_loop(
|
else:
|
||||||
lambda i, _: tf.less(i, self._num_microbatches), process_microbatch,
|
# Use of while_loop here requires that sample_state be a nested
|
||||||
[i, sample_state])
|
# structure of tensors. In general, we would prefer to allow it to be
|
||||||
|
# an arbitrary opaque type.
|
||||||
|
cond_fn = lambda i, _: tf.less(i, self._num_microbatches)
|
||||||
|
body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)]
|
||||||
|
idx = tf.constant(0)
|
||||||
|
_, sample_state = tf.while_loop(cond_fn, body_fn, [idx, sample_state])
|
||||||
|
|
||||||
final_grads, self._global_state = (
|
final_grads, self._global_state = (
|
||||||
self._private_query.get_noised_average(final_state,
|
self._private_query.get_noised_average(sample_state,
|
||||||
self._global_state))
|
self._global_state))
|
||||||
|
|
||||||
return list(zip(final_grads, var_list))
|
return list(zip(final_grads, var_list))
|
||||||
|
|
|
@ -136,6 +136,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def testEstimator(self):
|
def testEstimator(self):
|
||||||
"""Tests that DP optimizers work with tf.estimator."""
|
"""Tests that DP optimizers work with tf.estimator."""
|
||||||
|
|
||||||
def linear_model_fn(features, labels, mode):
|
def linear_model_fn(features, labels, mode):
|
||||||
preds = tf.keras.layers.Dense(
|
preds = tf.keras.layers.Dense(
|
||||||
1, activation='linear', name='dense').apply(features['x'])
|
1, activation='linear', name='dense').apply(features['x'])
|
||||||
|
@ -173,6 +174,32 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
true_weights,
|
true_weights,
|
||||||
atol=1.0)
|
atol=1.0)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer),
|
||||||
|
('DPAdagrad', dp_optimizer.DPAdagradOptimizer),
|
||||||
|
('DPAdam', dp_optimizer.DPAdamOptimizer))
|
||||||
|
def testUnrollMicrobatches(self, cls):
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
var0 = tf.Variable([1.0, 2.0])
|
||||||
|
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
||||||
|
|
||||||
|
opt = cls(
|
||||||
|
l2_norm_clip=1.0e9,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
num_microbatches=4,
|
||||||
|
learning_rate=2.0,
|
||||||
|
unroll_microbatches=True)
|
||||||
|
|
||||||
|
self.evaluate(tf.global_variables_initializer())
|
||||||
|
# Fetch params to validate initial values
|
||||||
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
|
|
||||||
|
# Expected gradient is sum of differences divided by number of
|
||||||
|
# microbatches.
|
||||||
|
gradient_op = opt.compute_gradients(loss(data0, var0), [var0])
|
||||||
|
grads_and_vars = sess.run(gradient_op)
|
||||||
|
self.assertAllCloseAccordingToType([-2.5, -2.5], grads_and_vars[0][0])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue