Remove the mock
dependency from TensorFlow Privacy, this is now part of the Python standard library.
PiperOrigin-RevId: 424681527
This commit is contained in:
parent
a749ce4e30
commit
81a11eb824
3 changed files with 10 additions and 8 deletions
|
@ -11,11 +11,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for `restart_query`."""
|
||||
from absl.testing import parameterized
|
||||
import mock
|
||||
|
||||
import unittest
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.dp_query import restart_query
|
||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
|
||||
|
||||
|
@ -77,7 +78,7 @@ class TimeRestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
# the `PeriodicTimeRestartIndicator` to unroll the mock test.
|
||||
return_time = tf.Variable(
|
||||
1627018868.452365) # 22:41pm PST 5:41am UTC, July 22, initialize
|
||||
with mock.patch.object(
|
||||
with unittest.mock.patch.object(
|
||||
tf, 'timestamp', return_value=return_time) as mock_func:
|
||||
time_stamps = [
|
||||
1627022468.452365, # 23:41pm PST 5:41am UTC, July 22, 1 hr, False
|
||||
|
|
|
@ -13,9 +13,9 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from absl.testing import parameterized
|
||||
import mock
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
@ -175,7 +175,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
||||
self.assertNear(np.std(grads), 2.0 * 4.0 / num_microbatches, 0.5)
|
||||
|
||||
@mock.patch('absl.logging.warning')
|
||||
@unittest.mock.patch('absl.logging.warning')
|
||||
def testComputeGradientsOverrideWarning(self, mock_logging):
|
||||
|
||||
class SimpleOptimizer(tf.train.Optimizer):
|
||||
|
|
|
@ -12,8 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from absl.testing import parameterized
|
||||
import mock
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
@ -112,7 +113,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
||||
self.assertNear(np.std(grads), 4.0 * 8.0, 0.5)
|
||||
|
||||
@mock.patch('absl.logging.warning')
|
||||
@unittest.mock.patch('absl.logging.warning')
|
||||
def testComputeGradientsOverrideWarning(self, mock_logging):
|
||||
|
||||
class SimpleOptimizer(tf.train.Optimizer):
|
||||
|
|
Loading…
Reference in a new issue