Enable optimizers to handle variables whose gradients are None.
PiperOrigin-RevId: 322193798
This commit is contained in:
parent
1a959eec34
commit
87c01eb2f5
3 changed files with 46 additions and 11 deletions
|
@ -207,6 +207,10 @@ def zeros_like(arg):
|
|||
return tf.zeros(arg.shape, arg.dtype)
|
||||
|
||||
|
||||
def safe_add(x, y):
|
||||
return x if y is None else tf.add(x, y)
|
||||
|
||||
|
||||
class SumAggregationDPQuery(DPQuery):
|
||||
"""Base class for DPQueries that aggregate via sum."""
|
||||
|
||||
|
@ -214,7 +218,7 @@ class SumAggregationDPQuery(DPQuery):
|
|||
return tf.nest.map_structure(zeros_like, template)
|
||||
|
||||
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
|
||||
return tf.nest.map_structure(tf.add, sample_state, preprocessed_record)
|
||||
return tf.nest.map_structure(safe_add, sample_state, preprocessed_record)
|
||||
|
||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||
return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2)
|
||||
return tf.nest.map_structure(safe_add, sample_state_1, sample_state_2)
|
||||
|
|
|
@ -135,14 +135,13 @@ def make_optimizer_class(cls):
|
|||
|
||||
def process_microbatch(i, sample_state):
|
||||
"""Process one microbatch (record) with privacy helper."""
|
||||
grads, _ = zip(*super(DPOptimizerClass, self).compute_gradients(
|
||||
tf.reduce_mean(input_tensor=tf.gather(
|
||||
microbatches_losses, [i])), var_list, gate_gradients,
|
||||
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
||||
grads_list = [
|
||||
g if g is not None else tf.zeros_like(v)
|
||||
for (g, v) in zip(list(grads), var_list)
|
||||
]
|
||||
grads, _ = zip(
|
||||
*super(DPOptimizerClass, self).compute_gradients(
|
||||
tf.reduce_mean(
|
||||
input_tensor=tf.gather(microbatches_losses,
|
||||
[i])), var_list, gate_gradients,
|
||||
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
||||
grads_list = list(grads)
|
||||
sample_state = self._dp_sum_query.accumulate_record(
|
||||
sample_params, sample_state, grads_list)
|
||||
return sample_state
|
||||
|
@ -172,7 +171,10 @@ def make_optimizer_class(cls):
|
|||
sample_state, self._global_state))
|
||||
|
||||
def normalize(v):
|
||||
return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32))
|
||||
try:
|
||||
return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32))
|
||||
except TypeError:
|
||||
return None
|
||||
|
||||
final_grads = tf.nest.map_structure(normalize, grad_sums)
|
||||
|
||||
|
|
|
@ -267,6 +267,35 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
grads_and_vars = opt.compute_gradients(self._loss(data0, var0), [var0])
|
||||
opt.apply_gradients(grads_and_vars)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1,
|
||||
[-2.5, -2.5]),
|
||||
('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2,
|
||||
[-2.5, -2.5]),
|
||||
)
|
||||
def testNoneGradients(self, cls, num_microbatches, expected_answer):
|
||||
"""Tests that optimizers can handle variables whose gradients are None."""
|
||||
with self.cached_session() as sess:
|
||||
var0 = tf.Variable([1.0, 2.0])
|
||||
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
||||
# Create a string variable whose gradient will be None.
|
||||
extra_variable = tf.Variable('foo', trainable=True, dtype=tf.string)
|
||||
|
||||
dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0)
|
||||
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6,
|
||||
num_microbatches / 1e6)
|
||||
|
||||
opt = cls(
|
||||
dp_sum_query, num_microbatches=num_microbatches, learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||
|
||||
minimize_op = opt.minimize(
|
||||
loss=self._loss(data0, var0), var_list=[var0, extra_variable])
|
||||
sess.run(minimize_op)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in a new issue