Remove calls to _dp_sum_query.set_batch_size in dp_optimizer.py, as no method with that name exists for objects of class QueryWithLedger.

PiperOrigin-RevId: 259858031
This commit is contained in:
A. Unique TensorFlower 2019-07-24 18:08:03 -07:00
parent 28fd879864
commit 5cd2439401

View file

@ -93,8 +93,6 @@ def make_optimizer_class(cls):
vector_loss = loss() vector_loss = loss()
if self._num_microbatches is None: if self._num_microbatches is None:
self._num_microbatches = tf.shape(vector_loss)[0] self._num_microbatches = tf.shape(vector_loss)[0]
if isinstance(self._dp_sum_query, privacy_ledger.QueryWithLedger):
self._dp_sum_query.set_batch_size(self._num_microbatches)
sample_state = self._dp_sum_query.initial_sample_state(var_list) sample_state = self._dp_sum_query.initial_sample_state(var_list)
microbatches_losses = tf.reshape(vector_loss, microbatches_losses = tf.reshape(vector_loss,
[self._num_microbatches, -1]) [self._num_microbatches, -1])
@ -135,8 +133,6 @@ def make_optimizer_class(cls):
# sampling from the dataset without replacement. # sampling from the dataset without replacement.
if self._num_microbatches is None: if self._num_microbatches is None:
self._num_microbatches = tf.shape(loss)[0] self._num_microbatches = tf.shape(loss)[0]
if isinstance(self._dp_sum_query, privacy_ledger.QueryWithLedger):
self._dp_sum_query.set_batch_size(self._num_microbatches)
microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1]) microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
sample_params = ( sample_params = (