Change PeriodicRoundRestartIndicator to return the first True at a given number of calls. Also update the code style to be more compatible with graph mode and TFF.

PiperOrigin-RevId: 397918733
This commit is contained in:
Zheng Xu 2021-09-20 22:38:23 -07:00 committed by A. Unique TensorFlower
parent 388f46ffa0
commit c39d628e16
2 changed files with 38 additions and 5 deletions

View file

@ -17,6 +17,7 @@ This query is used to compose with a DPQuery that has `reset_state` function.
""" """
import abc import abc
import collections import collections
from typing import Optional
import tensorflow as tf import tensorflow as tf
@ -60,17 +61,26 @@ class PeriodicRoundRestartIndicator(RestartIndicator):
The indicator will maintain an internal counter as state. The indicator will maintain an internal counter as state.
""" """
def __init__(self, frequency: int): def __init__(self, frequency: int, warmup: Optional[int] = None):
"""Construct the `PeriodicRoundRestartIndicator`. """Construct the `PeriodicRoundRestartIndicator`.
Args: Args:
frequency: The `next` function will return `True` every `frequency` number frequency: The `next` function will return `True` every `frequency` number
of `next` calls. of `next` calls.
warmup: The first `True` will be returned at the `warmup` times call of
`next`.
""" """
if frequency < 1: if frequency < 1:
raise ValueError('Restart frequency should be equal or larger than 1 ' raise ValueError('Restart frequency should be equal or larger than 1, '
f'got {frequency}') f'got {frequency}')
self.frequency = tf.constant(frequency, tf.int32) if warmup is None:
warmup = 0
elif warmup <= 0 or warmup >= frequency:
raise ValueError(
f'Warmup should be between 1 and `frequency-1={frequency-1}`, '
f'got {warmup}')
self.frequency = frequency
self.warmup = warmup
def initialize(self): def initialize(self):
"""Returns initialized state of 0 for `PeriodicRoundRestartIndicator`.""" """Returns initialized state of 0 for `PeriodicRoundRestartIndicator`."""
@ -86,8 +96,10 @@ class PeriodicRoundRestartIndicator(RestartIndicator):
A pair (value, new_state) where value is the bool indicator and new_state A pair (value, new_state) where value is the bool indicator and new_state
of `state+1`. of `state+1`.
""" """
frequency = tf.constant(self.frequency, tf.int32)
warmup = tf.constant(self.warmup, tf.int32)
state = state + tf.constant(1, tf.int32) state = state + tf.constant(1, tf.int32)
flag = state % self.frequency == 0 flag = tf.math.equal(tf.math.floormod(state, frequency), warmup)
return flag, state return flag, state
@ -132,6 +144,7 @@ class RestartQuery(dp_query.SumAggregationDPQuery):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.""" """Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
return self._inner_query.preprocess_record(params, record) return self._inner_query.preprocess_record(params, record)
@tf.function
def get_noised_result(self, sample_state, global_state): def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" """Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
noised_results, inner_state, event = self._inner_query.get_noised_result( noised_results, inner_state, event = self._inner_query.get_noised_result(

View file

@ -27,6 +27,15 @@ class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
ValueError, 'Restart frequency should be equal or larger than 1'): ValueError, 'Restart frequency should be equal or larger than 1'):
restart_query.PeriodicRoundRestartIndicator(frequency) restart_query.PeriodicRoundRestartIndicator(frequency)
@parameterized.named_parameters(('zero', 0), ('negative', -1), ('equal', 2),
('large', 3))
def test_round_raise_warmup(self, warmup):
frequency = 2
with self.assertRaisesRegex(
ValueError,
f'Warmup should be between 1 and `frequency-1={frequency-1}`'):
restart_query.PeriodicRoundRestartIndicator(frequency, warmup)
@parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5)) @parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5))
def test_round_indicator(self, frequency): def test_round_indicator(self, frequency):
total_steps = 20 total_steps = 20
@ -39,6 +48,18 @@ class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
else: else:
self.assertFalse(flag) self.assertFalse(flag)
@parameterized.named_parameters(('f2', 2, 1), ('f4', 4, 3), ('f5', 5, 2))
def test_round_indicator_warmup(self, frequency, warmup):
total_steps = 20
indicator = restart_query.PeriodicRoundRestartIndicator(frequency, warmup)
state = indicator.initialize()
for i in range(total_steps):
flag, state = indicator.next(state)
if i % frequency == warmup - 1:
self.assertTrue(flag)
else:
self.assertFalse(flag)
def _get_l2_clip_fn(): def _get_l2_clip_fn():
@ -118,7 +139,6 @@ class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
expected = scalar_value + tree_node_value * ( expected = scalar_value + tree_node_value * (
bin(i % frequency + 1)[2:].count('1') - bin(i % frequency + 1)[2:].count('1') -
bin(i % frequency)[2:].count('1')) bin(i % frequency)[2:].count('1'))
print(i, query_result, expected)
self.assertEqual(query_result, expected) self.assertEqual(query_result, expected)