forked from 626_privacy/tensorflow_privacy
Time based indicator for restart query.
PiperOrigin-RevId: 401871582
This commit is contained in:
parent
7426a4ec30
commit
27bb6e48d9
2 changed files with 79 additions and 1 deletions
|
@ -103,6 +103,50 @@ class PeriodicRoundRestartIndicator(RestartIndicator):
|
|||
return flag, state
|
||||
|
||||
|
||||
class PeriodicTimeRestartIndicator(RestartIndicator):
|
||||
"""Indicator for periodically resetting the tree state after a certain time.
|
||||
|
||||
The indicator will maintain a state to track the previous restart time.
|
||||
"""
|
||||
|
||||
def __init__(self, period_seconds: float):
|
||||
"""Construct the `PeriodicTimeRestartIndicator`.
|
||||
|
||||
Args:
|
||||
period_seconds: The `next` function will return `True` if called after
|
||||
`period_seconds`.
|
||||
"""
|
||||
if period_seconds <= 0:
|
||||
raise ValueError('Restart period_seconds should be larger than 0, got '
|
||||
f'{period_seconds}')
|
||||
self.period_seconds = period_seconds
|
||||
|
||||
@tf.function
|
||||
def initialize(self):
|
||||
"""Returns initial time as state."""
|
||||
return tf.timestamp()
|
||||
|
||||
@tf.function
|
||||
def next(self, state):
|
||||
"""Gets next bool indicator and advances the state.
|
||||
|
||||
Args:
|
||||
state: The current state.
|
||||
|
||||
Returns:
|
||||
A pair (value, new_state) where value is the bool indicator and new_state
|
||||
of time.
|
||||
"""
|
||||
current_time = tf.timestamp()
|
||||
current_period = current_time - state
|
||||
reset_flag = tf.math.greater(
|
||||
current_period,
|
||||
tf.convert_to_tensor(self.period_seconds, current_period.dtype))
|
||||
if reset_flag:
|
||||
state = current_time
|
||||
return reset_flag, state
|
||||
|
||||
|
||||
class RestartQuery(dp_query.SumAggregationDPQuery):
|
||||
"""`DPQuery` for `SumAggregationDPQuery` with a `reset_state` function."""
|
||||
|
||||
|
|
|
@ -13,13 +13,14 @@
|
|||
# limitations under the License.
|
||||
"""Tests for `restart_query`."""
|
||||
from absl.testing import parameterized
|
||||
import mock
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.dp_query import restart_query
|
||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
|
||||
|
||||
|
||||
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||
class RoundRestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(('zero', 0), ('negative', -1))
|
||||
def test_round_raise(self, frequency):
|
||||
|
@ -61,6 +62,39 @@ class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertFalse(flag)
|
||||
|
||||
|
||||
class TimeRestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(('zero', 0), ('negative', -1.))
|
||||
def test_round_raise(self, secs):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Restart period_seconds should be larger than 0'):
|
||||
restart_query.PeriodicTimeRestartIndicator(secs)
|
||||
|
||||
def test_round_indicator(self):
|
||||
indicator = restart_query.PeriodicTimeRestartIndicator(period_seconds=3600 *
|
||||
23.5)
|
||||
# TODO(b/193679963): use `tf.timestamp` as the default of a member of
|
||||
# the `PeriodicTimeRestartIndicator` to unroll the mock test.
|
||||
return_time = tf.Variable(
|
||||
1627018868.452365) # 22:41pm PST 5:41am UTC, July 22, initialize
|
||||
with mock.patch.object(
|
||||
tf, 'timestamp', return_value=return_time) as mock_func:
|
||||
time_stamps = [
|
||||
1627022468.452365, # 23:41pm PST 5:41am UTC, July 22, 1 hr, False
|
||||
1627105268.452365, # 23:41pm PST 5:41am UTC, July 23, 1 day, True
|
||||
1627112468.452365, # 2 hr after restart, False
|
||||
1627189508.452365, # 23.4 hr after restart, False
|
||||
1627189904.452365, # 23.51 hr after restart, True
|
||||
]
|
||||
expected_values = [False, True, False, False, True]
|
||||
state = indicator.initialize()
|
||||
for v, t in zip(expected_values, time_stamps):
|
||||
return_time.assign(t)
|
||||
mock_func.return_value = return_time
|
||||
flag, state = indicator.next(state)
|
||||
self.assertEqual(v, flag.numpy())
|
||||
|
||||
|
||||
def _get_l2_clip_fn():
|
||||
|
||||
def l2_clip_fn(record_as_list, clip_value):
|
||||
|
|
Loading…
Reference in a new issue