forked from 626_privacy/tensorflow_privacy
Time based indicator for restart query.
PiperOrigin-RevId: 401871582
This commit is contained in:
parent
7426a4ec30
commit
27bb6e48d9
2 changed files with 79 additions and 1 deletions
|
@ -103,6 +103,50 @@ class PeriodicRoundRestartIndicator(RestartIndicator):
|
||||||
return flag, state
|
return flag, state
|
||||||
|
|
||||||
|
|
||||||
|
class PeriodicTimeRestartIndicator(RestartIndicator):
|
||||||
|
"""Indicator for periodically resetting the tree state after a certain time.
|
||||||
|
|
||||||
|
The indicator will maintain a state to track the previous restart time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, period_seconds: float):
|
||||||
|
"""Construct the `PeriodicTimeRestartIndicator`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
period_seconds: The `next` function will return `True` if called after
|
||||||
|
`period_seconds`.
|
||||||
|
"""
|
||||||
|
if period_seconds <= 0:
|
||||||
|
raise ValueError('Restart period_seconds should be larger than 0, got '
|
||||||
|
f'{period_seconds}')
|
||||||
|
self.period_seconds = period_seconds
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def initialize(self):
|
||||||
|
"""Returns initial time as state."""
|
||||||
|
return tf.timestamp()
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def next(self, state):
|
||||||
|
"""Gets next bool indicator and advances the state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The current state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A pair (value, new_state) where value is the bool indicator and new_state
|
||||||
|
of time.
|
||||||
|
"""
|
||||||
|
current_time = tf.timestamp()
|
||||||
|
current_period = current_time - state
|
||||||
|
reset_flag = tf.math.greater(
|
||||||
|
current_period,
|
||||||
|
tf.convert_to_tensor(self.period_seconds, current_period.dtype))
|
||||||
|
if reset_flag:
|
||||||
|
state = current_time
|
||||||
|
return reset_flag, state
|
||||||
|
|
||||||
|
|
||||||
class RestartQuery(dp_query.SumAggregationDPQuery):
|
class RestartQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""`DPQuery` for `SumAggregationDPQuery` with a `reset_state` function."""
|
"""`DPQuery` for `SumAggregationDPQuery` with a `reset_state` function."""
|
||||||
|
|
||||||
|
|
|
@ -13,13 +13,14 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for `restart_query`."""
|
"""Tests for `restart_query`."""
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
import mock
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.dp_query import restart_query
|
from tensorflow_privacy.privacy.dp_query import restart_query
|
||||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
|
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
|
||||||
|
|
||||||
|
|
||||||
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
class RoundRestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@parameterized.named_parameters(('zero', 0), ('negative', -1))
|
@parameterized.named_parameters(('zero', 0), ('negative', -1))
|
||||||
def test_round_raise(self, frequency):
|
def test_round_raise(self, frequency):
|
||||||
|
@ -61,6 +62,39 @@ class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertFalse(flag)
|
self.assertFalse(flag)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeRestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('zero', 0), ('negative', -1.))
|
||||||
|
def test_round_raise(self, secs):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'Restart period_seconds should be larger than 0'):
|
||||||
|
restart_query.PeriodicTimeRestartIndicator(secs)
|
||||||
|
|
||||||
|
def test_round_indicator(self):
|
||||||
|
indicator = restart_query.PeriodicTimeRestartIndicator(period_seconds=3600 *
|
||||||
|
23.5)
|
||||||
|
# TODO(b/193679963): use `tf.timestamp` as the default of a member of
|
||||||
|
# the `PeriodicTimeRestartIndicator` to unroll the mock test.
|
||||||
|
return_time = tf.Variable(
|
||||||
|
1627018868.452365) # 22:41pm PST 5:41am UTC, July 22, initialize
|
||||||
|
with mock.patch.object(
|
||||||
|
tf, 'timestamp', return_value=return_time) as mock_func:
|
||||||
|
time_stamps = [
|
||||||
|
1627022468.452365, # 23:41pm PST 5:41am UTC, July 22, 1 hr, False
|
||||||
|
1627105268.452365, # 23:41pm PST 5:41am UTC, July 23, 1 day, True
|
||||||
|
1627112468.452365, # 2 hr after restart, False
|
||||||
|
1627189508.452365, # 23.4 hr after restart, False
|
||||||
|
1627189904.452365, # 23.51 hr after restart, True
|
||||||
|
]
|
||||||
|
expected_values = [False, True, False, False, True]
|
||||||
|
state = indicator.initialize()
|
||||||
|
for v, t in zip(expected_values, time_stamps):
|
||||||
|
return_time.assign(t)
|
||||||
|
mock_func.return_value = return_time
|
||||||
|
flag, state = indicator.next(state)
|
||||||
|
self.assertEqual(v, flag.numpy())
|
||||||
|
|
||||||
|
|
||||||
def _get_l2_clip_fn():
|
def _get_l2_clip_fn():
|
||||||
|
|
||||||
def l2_clip_fn(record_as_list, clip_value):
|
def l2_clip_fn(record_as_list, clip_value):
|
||||||
|
|
Loading…
Reference in a new issue