forked from 626_privacy/tensorflow_privacy
Fix Keras DP optimizer when num_microbatches == None.
Optimizer should not save TF tensors into class members, otherwise code may not work in some cases with tf.function. PiperOrigin-RevId: 374976737
This commit is contained in:
parent
e5848656ed
commit
a03374be6c
2 changed files with 23 additions and 18 deletions
|
@ -88,21 +88,23 @@ def make_keras_optimizer_class(cls):
|
||||||
tape.watch(var_list)
|
tape.watch(var_list)
|
||||||
|
|
||||||
loss = loss()
|
loss = loss()
|
||||||
batch_size = tf.shape(input=loss)[0]
|
|
||||||
if self._num_microbatches is None:
|
if self._num_microbatches is None:
|
||||||
self._num_microbatches = batch_size
|
num_microbatches = tf.shape(input=loss)[0]
|
||||||
|
else:
|
||||||
|
num_microbatches = self._num_microbatches
|
||||||
microbatch_losses = tf.reduce_mean(
|
microbatch_losses = tf.reduce_mean(
|
||||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
tf.reshape(loss, [num_microbatches, -1]), axis=1)
|
||||||
|
|
||||||
if callable(var_list):
|
if callable(var_list):
|
||||||
var_list = var_list()
|
var_list = var_list()
|
||||||
else:
|
else:
|
||||||
with tape:
|
with tape:
|
||||||
batch_size = tf.shape(input=loss)[0]
|
|
||||||
if self._num_microbatches is None:
|
if self._num_microbatches is None:
|
||||||
self._num_microbatches = batch_size
|
num_microbatches = tf.shape(input=loss)[0]
|
||||||
|
else:
|
||||||
|
num_microbatches = self._num_microbatches
|
||||||
microbatch_losses = tf.reduce_mean(
|
microbatch_losses = tf.reduce_mean(
|
||||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
tf.reshape(loss, [num_microbatches, -1]), axis=1)
|
||||||
|
|
||||||
var_list = tf.nest.flatten(var_list)
|
var_list = tf.nest.flatten(var_list)
|
||||||
|
|
||||||
|
@ -128,7 +130,7 @@ def make_keras_optimizer_class(cls):
|
||||||
|
|
||||||
# Normalize by number of microbatches and return.
|
# Normalize by number of microbatches and return.
|
||||||
return tf.truediv(noised_gradient,
|
return tf.truediv(noised_gradient,
|
||||||
tf.cast(self._num_microbatches, tf.float32))
|
tf.cast(num_microbatches, tf.float32))
|
||||||
|
|
||||||
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
|
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
|
||||||
clipped_gradients)
|
clipped_gradients)
|
||||||
|
|
|
@ -101,21 +101,23 @@ def make_vectorized_keras_optimizer_class(cls):
|
||||||
tape.watch(var_list)
|
tape.watch(var_list)
|
||||||
|
|
||||||
loss = loss()
|
loss = loss()
|
||||||
batch_size = tf.shape(input=loss)[0]
|
|
||||||
if self._num_microbatches is None:
|
if self._num_microbatches is None:
|
||||||
self._num_microbatches = batch_size
|
num_microbatches = tf.shape(input=loss)[0]
|
||||||
|
else:
|
||||||
|
num_microbatches = self._num_microbatches
|
||||||
microbatch_losses = tf.reduce_mean(
|
microbatch_losses = tf.reduce_mean(
|
||||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
tf.reshape(loss, [num_microbatches, -1]), axis=1)
|
||||||
|
|
||||||
if callable(var_list):
|
if callable(var_list):
|
||||||
var_list = var_list()
|
var_list = var_list()
|
||||||
else:
|
else:
|
||||||
with tape:
|
with tape:
|
||||||
batch_size = tf.shape(input=loss)[0]
|
|
||||||
if self._num_microbatches is None:
|
if self._num_microbatches is None:
|
||||||
self._num_microbatches = batch_size
|
num_microbatches = tf.shape(input=loss)[0]
|
||||||
|
else:
|
||||||
|
num_microbatches = self._num_microbatches
|
||||||
microbatch_losses = tf.reduce_mean(
|
microbatch_losses = tf.reduce_mean(
|
||||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
tf.reshape(loss, [num_microbatches, -1]), axis=1)
|
||||||
|
|
||||||
var_list = tf.nest.flatten(var_list)
|
var_list = tf.nest.flatten(var_list)
|
||||||
|
|
||||||
|
@ -138,7 +140,7 @@ def make_vectorized_keras_optimizer_class(cls):
|
||||||
|
|
||||||
# Normalize by number of microbatches and return.
|
# Normalize by number of microbatches and return.
|
||||||
return tf.truediv(noised_gradient,
|
return tf.truediv(noised_gradient,
|
||||||
tf.cast(self._num_microbatches, tf.float32))
|
tf.cast(num_microbatches, tf.float32))
|
||||||
|
|
||||||
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
|
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
|
||||||
clipped_gradients)
|
clipped_gradients)
|
||||||
|
@ -152,11 +154,12 @@ def make_vectorized_keras_optimizer_class(cls):
|
||||||
if self._global_state is None:
|
if self._global_state is None:
|
||||||
self._global_state = self._dp_sum_query.initial_global_state()
|
self._global_state = self._dp_sum_query.initial_global_state()
|
||||||
|
|
||||||
batch_size = tf.shape(input=loss)[0]
|
|
||||||
if self._num_microbatches is None:
|
if self._num_microbatches is None:
|
||||||
self._num_microbatches = batch_size
|
num_microbatches = tf.shape(input=loss)[0]
|
||||||
|
else:
|
||||||
|
num_microbatches = self._num_microbatches
|
||||||
|
|
||||||
microbatch_losses = tf.reshape(loss, [self._num_microbatches, -1])
|
microbatch_losses = tf.reshape(loss, [num_microbatches, -1])
|
||||||
|
|
||||||
def process_microbatch(microbatch_loss):
|
def process_microbatch(microbatch_loss):
|
||||||
"""Compute clipped grads for one microbatch."""
|
"""Compute clipped grads for one microbatch."""
|
||||||
|
@ -177,7 +180,7 @@ def make_vectorized_keras_optimizer_class(cls):
|
||||||
noise = tf.random.normal(
|
noise = tf.random.normal(
|
||||||
tf.shape(input=summed_grads), stddev=noise_stddev)
|
tf.shape(input=summed_grads), stddev=noise_stddev)
|
||||||
noised_grads = summed_grads + noise
|
noised_grads = summed_grads + noise
|
||||||
return noised_grads / tf.cast(self._num_microbatches, tf.float32)
|
return noised_grads / tf.cast(num_microbatches, tf.float32)
|
||||||
|
|
||||||
final_grads = tf.nest.map_structure(reduce_noise_normalize_batch,
|
final_grads = tf.nest.map_structure(reduce_noise_normalize_batch,
|
||||||
clipped_grads)
|
clipped_grads)
|
||||||
|
|
Loading…
Reference in a new issue