Update tests for optimizer classes to TF 2.

PiperOrigin-RevId: 468587323
This commit is contained in:
Steve Chien 2022-08-18 17:37:31 -07:00 committed by A. Unique TensorFlower
parent 5dd11fcdd6
commit d6ad59226d
2 changed files with 10 additions and 0 deletions

View file

@ -26,6 +26,11 @@ from tensorflow_privacy.privacy.optimizers import dp_optimizer
class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
@classmethod
def setUpClass(cls):
super(DPOptimizerTest, cls).setUpClass()
tf.compat.v1.disable_eager_execution()
def _loss(self, val0, val1): def _loss(self, val0, val1):
"""Loss function that is minimized at the mean of the input points.""" """Loss function that is minimized at the mean of the input points."""
return 0.5 * tf.reduce_sum( return 0.5 * tf.reduce_sum(

View file

@ -27,6 +27,11 @@ from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import Vector
class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
@classmethod
def setUpClass(cls):
super(DPOptimizerTest, cls).setUpClass()
tf.compat.v1.disable_eager_execution()
def _loss(self, val0, val1): def _loss(self, val0, val1):
"""Loss function that is minimized at the mean of the input points.""" """Loss function that is minimized at the mean of the input points."""
return 0.5 * tf.reduce_sum( return 0.5 * tf.reduce_sum(