From 27bb6e48d9b218c642def488befa38ddd83511a0 Mon Sep 17 00:00:00 2001 From: Zheng Xu Date: Fri, 8 Oct 2021 15:40:27 -0700 Subject: [PATCH] Time based indicator for restart query. PiperOrigin-RevId: 401871582 --- .../privacy/dp_query/restart_query.py | 44 +++++++++++++++++++ .../privacy/dp_query/restart_query_test.py | 36 ++++++++++++++- 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/tensorflow_privacy/privacy/dp_query/restart_query.py b/tensorflow_privacy/privacy/dp_query/restart_query.py index 8d1ff8c..38c08f9 100644 --- a/tensorflow_privacy/privacy/dp_query/restart_query.py +++ b/tensorflow_privacy/privacy/dp_query/restart_query.py @@ -103,6 +103,50 @@ class PeriodicRoundRestartIndicator(RestartIndicator): return flag, state +class PeriodicTimeRestartIndicator(RestartIndicator): + """Indicator for periodically resetting the tree state after a certain time. + + The indicator will maintain a state to track the previous restart time. + """ + + def __init__(self, period_seconds: float): + """Construct the `PeriodicTimeRestartIndicator`. + + Args: + period_seconds: The `next` function will return `True` if called after + `period_seconds`. + """ + if period_seconds <= 0: + raise ValueError('Restart period_seconds should be larger than 0, got ' + f'{period_seconds}') + self.period_seconds = period_seconds + + @tf.function + def initialize(self): + """Returns initial time as state.""" + return tf.timestamp() + + @tf.function + def next(self, state): + """Gets next bool indicator and advances the state. + + Args: + state: The current state. + + Returns: + A pair (value, new_state) where value is the bool indicator and new_state + of time. + """ + current_time = tf.timestamp() + current_period = current_time - state + reset_flag = tf.math.greater( + current_period, + tf.convert_to_tensor(self.period_seconds, current_period.dtype)) + if reset_flag: + state = current_time + return reset_flag, state + + class RestartQuery(dp_query.SumAggregationDPQuery): """`DPQuery` for `SumAggregationDPQuery` with a `reset_state` function.""" diff --git a/tensorflow_privacy/privacy/dp_query/restart_query_test.py b/tensorflow_privacy/privacy/dp_query/restart_query_test.py index bf6c374..ce05ed2 100644 --- a/tensorflow_privacy/privacy/dp_query/restart_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/restart_query_test.py @@ -13,13 +13,14 @@ # limitations under the License. """Tests for `restart_query`.""" from absl.testing import parameterized +import mock import tensorflow as tf from tensorflow_privacy.privacy.dp_query import restart_query from tensorflow_privacy.privacy.dp_query import tree_aggregation_query -class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase): +class RoundRestartIndicatorTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters(('zero', 0), ('negative', -1)) def test_round_raise(self, frequency): @@ -61,6 +62,39 @@ class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase): self.assertFalse(flag) +class TimeRestartIndicatorTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters(('zero', 0), ('negative', -1.)) + def test_round_raise(self, secs): + with self.assertRaisesRegex( + ValueError, 'Restart period_seconds should be larger than 0'): + restart_query.PeriodicTimeRestartIndicator(secs) + + def test_round_indicator(self): + indicator = restart_query.PeriodicTimeRestartIndicator(period_seconds=3600 * + 23.5) + # TODO(b/193679963): use `tf.timestamp` as the default of a member of + # the `PeriodicTimeRestartIndicator` to unroll the mock test. + return_time = tf.Variable( + 1627018868.452365) # 22:41pm PST 5:41am UTC, July 22, initialize + with mock.patch.object( + tf, 'timestamp', return_value=return_time) as mock_func: + time_stamps = [ + 1627022468.452365, # 23:41pm PST 5:41am UTC, July 22, 1 hr, False + 1627105268.452365, # 23:41pm PST 5:41am UTC, July 23, 1 day, True + 1627112468.452365, # 2 hr after restart, False + 1627189508.452365, # 23.4 hr after restart, False + 1627189904.452365, # 23.51 hr after restart, True + ] + expected_values = [False, True, False, False, True] + state = indicator.initialize() + for v, t in zip(expected_values, time_stamps): + return_time.assign(t) + mock_func.return_value = return_time + flag, state = indicator.next(state) + self.assertEqual(v, flag.numpy()) + + def _get_l2_clip_fn(): def l2_clip_fn(record_as_list, clip_value):