diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py index 3b1c04f..6699fb5 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py @@ -26,6 +26,11 @@ from tensorflow_privacy.privacy.optimizers import dp_optimizer class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): + @classmethod + def setUpClass(cls): + super(DPOptimizerTest, cls).setUpClass() + tf.compat.v1.disable_eager_execution() + def _loss(self, val0, val1): """Loss function that is minimized at the mean of the input points.""" return 0.5 * tf.reduce_sum( diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized_test.py index de3cc54..a1f81f6 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized_test.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized_test.py @@ -27,6 +27,11 @@ from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import Vector class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): + @classmethod + def setUpClass(cls): + super(DPOptimizerTest, cls).setUpClass() + tf.compat.v1.disable_eager_execution() + def _loss(self, val0, val1): """Loss function that is minimized at the mean of the input points.""" return 0.5 * tf.reduce_sum(