Time based indicator for restart query.

PiperOrigin-RevId: 401871582
This commit is contained in:
Zheng Xu 2021-10-08 15:40:27 -07:00 committed by A. Unique TensorFlower
parent 7426a4ec30
commit 27bb6e48d9
2 changed files with 79 additions and 1 deletions

View file

@ -103,6 +103,50 @@ class PeriodicRoundRestartIndicator(RestartIndicator):
return flag, state
class PeriodicTimeRestartIndicator(RestartIndicator):
"""Indicator for periodically resetting the tree state after a certain time.
The indicator will maintain a state to track the previous restart time.
"""
def __init__(self, period_seconds: float):
"""Construct the `PeriodicTimeRestartIndicator`.
Args:
period_seconds: The `next` function will return `True` if called after
`period_seconds`.
"""
if period_seconds <= 0:
raise ValueError('Restart period_seconds should be larger than 0, got '
f'{period_seconds}')
self.period_seconds = period_seconds
@tf.function
def initialize(self):
"""Returns initial time as state."""
return tf.timestamp()
@tf.function
def next(self, state):
"""Gets next bool indicator and advances the state.
Args:
state: The current state.
Returns:
A pair (value, new_state) where value is the bool indicator and new_state
of time.
"""
current_time = tf.timestamp()
current_period = current_time - state
reset_flag = tf.math.greater(
current_period,
tf.convert_to_tensor(self.period_seconds, current_period.dtype))
if reset_flag:
state = current_time
return reset_flag, state
class RestartQuery(dp_query.SumAggregationDPQuery):
"""`DPQuery` for `SumAggregationDPQuery` with a `reset_state` function."""

View file

@ -13,13 +13,14 @@
# limitations under the License.
"""Tests for `restart_query`."""
from absl.testing import parameterized
import mock
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import restart_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
class RoundRestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('zero', 0), ('negative', -1))
def test_round_raise(self, frequency):
@ -61,6 +62,39 @@ class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
self.assertFalse(flag)
class TimeRestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('zero', 0), ('negative', -1.))
def test_round_raise(self, secs):
with self.assertRaisesRegex(
ValueError, 'Restart period_seconds should be larger than 0'):
restart_query.PeriodicTimeRestartIndicator(secs)
def test_round_indicator(self):
indicator = restart_query.PeriodicTimeRestartIndicator(period_seconds=3600 *
23.5)
# TODO(b/193679963): use `tf.timestamp` as the default of a member of
# the `PeriodicTimeRestartIndicator` to unroll the mock test.
return_time = tf.Variable(
1627018868.452365) # 22:41pm PST 5:41am UTC, July 22, initialize
with mock.patch.object(
tf, 'timestamp', return_value=return_time) as mock_func:
time_stamps = [
1627022468.452365, # 23:41pm PST 5:41am UTC, July 22, 1 hr, False
1627105268.452365, # 23:41pm PST 5:41am UTC, July 23, 1 day, True
1627112468.452365, # 2 hr after restart, False
1627189508.452365, # 23.4 hr after restart, False
1627189904.452365, # 23.51 hr after restart, True
]
expected_values = [False, True, False, False, True]
state = indicator.initialize()
for v, t in zip(expected_values, time_stamps):
return_time.assign(t)
mock_func.return_value = return_time
flag, state = indicator.next(state)
self.assertEqual(v, flag.numpy())
def _get_l2_clip_fn():
def l2_clip_fn(record_as_list, clip_value):