Update keras optimizers (both traditional and vectorized) to handle case of num_microbatches=None.

PiperOrigin-RevId: 369497296
This commit is contained in:
Steve Chien 2021-04-20 12:35:00 -07:00 committed by A. Unique TensorFlower
parent 41530f4426
commit 755ed26671
3 changed files with 30 additions and 10 deletions

View file

@ -87,15 +87,20 @@ def make_keras_optimizer_class(cls):
if not callable(var_list): if not callable(var_list):
tape.watch(var_list) tape.watch(var_list)
if callable(loss): loss = loss()
loss = loss() batch_size = tf.shape(input=loss)[0]
microbatch_losses = tf.reduce_mean( if self._num_microbatches is None:
tf.reshape(loss, [self._num_microbatches, -1]), axis=1) self._num_microbatches = batch_size
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
if callable(var_list): if callable(var_list):
var_list = var_list() var_list = var_list()
else: else:
with tape: with tape:
batch_size = tf.shape(input=loss)[0]
if self._num_microbatches is None:
self._num_microbatches = batch_size
microbatch_losses = tf.reduce_mean( microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [self._num_microbatches, -1]), axis=1) tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
@ -122,7 +127,8 @@ def make_keras_optimizer_class(cls):
noised_gradient = tf.add(summed_gradient, noise) noised_gradient = tf.add(summed_gradient, noise)
# Normalize by number of microbatches and return. # Normalize by number of microbatches and return.
return tf.truediv(noised_gradient, self._num_microbatches) return tf.truediv(noised_gradient,
tf.cast(self._num_microbatches, tf.float32))
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch, final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
clipped_gradients) clipped_gradients)

View file

@ -51,6 +51,9 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
('DPAdagradVectorized 4', ('DPAdagradVectorized 4',
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4, dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4,
[-2.5, -2.5], [-0.5]), [-2.5, -2.5], [-0.5]),
('DPAdagradVectorized None',
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None,
[-2.5, -2.5], [-0.5]),
) )
def testBaselineWithCallableLoss(self, cls, num_microbatches, expected_grad0, def testBaselineWithCallableLoss(self, cls, num_microbatches, expected_grad0,
expected_grad1): expected_grad1):
@ -89,6 +92,9 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
('DPAdagradVectorized 4', ('DPAdagradVectorized 4',
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4, dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4,
[-2.5, -2.5], [-0.5]), [-2.5, -2.5], [-0.5]),
('DPAdagradVectorized None',
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None,
[-2.5, -2.5], [-0.5]),
) )
def testBaselineWithTensorLoss(self, cls, num_microbatches, expected_grad0, def testBaselineWithTensorLoss(self, cls, num_microbatches, expected_grad0,
expected_grad1): expected_grad1):
@ -244,6 +250,8 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2), dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2),
('DPGradientDescentVectorized 4', ('DPGradientDescentVectorized 4',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4), dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4),
('DPGradientDescentVectorized None',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, None),
) )
def testBaseline(self, cls, num_microbatches): def testBaseline(self, cls, num_microbatches):
"""Tests that DP optimizers work with tf.estimator.""" """Tests that DP optimizers work with tf.estimator."""

View file

@ -100,15 +100,20 @@ def make_vectorized_keras_optimizer_class(cls):
if not callable(var_list): if not callable(var_list):
tape.watch(var_list) tape.watch(var_list)
if callable(loss): loss = loss()
loss = loss() batch_size = tf.shape(input=loss)[0]
microbatch_losses = tf.reduce_mean( if self._num_microbatches is None:
tf.reshape(loss, [self._num_microbatches, -1]), axis=1) self._num_microbatches = batch_size
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
if callable(var_list): if callable(var_list):
var_list = var_list() var_list = var_list()
else: else:
with tape: with tape:
batch_size = tf.shape(input=loss)[0]
if self._num_microbatches is None:
self._num_microbatches = batch_size
microbatch_losses = tf.reduce_mean( microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [self._num_microbatches, -1]), axis=1) tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
@ -132,7 +137,8 @@ def make_vectorized_keras_optimizer_class(cls):
noised_gradient = tf.add(summed_gradient, noise) noised_gradient = tf.add(summed_gradient, noise)
# Normalize by number of microbatches and return. # Normalize by number of microbatches and return.
return tf.truediv(noised_gradient, self._num_microbatches) return tf.truediv(noised_gradient,
tf.cast(self._num_microbatches, tf.float32))
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch, final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
clipped_gradients) clipped_gradients)