diff --git a/tensorflow_privacy/privacy/dp_query/dp_query.py b/tensorflow_privacy/privacy/dp_query/dp_query.py index 045ca0c..782c893 100644 --- a/tensorflow_privacy/privacy/dp_query/dp_query.py +++ b/tensorflow_privacy/privacy/dp_query/dp_query.py @@ -207,6 +207,10 @@ def zeros_like(arg): return tf.zeros(arg.shape, arg.dtype) +def safe_add(x, y): + return x if y is None else tf.add(x, y) + + class SumAggregationDPQuery(DPQuery): """Base class for DPQueries that aggregate via sum.""" @@ -214,7 +218,7 @@ class SumAggregationDPQuery(DPQuery): return tf.nest.map_structure(zeros_like, template) def accumulate_preprocessed_record(self, sample_state, preprocessed_record): - return tf.nest.map_structure(tf.add, sample_state, preprocessed_record) + return tf.nest.map_structure(safe_add, sample_state, preprocessed_record) def merge_sample_states(self, sample_state_1, sample_state_2): - return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2) + return tf.nest.map_structure(safe_add, sample_state_1, sample_state_2) diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer.py index 83bcff9..cf5519b 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer.py @@ -135,14 +135,13 @@ def make_optimizer_class(cls): def process_microbatch(i, sample_state): """Process one microbatch (record) with privacy helper.""" - grads, _ = zip(*super(DPOptimizerClass, self).compute_gradients( - tf.reduce_mean(input_tensor=tf.gather( - microbatches_losses, [i])), var_list, gate_gradients, - aggregation_method, colocate_gradients_with_ops, grad_loss)) - grads_list = [ - g if g is not None else tf.zeros_like(v) - for (g, v) in zip(list(grads), var_list) - ] + grads, _ = zip( + *super(DPOptimizerClass, self).compute_gradients( + tf.reduce_mean( + input_tensor=tf.gather(microbatches_losses, + [i])), var_list, gate_gradients, + aggregation_method, colocate_gradients_with_ops, grad_loss)) + grads_list = list(grads) sample_state = self._dp_sum_query.accumulate_record( sample_params, sample_state, grads_list) return sample_state @@ -172,7 +171,10 @@ def make_optimizer_class(cls): sample_state, self._global_state)) def normalize(v): - return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32)) + try: + return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32)) + except TypeError: + return None final_grads = tf.nest.map_structure(normalize, grad_sums) diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py index 2032a58..909c674 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py @@ -267,6 +267,35 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): grads_and_vars = opt.compute_gradients(self._loss(data0, var0), [var0]) opt.apply_gradients(grads_and_vars) + @parameterized.named_parameters( + ('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1, + [-2.5, -2.5]), + ('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2, + [-2.5, -2.5]), + ) + def testNoneGradients(self, cls, num_microbatches, expected_answer): + """Tests that optimizers can handle variables whose gradients are None.""" + with self.cached_session() as sess: + var0 = tf.Variable([1.0, 2.0]) + data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) + # Create a string variable whose gradient will be None. + extra_variable = tf.Variable('foo', trainable=True, dtype=tf.string) + + dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0) + dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, + num_microbatches / 1e6) + + opt = cls( + dp_sum_query, num_microbatches=num_microbatches, learning_rate=2.0) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + + minimize_op = opt.minimize( + loss=self._loss(data0, var0), var_list=[var0, extra_variable]) + sess.run(minimize_op) + if __name__ == '__main__': tf.test.main()