Remove the mock dependency from TensorFlow Privacy, this is now part of the Python standard library.

PiperOrigin-RevId: 424681527
This commit is contained in:
Michael Reneer 2022-01-27 12:34:33 -08:00 committed by A. Unique TensorFlower
parent a749ce4e30
commit 81a11eb824
3 changed files with 10 additions and 8 deletions

View file

@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for `restart_query`."""
from absl.testing import parameterized
import mock
import unittest
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import restart_query from tensorflow_privacy.privacy.dp_query import restart_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
@ -77,7 +78,7 @@ class TimeRestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
# the `PeriodicTimeRestartIndicator` to unroll the mock test. # the `PeriodicTimeRestartIndicator` to unroll the mock test.
return_time = tf.Variable( return_time = tf.Variable(
1627018868.452365) # 22:41pm PST 5:41am UTC, July 22, initialize 1627018868.452365) # 22:41pm PST 5:41am UTC, July 22, initialize
with mock.patch.object( with unittest.mock.patch.object(
tf, 'timestamp', return_value=return_time) as mock_func: tf, 'timestamp', return_value=return_time) as mock_func:
time_stamps = [ time_stamps = [
1627022468.452365, # 23:41pm PST 5:41am UTC, July 22, 1 hr, False 1627022468.452365, # 23:41pm PST 5:41am UTC, July 22, 1 hr, False

View file

@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
import os import os
import unittest
from absl.testing import parameterized from absl.testing import parameterized
import mock
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
@ -175,7 +175,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
# Test standard deviation is close to l2_norm_clip * noise_multiplier. # Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear(np.std(grads), 2.0 * 4.0 / num_microbatches, 0.5) self.assertNear(np.std(grads), 2.0 * 4.0 / num_microbatches, 0.5)
@mock.patch('absl.logging.warning') @unittest.mock.patch('absl.logging.warning')
def testComputeGradientsOverrideWarning(self, mock_logging): def testComputeGradientsOverrideWarning(self, mock_logging):
class SimpleOptimizer(tf.train.Optimizer): class SimpleOptimizer(tf.train.Optimizer):

View file

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest
from absl.testing import parameterized from absl.testing import parameterized
import mock
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
@ -112,7 +113,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
# Test standard deviation is close to l2_norm_clip * noise_multiplier. # Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear(np.std(grads), 4.0 * 8.0, 0.5) self.assertNear(np.std(grads), 4.0 * 8.0, 0.5)
@mock.patch('absl.logging.warning') @unittest.mock.patch('absl.logging.warning')
def testComputeGradientsOverrideWarning(self, mock_logging): def testComputeGradientsOverrideWarning(self, mock_logging):
class SimpleOptimizer(tf.train.Optimizer): class SimpleOptimizer(tf.train.Optimizer):