Fix lint errors in dp_optimizer_test.

PiperOrigin-RevId: 424183036
This commit is contained in:
Michael Reneer 2022-01-25 14:36:28 -08:00 committed by A. Unique TensorFlower
parent 3a4c4400a6
commit 4b76e882bc

View file

@ -319,6 +319,8 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
) )
def testNoneGradients(self, cls, num_microbatches, expected_answer): def testNoneGradients(self, cls, num_microbatches, expected_answer):
"""Tests that optimizers can handle variables whose gradients are None.""" """Tests that optimizers can handle variables whose gradients are None."""
del expected_answer # Unused.
with self.cached_session() as sess: with self.cached_session() as sess:
var0 = tf.Variable([1.0, 2.0]) var0 = tf.Variable([1.0, 2.0])
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
@ -338,7 +340,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
loss=self._loss(data0, var0), var_list=[var0, extra_variable]) loss=self._loss(data0, var0), var_list=[var0, extra_variable])
sess.run(minimize_op) sess.run(minimize_op)
def _testWriteOutAndReload(self, optimizer_cls): def _test_write_out_and_reload(self, optimizer_cls):
optimizer = optimizer_cls( optimizer = optimizer_cls(
l2_norm_clip=1.0, noise_multiplier=0.01, num_microbatches=1) l2_norm_clip=1.0, noise_multiplier=0.01, num_microbatches=1)
@ -365,12 +367,12 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
def testWriteOutAndReloadAdam(self): def testWriteOutAndReloadAdam(self):
optimizer_class = dp_optimizer.make_gaussian_optimizer_class( optimizer_class = dp_optimizer.make_gaussian_optimizer_class(
tf.keras.optimizers.Adam) tf.keras.optimizers.Adam)
self._testWriteOutAndReload(optimizer_class) self._test_write_out_and_reload(optimizer_class)
def testWriteOutAndReloadSGD(self): def testWriteOutAndReloadSGD(self):
optimizer_class = dp_optimizer.make_gaussian_optimizer_class( optimizer_class = dp_optimizer.make_gaussian_optimizer_class(
tf.keras.optimizers.SGD) tf.keras.optimizers.SGD)
self._testWriteOutAndReload(optimizer_class) self._test_write_out_and_reload(optimizer_class)
if __name__ == '__main__': if __name__ == '__main__':