Fix lint errors in dp_optimizer_test
.
PiperOrigin-RevId: 424183036
This commit is contained in:
parent
3a4c4400a6
commit
4b76e882bc
1 changed files with 5 additions and 3 deletions
|
@ -319,6 +319,8 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
)
|
)
|
||||||
def testNoneGradients(self, cls, num_microbatches, expected_answer):
|
def testNoneGradients(self, cls, num_microbatches, expected_answer):
|
||||||
"""Tests that optimizers can handle variables whose gradients are None."""
|
"""Tests that optimizers can handle variables whose gradients are None."""
|
||||||
|
del expected_answer # Unused.
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
var0 = tf.Variable([1.0, 2.0])
|
var0 = tf.Variable([1.0, 2.0])
|
||||||
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
||||||
|
@ -338,7 +340,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
loss=self._loss(data0, var0), var_list=[var0, extra_variable])
|
loss=self._loss(data0, var0), var_list=[var0, extra_variable])
|
||||||
sess.run(minimize_op)
|
sess.run(minimize_op)
|
||||||
|
|
||||||
def _testWriteOutAndReload(self, optimizer_cls):
|
def _test_write_out_and_reload(self, optimizer_cls):
|
||||||
optimizer = optimizer_cls(
|
optimizer = optimizer_cls(
|
||||||
l2_norm_clip=1.0, noise_multiplier=0.01, num_microbatches=1)
|
l2_norm_clip=1.0, noise_multiplier=0.01, num_microbatches=1)
|
||||||
|
|
||||||
|
@ -365,12 +367,12 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
def testWriteOutAndReloadAdam(self):
|
def testWriteOutAndReloadAdam(self):
|
||||||
optimizer_class = dp_optimizer.make_gaussian_optimizer_class(
|
optimizer_class = dp_optimizer.make_gaussian_optimizer_class(
|
||||||
tf.keras.optimizers.Adam)
|
tf.keras.optimizers.Adam)
|
||||||
self._testWriteOutAndReload(optimizer_class)
|
self._test_write_out_and_reload(optimizer_class)
|
||||||
|
|
||||||
def testWriteOutAndReloadSGD(self):
|
def testWriteOutAndReloadSGD(self):
|
||||||
optimizer_class = dp_optimizer.make_gaussian_optimizer_class(
|
optimizer_class = dp_optimizer.make_gaussian_optimizer_class(
|
||||||
tf.keras.optimizers.SGD)
|
tf.keras.optimizers.SGD)
|
||||||
self._testWriteOutAndReload(optimizer_class)
|
self._test_write_out_and_reload(optimizer_class)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in a new issue