Fix Keras DP optimizer when num_microbatches == None.

Optimizer should not save TF tensors into class members, otherwise code may not work in some cases with tf.function.

PiperOrigin-RevId: 374976737
This commit is contained in:
A. Unique TensorFlower 2021-05-20 16:46:30 -07:00
parent e5848656ed
commit a03374be6c
2 changed files with 23 additions and 18 deletions

View file

@ -88,21 +88,23 @@ def make_keras_optimizer_class(cls):
tape.watch(var_list)
loss = loss()
batch_size = tf.shape(input=loss)[0]
if self._num_microbatches is None:
self._num_microbatches = batch_size
num_microbatches = tf.shape(input=loss)[0]
else:
num_microbatches = self._num_microbatches
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
tf.reshape(loss, [num_microbatches, -1]), axis=1)
if callable(var_list):
var_list = var_list()
else:
with tape:
batch_size = tf.shape(input=loss)[0]
if self._num_microbatches is None:
self._num_microbatches = batch_size
num_microbatches = tf.shape(input=loss)[0]
else:
num_microbatches = self._num_microbatches
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
tf.reshape(loss, [num_microbatches, -1]), axis=1)
var_list = tf.nest.flatten(var_list)
@ -128,7 +130,7 @@ def make_keras_optimizer_class(cls):
# Normalize by number of microbatches and return.
return tf.truediv(noised_gradient,
tf.cast(self._num_microbatches, tf.float32))
tf.cast(num_microbatches, tf.float32))
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
clipped_gradients)

View file

@ -101,21 +101,23 @@ def make_vectorized_keras_optimizer_class(cls):
tape.watch(var_list)
loss = loss()
batch_size = tf.shape(input=loss)[0]
if self._num_microbatches is None:
self._num_microbatches = batch_size
num_microbatches = tf.shape(input=loss)[0]
else:
num_microbatches = self._num_microbatches
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
tf.reshape(loss, [num_microbatches, -1]), axis=1)
if callable(var_list):
var_list = var_list()
else:
with tape:
batch_size = tf.shape(input=loss)[0]
if self._num_microbatches is None:
self._num_microbatches = batch_size
num_microbatches = tf.shape(input=loss)[0]
else:
num_microbatches = self._num_microbatches
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
tf.reshape(loss, [num_microbatches, -1]), axis=1)
var_list = tf.nest.flatten(var_list)
@ -138,7 +140,7 @@ def make_vectorized_keras_optimizer_class(cls):
# Normalize by number of microbatches and return.
return tf.truediv(noised_gradient,
tf.cast(self._num_microbatches, tf.float32))
tf.cast(num_microbatches, tf.float32))
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
clipped_gradients)
@ -152,11 +154,12 @@ def make_vectorized_keras_optimizer_class(cls):
if self._global_state is None:
self._global_state = self._dp_sum_query.initial_global_state()
batch_size = tf.shape(input=loss)[0]
if self._num_microbatches is None:
self._num_microbatches = batch_size
num_microbatches = tf.shape(input=loss)[0]
else:
num_microbatches = self._num_microbatches
microbatch_losses = tf.reshape(loss, [self._num_microbatches, -1])
microbatch_losses = tf.reshape(loss, [num_microbatches, -1])
def process_microbatch(microbatch_loss):
"""Compute clipped grads for one microbatch."""
@ -177,7 +180,7 @@ def make_vectorized_keras_optimizer_class(cls):
noise = tf.random.normal(
tf.shape(input=summed_grads), stddev=noise_stddev)
noised_grads = summed_grads + noise
return noised_grads / tf.cast(self._num_microbatches, tf.float32)
return noised_grads / tf.cast(num_microbatches, tf.float32)
final_grads = tf.nest.map_structure(reduce_noise_normalize_batch,
clipped_grads)