Change PeriodicRoundRestartIndicator
to return the first True
at a given number of calls. Also update the code style to be more compatible with graph mode and TFF.
PiperOrigin-RevId: 397918733
This commit is contained in:
parent
388f46ffa0
commit
c39d628e16
2 changed files with 38 additions and 5 deletions
|
@ -17,6 +17,7 @@ This query is used to compose with a DPQuery that has `reset_state` function.
|
|||
"""
|
||||
import abc
|
||||
import collections
|
||||
from typing import Optional
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -60,17 +61,26 @@ class PeriodicRoundRestartIndicator(RestartIndicator):
|
|||
The indicator will maintain an internal counter as state.
|
||||
"""
|
||||
|
||||
def __init__(self, frequency: int):
|
||||
def __init__(self, frequency: int, warmup: Optional[int] = None):
|
||||
"""Construct the `PeriodicRoundRestartIndicator`.
|
||||
|
||||
Args:
|
||||
frequency: The `next` function will return `True` every `frequency` number
|
||||
of `next` calls.
|
||||
warmup: The first `True` will be returned at the `warmup` times call of
|
||||
`next`.
|
||||
"""
|
||||
if frequency < 1:
|
||||
raise ValueError('Restart frequency should be equal or larger than 1 '
|
||||
raise ValueError('Restart frequency should be equal or larger than 1, '
|
||||
f'got {frequency}')
|
||||
self.frequency = tf.constant(frequency, tf.int32)
|
||||
if warmup is None:
|
||||
warmup = 0
|
||||
elif warmup <= 0 or warmup >= frequency:
|
||||
raise ValueError(
|
||||
f'Warmup should be between 1 and `frequency-1={frequency-1}`, '
|
||||
f'got {warmup}')
|
||||
self.frequency = frequency
|
||||
self.warmup = warmup
|
||||
|
||||
def initialize(self):
|
||||
"""Returns initialized state of 0 for `PeriodicRoundRestartIndicator`."""
|
||||
|
@ -86,8 +96,10 @@ class PeriodicRoundRestartIndicator(RestartIndicator):
|
|||
A pair (value, new_state) where value is the bool indicator and new_state
|
||||
of `state+1`.
|
||||
"""
|
||||
frequency = tf.constant(self.frequency, tf.int32)
|
||||
warmup = tf.constant(self.warmup, tf.int32)
|
||||
state = state + tf.constant(1, tf.int32)
|
||||
flag = state % self.frequency == 0
|
||||
flag = tf.math.equal(tf.math.floormod(state, frequency), warmup)
|
||||
return flag, state
|
||||
|
||||
|
||||
|
@ -132,6 +144,7 @@ class RestartQuery(dp_query.SumAggregationDPQuery):
|
|||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
|
||||
return self._inner_query.preprocess_record(params, record)
|
||||
|
||||
@tf.function
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
noised_results, inner_state, event = self._inner_query.get_noised_result(
|
||||
|
|
|
@ -27,6 +27,15 @@ class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
ValueError, 'Restart frequency should be equal or larger than 1'):
|
||||
restart_query.PeriodicRoundRestartIndicator(frequency)
|
||||
|
||||
@parameterized.named_parameters(('zero', 0), ('negative', -1), ('equal', 2),
|
||||
('large', 3))
|
||||
def test_round_raise_warmup(self, warmup):
|
||||
frequency = 2
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
f'Warmup should be between 1 and `frequency-1={frequency-1}`'):
|
||||
restart_query.PeriodicRoundRestartIndicator(frequency, warmup)
|
||||
|
||||
@parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5))
|
||||
def test_round_indicator(self, frequency):
|
||||
total_steps = 20
|
||||
|
@ -39,6 +48,18 @@ class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
else:
|
||||
self.assertFalse(flag)
|
||||
|
||||
@parameterized.named_parameters(('f2', 2, 1), ('f4', 4, 3), ('f5', 5, 2))
|
||||
def test_round_indicator_warmup(self, frequency, warmup):
|
||||
total_steps = 20
|
||||
indicator = restart_query.PeriodicRoundRestartIndicator(frequency, warmup)
|
||||
state = indicator.initialize()
|
||||
for i in range(total_steps):
|
||||
flag, state = indicator.next(state)
|
||||
if i % frequency == warmup - 1:
|
||||
self.assertTrue(flag)
|
||||
else:
|
||||
self.assertFalse(flag)
|
||||
|
||||
|
||||
def _get_l2_clip_fn():
|
||||
|
||||
|
@ -118,7 +139,6 @@ class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
expected = scalar_value + tree_node_value * (
|
||||
bin(i % frequency + 1)[2:].count('1') -
|
||||
bin(i % frequency)[2:].count('1'))
|
||||
print(i, query_result, expected)
|
||||
self.assertEqual(query_result, expected)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue