diff --git a/tensorflow_privacy/privacy/dp_query/restart_query.py b/tensorflow_privacy/privacy/dp_query/restart_query.py index 5716b0b..8d1ff8c 100644 --- a/tensorflow_privacy/privacy/dp_query/restart_query.py +++ b/tensorflow_privacy/privacy/dp_query/restart_query.py @@ -17,6 +17,7 @@ This query is used to compose with a DPQuery that has `reset_state` function. """ import abc import collections +from typing import Optional import tensorflow as tf @@ -60,17 +61,26 @@ class PeriodicRoundRestartIndicator(RestartIndicator): The indicator will maintain an internal counter as state. """ - def __init__(self, frequency: int): + def __init__(self, frequency: int, warmup: Optional[int] = None): """Construct the `PeriodicRoundRestartIndicator`. Args: frequency: The `next` function will return `True` every `frequency` number of `next` calls. + warmup: The first `True` will be returned at the `warmup` times call of + `next`. """ if frequency < 1: - raise ValueError('Restart frequency should be equal or larger than 1 ' + raise ValueError('Restart frequency should be equal or larger than 1, ' f'got {frequency}') - self.frequency = tf.constant(frequency, tf.int32) + if warmup is None: + warmup = 0 + elif warmup <= 0 or warmup >= frequency: + raise ValueError( + f'Warmup should be between 1 and `frequency-1={frequency-1}`, ' + f'got {warmup}') + self.frequency = frequency + self.warmup = warmup def initialize(self): """Returns initialized state of 0 for `PeriodicRoundRestartIndicator`.""" @@ -86,8 +96,10 @@ class PeriodicRoundRestartIndicator(RestartIndicator): A pair (value, new_state) where value is the bool indicator and new_state of `state+1`. """ + frequency = tf.constant(self.frequency, tf.int32) + warmup = tf.constant(self.warmup, tf.int32) state = state + tf.constant(1, tf.int32) - flag = state % self.frequency == 0 + flag = tf.math.equal(tf.math.floormod(state, frequency), warmup) return flag, state @@ -132,6 +144,7 @@ class RestartQuery(dp_query.SumAggregationDPQuery): """Implements `tensorflow_privacy.DPQuery.preprocess_record`.""" return self._inner_query.preprocess_record(params, record) + @tf.function def get_noised_result(self, sample_state, global_state): """Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" noised_results, inner_state, event = self._inner_query.get_noised_result( diff --git a/tensorflow_privacy/privacy/dp_query/restart_query_test.py b/tensorflow_privacy/privacy/dp_query/restart_query_test.py index f3a0276..bf6c374 100644 --- a/tensorflow_privacy/privacy/dp_query/restart_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/restart_query_test.py @@ -27,6 +27,15 @@ class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase): ValueError, 'Restart frequency should be equal or larger than 1'): restart_query.PeriodicRoundRestartIndicator(frequency) + @parameterized.named_parameters(('zero', 0), ('negative', -1), ('equal', 2), + ('large', 3)) + def test_round_raise_warmup(self, warmup): + frequency = 2 + with self.assertRaisesRegex( + ValueError, + f'Warmup should be between 1 and `frequency-1={frequency-1}`'): + restart_query.PeriodicRoundRestartIndicator(frequency, warmup) + @parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5)) def test_round_indicator(self, frequency): total_steps = 20 @@ -39,6 +48,18 @@ class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase): else: self.assertFalse(flag) + @parameterized.named_parameters(('f2', 2, 1), ('f4', 4, 3), ('f5', 5, 2)) + def test_round_indicator_warmup(self, frequency, warmup): + total_steps = 20 + indicator = restart_query.PeriodicRoundRestartIndicator(frequency, warmup) + state = indicator.initialize() + for i in range(total_steps): + flag, state = indicator.next(state) + if i % frequency == warmup - 1: + self.assertTrue(flag) + else: + self.assertFalse(flag) + def _get_l2_clip_fn(): @@ -118,7 +139,6 @@ class RestartQueryTest(tf.test.TestCase, parameterized.TestCase): expected = scalar_value + tree_node_value * ( bin(i % frequency + 1)[2:].count('1') - bin(i % frequency)[2:].count('1')) - print(i, query_result, expected) self.assertEqual(query_result, expected)