diff --git a/privacy/optimizers/dp_optimizer.py b/privacy/optimizers/dp_optimizer.py index e70086f..83a3f4d 100644 --- a/privacy/optimizers/dp_optimizer.py +++ b/privacy/optimizers/dp_optimizer.py @@ -93,8 +93,6 @@ def make_optimizer_class(cls): vector_loss = loss() if self._num_microbatches is None: self._num_microbatches = tf.shape(vector_loss)[0] - if isinstance(self._dp_sum_query, privacy_ledger.QueryWithLedger): - self._dp_sum_query.set_batch_size(self._num_microbatches) sample_state = self._dp_sum_query.initial_sample_state(var_list) microbatches_losses = tf.reshape(vector_loss, [self._num_microbatches, -1]) @@ -135,8 +133,6 @@ def make_optimizer_class(cls): # sampling from the dataset without replacement. if self._num_microbatches is None: self._num_microbatches = tf.shape(loss)[0] - if isinstance(self._dp_sum_query, privacy_ledger.QueryWithLedger): - self._dp_sum_query.set_batch_size(self._num_microbatches) microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1]) sample_params = (