From d6ad59226d64a5e47c46ea30d214d9c58ad1b198 Mon Sep 17 00:00:00 2001 From: Steve Chien Date: Thu, 18 Aug 2022 17:37:31 -0700 Subject: [PATCH] Update tests for optimizer classes to TF 2. PiperOrigin-RevId: 468587323 --- tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py | 5 +++++ .../privacy/optimizers/dp_optimizer_vectorized_test.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py index 3b1c04f..6699fb5 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py @@ -26,6 +26,11 @@ from tensorflow_privacy.privacy.optimizers import dp_optimizer class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): + @classmethod + def setUpClass(cls): + super(DPOptimizerTest, cls).setUpClass() + tf.compat.v1.disable_eager_execution() + def _loss(self, val0, val1): """Loss function that is minimized at the mean of the input points.""" return 0.5 * tf.reduce_sum( diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized_test.py index de3cc54..a1f81f6 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized_test.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized_test.py @@ -27,6 +27,11 @@ from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import Vector class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): + @classmethod + def setUpClass(cls): + super(DPOptimizerTest, cls).setUpClass() + tf.compat.v1.disable_eager_execution() + def _loss(self, val0, val1): """Loss function that is minimized at the mean of the input points.""" return 0.5 * tf.reduce_sum(