Update keras optimizers (both traditional and vectorized) to handle case of num_microbatches=None.
PiperOrigin-RevId: 369497296
This commit is contained in:
parent
41530f4426
commit
755ed26671
3 changed files with 30 additions and 10 deletions
|
@ -87,8 +87,10 @@ def make_keras_optimizer_class(cls):
|
|||
if not callable(var_list):
|
||||
tape.watch(var_list)
|
||||
|
||||
if callable(loss):
|
||||
loss = loss()
|
||||
batch_size = tf.shape(input=loss)[0]
|
||||
if self._num_microbatches is None:
|
||||
self._num_microbatches = batch_size
|
||||
microbatch_losses = tf.reduce_mean(
|
||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
||||
|
||||
|
@ -96,6 +98,9 @@ def make_keras_optimizer_class(cls):
|
|||
var_list = var_list()
|
||||
else:
|
||||
with tape:
|
||||
batch_size = tf.shape(input=loss)[0]
|
||||
if self._num_microbatches is None:
|
||||
self._num_microbatches = batch_size
|
||||
microbatch_losses = tf.reduce_mean(
|
||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
||||
|
||||
|
@ -122,7 +127,8 @@ def make_keras_optimizer_class(cls):
|
|||
noised_gradient = tf.add(summed_gradient, noise)
|
||||
|
||||
# Normalize by number of microbatches and return.
|
||||
return tf.truediv(noised_gradient, self._num_microbatches)
|
||||
return tf.truediv(noised_gradient,
|
||||
tf.cast(self._num_microbatches, tf.float32))
|
||||
|
||||
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
|
||||
clipped_gradients)
|
||||
|
|
|
@ -51,6 +51,9 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
('DPAdagradVectorized 4',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4,
|
||||
[-2.5, -2.5], [-0.5]),
|
||||
('DPAdagradVectorized None',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None,
|
||||
[-2.5, -2.5], [-0.5]),
|
||||
)
|
||||
def testBaselineWithCallableLoss(self, cls, num_microbatches, expected_grad0,
|
||||
expected_grad1):
|
||||
|
@ -89,6 +92,9 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
('DPAdagradVectorized 4',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4,
|
||||
[-2.5, -2.5], [-0.5]),
|
||||
('DPAdagradVectorized None',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None,
|
||||
[-2.5, -2.5], [-0.5]),
|
||||
)
|
||||
def testBaselineWithTensorLoss(self, cls, num_microbatches, expected_grad0,
|
||||
expected_grad1):
|
||||
|
@ -244,6 +250,8 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2),
|
||||
('DPGradientDescentVectorized 4',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4),
|
||||
('DPGradientDescentVectorized None',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, None),
|
||||
)
|
||||
def testBaseline(self, cls, num_microbatches):
|
||||
"""Tests that DP optimizers work with tf.estimator."""
|
||||
|
|
|
@ -100,8 +100,10 @@ def make_vectorized_keras_optimizer_class(cls):
|
|||
if not callable(var_list):
|
||||
tape.watch(var_list)
|
||||
|
||||
if callable(loss):
|
||||
loss = loss()
|
||||
batch_size = tf.shape(input=loss)[0]
|
||||
if self._num_microbatches is None:
|
||||
self._num_microbatches = batch_size
|
||||
microbatch_losses = tf.reduce_mean(
|
||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
||||
|
||||
|
@ -109,6 +111,9 @@ def make_vectorized_keras_optimizer_class(cls):
|
|||
var_list = var_list()
|
||||
else:
|
||||
with tape:
|
||||
batch_size = tf.shape(input=loss)[0]
|
||||
if self._num_microbatches is None:
|
||||
self._num_microbatches = batch_size
|
||||
microbatch_losses = tf.reduce_mean(
|
||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
||||
|
||||
|
@ -132,7 +137,8 @@ def make_vectorized_keras_optimizer_class(cls):
|
|||
noised_gradient = tf.add(summed_gradient, noise)
|
||||
|
||||
# Normalize by number of microbatches and return.
|
||||
return tf.truediv(noised_gradient, self._num_microbatches)
|
||||
return tf.truediv(noised_gradient,
|
||||
tf.cast(self._num_microbatches, tf.float32))
|
||||
|
||||
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
|
||||
clipped_gradients)
|
||||
|
|
Loading…
Reference in a new issue