Enable optimizers to handle variables whose gradients are None.
PiperOrigin-RevId: 322193798
This commit is contained in:
parent
1a959eec34
commit
87c01eb2f5
3 changed files with 46 additions and 11 deletions
|
@ -207,6 +207,10 @@ def zeros_like(arg):
|
||||||
return tf.zeros(arg.shape, arg.dtype)
|
return tf.zeros(arg.shape, arg.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_add(x, y):
|
||||||
|
return x if y is None else tf.add(x, y)
|
||||||
|
|
||||||
|
|
||||||
class SumAggregationDPQuery(DPQuery):
|
class SumAggregationDPQuery(DPQuery):
|
||||||
"""Base class for DPQueries that aggregate via sum."""
|
"""Base class for DPQueries that aggregate via sum."""
|
||||||
|
|
||||||
|
@ -214,7 +218,7 @@ class SumAggregationDPQuery(DPQuery):
|
||||||
return tf.nest.map_structure(zeros_like, template)
|
return tf.nest.map_structure(zeros_like, template)
|
||||||
|
|
||||||
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
|
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
|
||||||
return tf.nest.map_structure(tf.add, sample_state, preprocessed_record)
|
return tf.nest.map_structure(safe_add, sample_state, preprocessed_record)
|
||||||
|
|
||||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||||
return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2)
|
return tf.nest.map_structure(safe_add, sample_state_1, sample_state_2)
|
||||||
|
|
|
@ -135,14 +135,13 @@ def make_optimizer_class(cls):
|
||||||
|
|
||||||
def process_microbatch(i, sample_state):
|
def process_microbatch(i, sample_state):
|
||||||
"""Process one microbatch (record) with privacy helper."""
|
"""Process one microbatch (record) with privacy helper."""
|
||||||
grads, _ = zip(*super(DPOptimizerClass, self).compute_gradients(
|
grads, _ = zip(
|
||||||
tf.reduce_mean(input_tensor=tf.gather(
|
*super(DPOptimizerClass, self).compute_gradients(
|
||||||
microbatches_losses, [i])), var_list, gate_gradients,
|
tf.reduce_mean(
|
||||||
|
input_tensor=tf.gather(microbatches_losses,
|
||||||
|
[i])), var_list, gate_gradients,
|
||||||
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
||||||
grads_list = [
|
grads_list = list(grads)
|
||||||
g if g is not None else tf.zeros_like(v)
|
|
||||||
for (g, v) in zip(list(grads), var_list)
|
|
||||||
]
|
|
||||||
sample_state = self._dp_sum_query.accumulate_record(
|
sample_state = self._dp_sum_query.accumulate_record(
|
||||||
sample_params, sample_state, grads_list)
|
sample_params, sample_state, grads_list)
|
||||||
return sample_state
|
return sample_state
|
||||||
|
@ -172,7 +171,10 @@ def make_optimizer_class(cls):
|
||||||
sample_state, self._global_state))
|
sample_state, self._global_state))
|
||||||
|
|
||||||
def normalize(v):
|
def normalize(v):
|
||||||
|
try:
|
||||||
return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32))
|
return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32))
|
||||||
|
except TypeError:
|
||||||
|
return None
|
||||||
|
|
||||||
final_grads = tf.nest.map_structure(normalize, grad_sums)
|
final_grads = tf.nest.map_structure(normalize, grad_sums)
|
||||||
|
|
||||||
|
|
|
@ -267,6 +267,35 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
grads_and_vars = opt.compute_gradients(self._loss(data0, var0), [var0])
|
grads_and_vars = opt.compute_gradients(self._loss(data0, var0), [var0])
|
||||||
opt.apply_gradients(grads_and_vars)
|
opt.apply_gradients(grads_and_vars)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1,
|
||||||
|
[-2.5, -2.5]),
|
||||||
|
('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2,
|
||||||
|
[-2.5, -2.5]),
|
||||||
|
)
|
||||||
|
def testNoneGradients(self, cls, num_microbatches, expected_answer):
|
||||||
|
"""Tests that optimizers can handle variables whose gradients are None."""
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
var0 = tf.Variable([1.0, 2.0])
|
||||||
|
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
||||||
|
# Create a string variable whose gradient will be None.
|
||||||
|
extra_variable = tf.Variable('foo', trainable=True, dtype=tf.string)
|
||||||
|
|
||||||
|
dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0)
|
||||||
|
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6,
|
||||||
|
num_microbatches / 1e6)
|
||||||
|
|
||||||
|
opt = cls(
|
||||||
|
dp_sum_query, num_microbatches=num_microbatches, learning_rate=2.0)
|
||||||
|
|
||||||
|
self.evaluate(tf.global_variables_initializer())
|
||||||
|
# Fetch params to validate initial values
|
||||||
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
|
|
||||||
|
minimize_op = opt.minimize(
|
||||||
|
loss=self._loss(data0, var0), var_list=[var0, extra_variable])
|
||||||
|
sess.run(minimize_op)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue