Update tests for optimizer classes to TF 2.
PiperOrigin-RevId: 468587323
This commit is contained in:
parent
5dd11fcdd6
commit
d6ad59226d
2 changed files with 10 additions and 0 deletions
|
@ -26,6 +26,11 @@ from tensorflow_privacy.privacy.optimizers import dp_optimizer
|
|||
|
||||
class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(DPOptimizerTest, cls).setUpClass()
|
||||
tf.compat.v1.disable_eager_execution()
|
||||
|
||||
def _loss(self, val0, val1):
|
||||
"""Loss function that is minimized at the mean of the input points."""
|
||||
return 0.5 * tf.reduce_sum(
|
||||
|
|
|
@ -27,6 +27,11 @@ from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import Vector
|
|||
|
||||
class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(DPOptimizerTest, cls).setUpClass()
|
||||
tf.compat.v1.disable_eager_execution()
|
||||
|
||||
def _loss(self, val0, val1):
|
||||
"""Loss function that is minimized at the mean of the input points."""
|
||||
return 0.5 * tf.reduce_sum(
|
||||
|
|
Loading…
Reference in a new issue