Update keras optimizers (both traditional and vectorized) to handle case of num_microbatches=None.

PiperOrigin-RevId: 369497296
This commit is contained in:
Steve Chien 2021-04-20 12:35:00 -07:00 committed by A. Unique TensorFlower
parent 41530f4426
commit 755ed26671
3 changed files with 30 additions and 10 deletions

View file

@ -87,8 +87,10 @@ def make_keras_optimizer_class(cls):
if not callable(var_list):
tape.watch(var_list)
if callable(loss):
loss = loss()
batch_size = tf.shape(input=loss)[0]
if self._num_microbatches is None:
self._num_microbatches = batch_size
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
@ -96,6 +98,9 @@ def make_keras_optimizer_class(cls):
var_list = var_list()
else:
with tape:
batch_size = tf.shape(input=loss)[0]
if self._num_microbatches is None:
self._num_microbatches = batch_size
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
@ -122,7 +127,8 @@ def make_keras_optimizer_class(cls):
noised_gradient = tf.add(summed_gradient, noise)
# Normalize by number of microbatches and return.
return tf.truediv(noised_gradient, self._num_microbatches)
return tf.truediv(noised_gradient,
tf.cast(self._num_microbatches, tf.float32))
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
clipped_gradients)

View file

@ -51,6 +51,9 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
('DPAdagradVectorized 4',
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4,
[-2.5, -2.5], [-0.5]),
('DPAdagradVectorized None',
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None,
[-2.5, -2.5], [-0.5]),
)
def testBaselineWithCallableLoss(self, cls, num_microbatches, expected_grad0,
expected_grad1):
@ -89,6 +92,9 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
('DPAdagradVectorized 4',
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4,
[-2.5, -2.5], [-0.5]),
('DPAdagradVectorized None',
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None,
[-2.5, -2.5], [-0.5]),
)
def testBaselineWithTensorLoss(self, cls, num_microbatches, expected_grad0,
expected_grad1):
@ -244,6 +250,8 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2),
('DPGradientDescentVectorized 4',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4),
('DPGradientDescentVectorized None',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, None),
)
def testBaseline(self, cls, num_microbatches):
"""Tests that DP optimizers work with tf.estimator."""

View file

@ -100,8 +100,10 @@ def make_vectorized_keras_optimizer_class(cls):
if not callable(var_list):
tape.watch(var_list)
if callable(loss):
loss = loss()
batch_size = tf.shape(input=loss)[0]
if self._num_microbatches is None:
self._num_microbatches = batch_size
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
@ -109,6 +111,9 @@ def make_vectorized_keras_optimizer_class(cls):
var_list = var_list()
else:
with tape:
batch_size = tf.shape(input=loss)[0]
if self._num_microbatches is None:
self._num_microbatches = batch_size
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
@ -132,7 +137,8 @@ def make_vectorized_keras_optimizer_class(cls):
noised_gradient = tf.add(summed_gradient, noise)
# Normalize by number of microbatches and return.
return tf.truediv(noised_gradient, self._num_microbatches)
return tf.truediv(noised_gradient,
tf.cast(self._num_microbatches, tf.float32))
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
clipped_gradients)