Remove calls to _dp_sum_query.set_batch_size in dp_optimizer.py, as no method with that name exists for objects of class QueryWithLedger.
PiperOrigin-RevId: 259858031
This commit is contained in:
parent
28fd879864
commit
5cd2439401
1 changed files with 0 additions and 4 deletions
|
@ -93,8 +93,6 @@ def make_optimizer_class(cls):
|
||||||
vector_loss = loss()
|
vector_loss = loss()
|
||||||
if self._num_microbatches is None:
|
if self._num_microbatches is None:
|
||||||
self._num_microbatches = tf.shape(vector_loss)[0]
|
self._num_microbatches = tf.shape(vector_loss)[0]
|
||||||
if isinstance(self._dp_sum_query, privacy_ledger.QueryWithLedger):
|
|
||||||
self._dp_sum_query.set_batch_size(self._num_microbatches)
|
|
||||||
sample_state = self._dp_sum_query.initial_sample_state(var_list)
|
sample_state = self._dp_sum_query.initial_sample_state(var_list)
|
||||||
microbatches_losses = tf.reshape(vector_loss,
|
microbatches_losses = tf.reshape(vector_loss,
|
||||||
[self._num_microbatches, -1])
|
[self._num_microbatches, -1])
|
||||||
|
@ -135,8 +133,6 @@ def make_optimizer_class(cls):
|
||||||
# sampling from the dataset without replacement.
|
# sampling from the dataset without replacement.
|
||||||
if self._num_microbatches is None:
|
if self._num_microbatches is None:
|
||||||
self._num_microbatches = tf.shape(loss)[0]
|
self._num_microbatches = tf.shape(loss)[0]
|
||||||
if isinstance(self._dp_sum_query, privacy_ledger.QueryWithLedger):
|
|
||||||
self._dp_sum_query.set_batch_size(self._num_microbatches)
|
|
||||||
|
|
||||||
microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
|
microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
|
||||||
sample_params = (
|
sample_params = (
|
||||||
|
|
Loading…
Reference in a new issue