Enable optimizers to handle variables whose gradients are None.

PiperOrigin-RevId: 322193798
This commit is contained in:
Steve Chien 2020-07-20 11:58:57 -07:00 committed by A. Unique TensorFlower
parent 1a959eec34
commit 87c01eb2f5
3 changed files with 46 additions and 11 deletions

View file

@ -207,6 +207,10 @@ def zeros_like(arg):
return tf.zeros(arg.shape, arg.dtype) return tf.zeros(arg.shape, arg.dtype)
def safe_add(x, y):
return x if y is None else tf.add(x, y)
class SumAggregationDPQuery(DPQuery): class SumAggregationDPQuery(DPQuery):
"""Base class for DPQueries that aggregate via sum.""" """Base class for DPQueries that aggregate via sum."""
@ -214,7 +218,7 @@ class SumAggregationDPQuery(DPQuery):
return tf.nest.map_structure(zeros_like, template) return tf.nest.map_structure(zeros_like, template)
def accumulate_preprocessed_record(self, sample_state, preprocessed_record): def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
return tf.nest.map_structure(tf.add, sample_state, preprocessed_record) return tf.nest.map_structure(safe_add, sample_state, preprocessed_record)
def merge_sample_states(self, sample_state_1, sample_state_2): def merge_sample_states(self, sample_state_1, sample_state_2):
return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2) return tf.nest.map_structure(safe_add, sample_state_1, sample_state_2)

View file

@ -135,14 +135,13 @@ def make_optimizer_class(cls):
def process_microbatch(i, sample_state): def process_microbatch(i, sample_state):
"""Process one microbatch (record) with privacy helper.""" """Process one microbatch (record) with privacy helper."""
grads, _ = zip(*super(DPOptimizerClass, self).compute_gradients( grads, _ = zip(
tf.reduce_mean(input_tensor=tf.gather( *super(DPOptimizerClass, self).compute_gradients(
microbatches_losses, [i])), var_list, gate_gradients, tf.reduce_mean(
input_tensor=tf.gather(microbatches_losses,
[i])), var_list, gate_gradients,
aggregation_method, colocate_gradients_with_ops, grad_loss)) aggregation_method, colocate_gradients_with_ops, grad_loss))
grads_list = [ grads_list = list(grads)
g if g is not None else tf.zeros_like(v)
for (g, v) in zip(list(grads), var_list)
]
sample_state = self._dp_sum_query.accumulate_record( sample_state = self._dp_sum_query.accumulate_record(
sample_params, sample_state, grads_list) sample_params, sample_state, grads_list)
return sample_state return sample_state
@ -172,7 +171,10 @@ def make_optimizer_class(cls):
sample_state, self._global_state)) sample_state, self._global_state))
def normalize(v): def normalize(v):
try:
return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32)) return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32))
except TypeError:
return None
final_grads = tf.nest.map_structure(normalize, grad_sums) final_grads = tf.nest.map_structure(normalize, grad_sums)

View file

@ -267,6 +267,35 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
grads_and_vars = opt.compute_gradients(self._loss(data0, var0), [var0]) grads_and_vars = opt.compute_gradients(self._loss(data0, var0), [var0])
opt.apply_gradients(grads_and_vars) opt.apply_gradients(grads_and_vars)
@parameterized.named_parameters(
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1,
[-2.5, -2.5]),
('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2,
[-2.5, -2.5]),
)
def testNoneGradients(self, cls, num_microbatches, expected_answer):
"""Tests that optimizers can handle variables whose gradients are None."""
with self.cached_session() as sess:
var0 = tf.Variable([1.0, 2.0])
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
# Create a string variable whose gradient will be None.
extra_variable = tf.Variable('foo', trainable=True, dtype=tf.string)
dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6,
num_microbatches / 1e6)
opt = cls(
dp_sum_query, num_microbatches=num_microbatches, learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
minimize_op = opt.minimize(
loss=self._loss(data0, var0), var_list=[var0, extra_variable])
sess.run(minimize_op)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()