From 81a11eb824e6d6ce0e1d07193dcb39f4bb595102 Mon Sep 17 00:00:00 2001 From: Michael Reneer Date: Thu, 27 Jan 2022 12:34:33 -0800 Subject: [PATCH] Remove the `mock` dependency from TensorFlow Privacy, this is now part of the Python standard library. PiperOrigin-RevId: 424681527 --- .../privacy/dp_query/restart_query_test.py | 9 +++++---- .../privacy/optimizers/dp_optimizer_test.py | 4 ++-- .../privacy/optimizers/dp_optimizer_vectorized_test.py | 5 +++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tensorflow_privacy/privacy/dp_query/restart_query_test.py b/tensorflow_privacy/privacy/dp_query/restart_query_test.py index 1ce303a..fe8a440 100644 --- a/tensorflow_privacy/privacy/dp_query/restart_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/restart_query_test.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for `restart_query`.""" -from absl.testing import parameterized -import mock +import unittest + +from absl.testing import parameterized import tensorflow as tf + from tensorflow_privacy.privacy.dp_query import restart_query from tensorflow_privacy.privacy.dp_query import tree_aggregation_query @@ -77,7 +78,7 @@ class TimeRestartIndicatorTest(tf.test.TestCase, parameterized.TestCase): # the `PeriodicTimeRestartIndicator` to unroll the mock test. return_time = tf.Variable( 1627018868.452365) # 22:41pm PST 5:41am UTC, July 22, initialize - with mock.patch.object( + with unittest.mock.patch.object( tf, 'timestamp', return_value=return_time) as mock_func: time_stamps = [ 1627022468.452365, # 23:41pm PST 5:41am UTC, July 22, 1 hr, False diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py index 611464a..4c854d1 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py @@ -13,9 +13,9 @@ # limitations under the License. import os +import unittest from absl.testing import parameterized -import mock import numpy as np import tensorflow.compat.v1 as tf @@ -175,7 +175,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): # Test standard deviation is close to l2_norm_clip * noise_multiplier. self.assertNear(np.std(grads), 2.0 * 4.0 / num_microbatches, 0.5) - @mock.patch('absl.logging.warning') + @unittest.mock.patch('absl.logging.warning') def testComputeGradientsOverrideWarning(self, mock_logging): class SimpleOptimizer(tf.train.Optimizer): diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized_test.py index 6d326fd..977d3bd 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized_test.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized_test.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + from absl.testing import parameterized -import mock import numpy as np import tensorflow.compat.v1 as tf @@ -112,7 +113,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): # Test standard deviation is close to l2_norm_clip * noise_multiplier. self.assertNear(np.std(grads), 4.0 * 8.0, 0.5) - @mock.patch('absl.logging.warning') + @unittest.mock.patch('absl.logging.warning') def testComputeGradientsOverrideWarning(self, mock_logging): class SimpleOptimizer(tf.train.Optimizer):