Update keras optimizers (both traditional and vectorized) to handle case of num_microbatches=None.
PiperOrigin-RevId: 369497296
This commit is contained in:
parent
41530f4426
commit
755ed26671
3 changed files with 30 additions and 10 deletions
|
@ -87,15 +87,20 @@ def make_keras_optimizer_class(cls):
|
||||||
if not callable(var_list):
|
if not callable(var_list):
|
||||||
tape.watch(var_list)
|
tape.watch(var_list)
|
||||||
|
|
||||||
if callable(loss):
|
loss = loss()
|
||||||
loss = loss()
|
batch_size = tf.shape(input=loss)[0]
|
||||||
microbatch_losses = tf.reduce_mean(
|
if self._num_microbatches is None:
|
||||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
self._num_microbatches = batch_size
|
||||||
|
microbatch_losses = tf.reduce_mean(
|
||||||
|
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
||||||
|
|
||||||
if callable(var_list):
|
if callable(var_list):
|
||||||
var_list = var_list()
|
var_list = var_list()
|
||||||
else:
|
else:
|
||||||
with tape:
|
with tape:
|
||||||
|
batch_size = tf.shape(input=loss)[0]
|
||||||
|
if self._num_microbatches is None:
|
||||||
|
self._num_microbatches = batch_size
|
||||||
microbatch_losses = tf.reduce_mean(
|
microbatch_losses = tf.reduce_mean(
|
||||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
||||||
|
|
||||||
|
@ -122,7 +127,8 @@ def make_keras_optimizer_class(cls):
|
||||||
noised_gradient = tf.add(summed_gradient, noise)
|
noised_gradient = tf.add(summed_gradient, noise)
|
||||||
|
|
||||||
# Normalize by number of microbatches and return.
|
# Normalize by number of microbatches and return.
|
||||||
return tf.truediv(noised_gradient, self._num_microbatches)
|
return tf.truediv(noised_gradient,
|
||||||
|
tf.cast(self._num_microbatches, tf.float32))
|
||||||
|
|
||||||
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
|
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
|
||||||
clipped_gradients)
|
clipped_gradients)
|
||||||
|
|
|
@ -51,6 +51,9 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
('DPAdagradVectorized 4',
|
('DPAdagradVectorized 4',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4,
|
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4,
|
||||||
[-2.5, -2.5], [-0.5]),
|
[-2.5, -2.5], [-0.5]),
|
||||||
|
('DPAdagradVectorized None',
|
||||||
|
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None,
|
||||||
|
[-2.5, -2.5], [-0.5]),
|
||||||
)
|
)
|
||||||
def testBaselineWithCallableLoss(self, cls, num_microbatches, expected_grad0,
|
def testBaselineWithCallableLoss(self, cls, num_microbatches, expected_grad0,
|
||||||
expected_grad1):
|
expected_grad1):
|
||||||
|
@ -89,6 +92,9 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
('DPAdagradVectorized 4',
|
('DPAdagradVectorized 4',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4,
|
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4,
|
||||||
[-2.5, -2.5], [-0.5]),
|
[-2.5, -2.5], [-0.5]),
|
||||||
|
('DPAdagradVectorized None',
|
||||||
|
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None,
|
||||||
|
[-2.5, -2.5], [-0.5]),
|
||||||
)
|
)
|
||||||
def testBaselineWithTensorLoss(self, cls, num_microbatches, expected_grad0,
|
def testBaselineWithTensorLoss(self, cls, num_microbatches, expected_grad0,
|
||||||
expected_grad1):
|
expected_grad1):
|
||||||
|
@ -244,6 +250,8 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2),
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2),
|
||||||
('DPGradientDescentVectorized 4',
|
('DPGradientDescentVectorized 4',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4),
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4),
|
||||||
|
('DPGradientDescentVectorized None',
|
||||||
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, None),
|
||||||
)
|
)
|
||||||
def testBaseline(self, cls, num_microbatches):
|
def testBaseline(self, cls, num_microbatches):
|
||||||
"""Tests that DP optimizers work with tf.estimator."""
|
"""Tests that DP optimizers work with tf.estimator."""
|
||||||
|
|
|
@ -100,15 +100,20 @@ def make_vectorized_keras_optimizer_class(cls):
|
||||||
if not callable(var_list):
|
if not callable(var_list):
|
||||||
tape.watch(var_list)
|
tape.watch(var_list)
|
||||||
|
|
||||||
if callable(loss):
|
loss = loss()
|
||||||
loss = loss()
|
batch_size = tf.shape(input=loss)[0]
|
||||||
microbatch_losses = tf.reduce_mean(
|
if self._num_microbatches is None:
|
||||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
self._num_microbatches = batch_size
|
||||||
|
microbatch_losses = tf.reduce_mean(
|
||||||
|
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
||||||
|
|
||||||
if callable(var_list):
|
if callable(var_list):
|
||||||
var_list = var_list()
|
var_list = var_list()
|
||||||
else:
|
else:
|
||||||
with tape:
|
with tape:
|
||||||
|
batch_size = tf.shape(input=loss)[0]
|
||||||
|
if self._num_microbatches is None:
|
||||||
|
self._num_microbatches = batch_size
|
||||||
microbatch_losses = tf.reduce_mean(
|
microbatch_losses = tf.reduce_mean(
|
||||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
||||||
|
|
||||||
|
@ -132,7 +137,8 @@ def make_vectorized_keras_optimizer_class(cls):
|
||||||
noised_gradient = tf.add(summed_gradient, noise)
|
noised_gradient = tf.add(summed_gradient, noise)
|
||||||
|
|
||||||
# Normalize by number of microbatches and return.
|
# Normalize by number of microbatches and return.
|
||||||
return tf.truediv(noised_gradient, self._num_microbatches)
|
return tf.truediv(noised_gradient,
|
||||||
|
tf.cast(self._num_microbatches, tf.float32))
|
||||||
|
|
||||||
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
|
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
|
||||||
clipped_gradients)
|
clipped_gradients)
|
||||||
|
|
Loading…
Reference in a new issue