Fix lint errors in dp_optimizer_test
.
PiperOrigin-RevId: 424183036
This commit is contained in:
parent
3a4c4400a6
commit
4b76e882bc
1 changed files with 5 additions and 3 deletions
|
@ -319,6 +319,8 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
)
|
||||
def testNoneGradients(self, cls, num_microbatches, expected_answer):
|
||||
"""Tests that optimizers can handle variables whose gradients are None."""
|
||||
del expected_answer # Unused.
|
||||
|
||||
with self.cached_session() as sess:
|
||||
var0 = tf.Variable([1.0, 2.0])
|
||||
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
||||
|
@ -338,7 +340,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
loss=self._loss(data0, var0), var_list=[var0, extra_variable])
|
||||
sess.run(minimize_op)
|
||||
|
||||
def _testWriteOutAndReload(self, optimizer_cls):
|
||||
def _test_write_out_and_reload(self, optimizer_cls):
|
||||
optimizer = optimizer_cls(
|
||||
l2_norm_clip=1.0, noise_multiplier=0.01, num_microbatches=1)
|
||||
|
||||
|
@ -365,12 +367,12 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
def testWriteOutAndReloadAdam(self):
|
||||
optimizer_class = dp_optimizer.make_gaussian_optimizer_class(
|
||||
tf.keras.optimizers.Adam)
|
||||
self._testWriteOutAndReload(optimizer_class)
|
||||
self._test_write_out_and_reload(optimizer_class)
|
||||
|
||||
def testWriteOutAndReloadSGD(self):
|
||||
optimizer_class = dp_optimizer.make_gaussian_optimizer_class(
|
||||
tf.keras.optimizers.SGD)
|
||||
self._testWriteOutAndReload(optimizer_class)
|
||||
self._test_write_out_and_reload(optimizer_class)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in a new issue