forked from 626_privacy/tensorflow_privacy
Fix Keras DP optimizer when num_microbatches == None.
Optimizer should not save TF tensors into class members, otherwise code may not work in some cases with tf.function. PiperOrigin-RevId: 374976737
This commit is contained in:
parent
e5848656ed
commit
a03374be6c
2 changed files with 23 additions and 18 deletions
|
@ -88,21 +88,23 @@ def make_keras_optimizer_class(cls):
|
|||
tape.watch(var_list)
|
||||
|
||||
loss = loss()
|
||||
batch_size = tf.shape(input=loss)[0]
|
||||
if self._num_microbatches is None:
|
||||
self._num_microbatches = batch_size
|
||||
num_microbatches = tf.shape(input=loss)[0]
|
||||
else:
|
||||
num_microbatches = self._num_microbatches
|
||||
microbatch_losses = tf.reduce_mean(
|
||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
||||
tf.reshape(loss, [num_microbatches, -1]), axis=1)
|
||||
|
||||
if callable(var_list):
|
||||
var_list = var_list()
|
||||
else:
|
||||
with tape:
|
||||
batch_size = tf.shape(input=loss)[0]
|
||||
if self._num_microbatches is None:
|
||||
self._num_microbatches = batch_size
|
||||
num_microbatches = tf.shape(input=loss)[0]
|
||||
else:
|
||||
num_microbatches = self._num_microbatches
|
||||
microbatch_losses = tf.reduce_mean(
|
||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
||||
tf.reshape(loss, [num_microbatches, -1]), axis=1)
|
||||
|
||||
var_list = tf.nest.flatten(var_list)
|
||||
|
||||
|
@ -128,7 +130,7 @@ def make_keras_optimizer_class(cls):
|
|||
|
||||
# Normalize by number of microbatches and return.
|
||||
return tf.truediv(noised_gradient,
|
||||
tf.cast(self._num_microbatches, tf.float32))
|
||||
tf.cast(num_microbatches, tf.float32))
|
||||
|
||||
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
|
||||
clipped_gradients)
|
||||
|
|
|
@ -101,21 +101,23 @@ def make_vectorized_keras_optimizer_class(cls):
|
|||
tape.watch(var_list)
|
||||
|
||||
loss = loss()
|
||||
batch_size = tf.shape(input=loss)[0]
|
||||
if self._num_microbatches is None:
|
||||
self._num_microbatches = batch_size
|
||||
num_microbatches = tf.shape(input=loss)[0]
|
||||
else:
|
||||
num_microbatches = self._num_microbatches
|
||||
microbatch_losses = tf.reduce_mean(
|
||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
||||
tf.reshape(loss, [num_microbatches, -1]), axis=1)
|
||||
|
||||
if callable(var_list):
|
||||
var_list = var_list()
|
||||
else:
|
||||
with tape:
|
||||
batch_size = tf.shape(input=loss)[0]
|
||||
if self._num_microbatches is None:
|
||||
self._num_microbatches = batch_size
|
||||
num_microbatches = tf.shape(input=loss)[0]
|
||||
else:
|
||||
num_microbatches = self._num_microbatches
|
||||
microbatch_losses = tf.reduce_mean(
|
||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
||||
tf.reshape(loss, [num_microbatches, -1]), axis=1)
|
||||
|
||||
var_list = tf.nest.flatten(var_list)
|
||||
|
||||
|
@ -138,7 +140,7 @@ def make_vectorized_keras_optimizer_class(cls):
|
|||
|
||||
# Normalize by number of microbatches and return.
|
||||
return tf.truediv(noised_gradient,
|
||||
tf.cast(self._num_microbatches, tf.float32))
|
||||
tf.cast(num_microbatches, tf.float32))
|
||||
|
||||
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
|
||||
clipped_gradients)
|
||||
|
@ -152,11 +154,12 @@ def make_vectorized_keras_optimizer_class(cls):
|
|||
if self._global_state is None:
|
||||
self._global_state = self._dp_sum_query.initial_global_state()
|
||||
|
||||
batch_size = tf.shape(input=loss)[0]
|
||||
if self._num_microbatches is None:
|
||||
self._num_microbatches = batch_size
|
||||
num_microbatches = tf.shape(input=loss)[0]
|
||||
else:
|
||||
num_microbatches = self._num_microbatches
|
||||
|
||||
microbatch_losses = tf.reshape(loss, [self._num_microbatches, -1])
|
||||
microbatch_losses = tf.reshape(loss, [num_microbatches, -1])
|
||||
|
||||
def process_microbatch(microbatch_loss):
|
||||
"""Compute clipped grads for one microbatch."""
|
||||
|
@ -177,7 +180,7 @@ def make_vectorized_keras_optimizer_class(cls):
|
|||
noise = tf.random.normal(
|
||||
tf.shape(input=summed_grads), stddev=noise_stddev)
|
||||
noised_grads = summed_grads + noise
|
||||
return noised_grads / tf.cast(self._num_microbatches, tf.float32)
|
||||
return noised_grads / tf.cast(num_microbatches, tf.float32)
|
||||
|
||||
final_grads = tf.nest.map_structure(reduce_noise_normalize_batch,
|
||||
clipped_grads)
|
||||
|
|
Loading…
Reference in a new issue