Remove the mock
dependency from TensorFlow Privacy, this is now part of the Python standard library.
PiperOrigin-RevId: 424681527
This commit is contained in:
parent
a749ce4e30
commit
81a11eb824
3 changed files with 10 additions and 8 deletions
|
@ -11,11 +11,12 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for `restart_query`."""
|
|
||||||
from absl.testing import parameterized
|
|
||||||
import mock
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.dp_query import restart_query
|
from tensorflow_privacy.privacy.dp_query import restart_query
|
||||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
|
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
|
||||||
|
|
||||||
|
@ -77,7 +78,7 @@ class TimeRestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
# the `PeriodicTimeRestartIndicator` to unroll the mock test.
|
# the `PeriodicTimeRestartIndicator` to unroll the mock test.
|
||||||
return_time = tf.Variable(
|
return_time = tf.Variable(
|
||||||
1627018868.452365) # 22:41pm PST 5:41am UTC, July 22, initialize
|
1627018868.452365) # 22:41pm PST 5:41am UTC, July 22, initialize
|
||||||
with mock.patch.object(
|
with unittest.mock.patch.object(
|
||||||
tf, 'timestamp', return_value=return_time) as mock_func:
|
tf, 'timestamp', return_value=return_time) as mock_func:
|
||||||
time_stamps = [
|
time_stamps = [
|
||||||
1627022468.452365, # 23:41pm PST 5:41am UTC, July 22, 1 hr, False
|
1627022468.452365, # 23:41pm PST 5:41am UTC, July 22, 1 hr, False
|
||||||
|
|
|
@ -13,9 +13,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import mock
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
@ -175,7 +175,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
||||||
self.assertNear(np.std(grads), 2.0 * 4.0 / num_microbatches, 0.5)
|
self.assertNear(np.std(grads), 2.0 * 4.0 / num_microbatches, 0.5)
|
||||||
|
|
||||||
@mock.patch('absl.logging.warning')
|
@unittest.mock.patch('absl.logging.warning')
|
||||||
def testComputeGradientsOverrideWarning(self, mock_logging):
|
def testComputeGradientsOverrideWarning(self, mock_logging):
|
||||||
|
|
||||||
class SimpleOptimizer(tf.train.Optimizer):
|
class SimpleOptimizer(tf.train.Optimizer):
|
||||||
|
|
|
@ -12,8 +12,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import mock
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
@ -112,7 +113,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
||||||
self.assertNear(np.std(grads), 4.0 * 8.0, 0.5)
|
self.assertNear(np.std(grads), 4.0 * 8.0, 0.5)
|
||||||
|
|
||||||
@mock.patch('absl.logging.warning')
|
@unittest.mock.patch('absl.logging.warning')
|
||||||
def testComputeGradientsOverrideWarning(self, mock_logging):
|
def testComputeGradientsOverrideWarning(self, mock_logging):
|
||||||
|
|
||||||
class SimpleOptimizer(tf.train.Optimizer):
|
class SimpleOptimizer(tf.train.Optimizer):
|
||||||
|
|
Loading…
Reference in a new issue