diff --git a/privacy/optimizers/dp_optimizer.py b/privacy/optimizers/dp_optimizer.py index 268e17e..a0f7ebf 100644 --- a/privacy/optimizers/dp_optimizer.py +++ b/privacy/optimizers/dp_optimizer.py @@ -111,7 +111,10 @@ def make_optimizer_class(cls): tf.reduce_mean(tf.gather(microbatches_losses, [i])), var_list, gate_gradients, aggregation_method, colocate_gradients_with_ops, grad_loss)) - grads_list = list(grads) + grads_list = [ + g if g is not None else tf.zeros_like(v) + for (g, v) in zip(list(grads), var_list) + ] sample_state = self._dp_average_query.accumulate_record( sample_params, sample_state, grads_list) return sample_state