diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py index eac2916..1f69f1a 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py @@ -88,21 +88,23 @@ def make_keras_optimizer_class(cls): tape.watch(var_list) loss = loss() - batch_size = tf.shape(input=loss)[0] if self._num_microbatches is None: - self._num_microbatches = batch_size + num_microbatches = tf.shape(input=loss)[0] + else: + num_microbatches = self._num_microbatches microbatch_losses = tf.reduce_mean( - tf.reshape(loss, [self._num_microbatches, -1]), axis=1) + tf.reshape(loss, [num_microbatches, -1]), axis=1) if callable(var_list): var_list = var_list() else: with tape: - batch_size = tf.shape(input=loss)[0] if self._num_microbatches is None: - self._num_microbatches = batch_size + num_microbatches = tf.shape(input=loss)[0] + else: + num_microbatches = self._num_microbatches microbatch_losses = tf.reduce_mean( - tf.reshape(loss, [self._num_microbatches, -1]), axis=1) + tf.reshape(loss, [num_microbatches, -1]), axis=1) var_list = tf.nest.flatten(var_list) @@ -128,7 +130,7 @@ def make_keras_optimizer_class(cls): # Normalize by number of microbatches and return. return tf.truediv(noised_gradient, - tf.cast(self._num_microbatches, tf.float32)) + tf.cast(num_microbatches, tf.float32)) final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch, clipped_gradients) diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py index e1d3fc6..ce375b8 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py @@ -101,21 +101,23 @@ def make_vectorized_keras_optimizer_class(cls): tape.watch(var_list) loss = loss() - batch_size = tf.shape(input=loss)[0] if self._num_microbatches is None: - self._num_microbatches = batch_size + num_microbatches = tf.shape(input=loss)[0] + else: + num_microbatches = self._num_microbatches microbatch_losses = tf.reduce_mean( - tf.reshape(loss, [self._num_microbatches, -1]), axis=1) + tf.reshape(loss, [num_microbatches, -1]), axis=1) if callable(var_list): var_list = var_list() else: with tape: - batch_size = tf.shape(input=loss)[0] if self._num_microbatches is None: - self._num_microbatches = batch_size + num_microbatches = tf.shape(input=loss)[0] + else: + num_microbatches = self._num_microbatches microbatch_losses = tf.reduce_mean( - tf.reshape(loss, [self._num_microbatches, -1]), axis=1) + tf.reshape(loss, [num_microbatches, -1]), axis=1) var_list = tf.nest.flatten(var_list) @@ -138,7 +140,7 @@ def make_vectorized_keras_optimizer_class(cls): # Normalize by number of microbatches and return. return tf.truediv(noised_gradient, - tf.cast(self._num_microbatches, tf.float32)) + tf.cast(num_microbatches, tf.float32)) final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch, clipped_gradients) @@ -152,11 +154,12 @@ def make_vectorized_keras_optimizer_class(cls): if self._global_state is None: self._global_state = self._dp_sum_query.initial_global_state() - batch_size = tf.shape(input=loss)[0] if self._num_microbatches is None: - self._num_microbatches = batch_size + num_microbatches = tf.shape(input=loss)[0] + else: + num_microbatches = self._num_microbatches - microbatch_losses = tf.reshape(loss, [self._num_microbatches, -1]) + microbatch_losses = tf.reshape(loss, [num_microbatches, -1]) def process_microbatch(microbatch_loss): """Compute clipped grads for one microbatch.""" @@ -177,7 +180,7 @@ def make_vectorized_keras_optimizer_class(cls): noise = tf.random.normal( tf.shape(input=summed_grads), stddev=noise_stddev) noised_grads = summed_grads + noise - return noised_grads / tf.cast(self._num_microbatches, tf.float32) + return noised_grads / tf.cast(num_microbatches, tf.float32) final_grads = tf.nest.map_structure(reduce_noise_normalize_batch, clipped_grads)