Change PeriodicRoundRestartIndicator
to return the first True
at a given number of calls. Also update the code style to be more compatible with graph mode and TFF.
PiperOrigin-RevId: 397918733
This commit is contained in:
parent
388f46ffa0
commit
c39d628e16
2 changed files with 38 additions and 5 deletions
|
@ -17,6 +17,7 @@ This query is used to compose with a DPQuery that has `reset_state` function.
|
||||||
"""
|
"""
|
||||||
import abc
|
import abc
|
||||||
import collections
|
import collections
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -60,17 +61,26 @@ class PeriodicRoundRestartIndicator(RestartIndicator):
|
||||||
The indicator will maintain an internal counter as state.
|
The indicator will maintain an internal counter as state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, frequency: int):
|
def __init__(self, frequency: int, warmup: Optional[int] = None):
|
||||||
"""Construct the `PeriodicRoundRestartIndicator`.
|
"""Construct the `PeriodicRoundRestartIndicator`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
frequency: The `next` function will return `True` every `frequency` number
|
frequency: The `next` function will return `True` every `frequency` number
|
||||||
of `next` calls.
|
of `next` calls.
|
||||||
|
warmup: The first `True` will be returned at the `warmup` times call of
|
||||||
|
`next`.
|
||||||
"""
|
"""
|
||||||
if frequency < 1:
|
if frequency < 1:
|
||||||
raise ValueError('Restart frequency should be equal or larger than 1 '
|
raise ValueError('Restart frequency should be equal or larger than 1, '
|
||||||
f'got {frequency}')
|
f'got {frequency}')
|
||||||
self.frequency = tf.constant(frequency, tf.int32)
|
if warmup is None:
|
||||||
|
warmup = 0
|
||||||
|
elif warmup <= 0 or warmup >= frequency:
|
||||||
|
raise ValueError(
|
||||||
|
f'Warmup should be between 1 and `frequency-1={frequency-1}`, '
|
||||||
|
f'got {warmup}')
|
||||||
|
self.frequency = frequency
|
||||||
|
self.warmup = warmup
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
"""Returns initialized state of 0 for `PeriodicRoundRestartIndicator`."""
|
"""Returns initialized state of 0 for `PeriodicRoundRestartIndicator`."""
|
||||||
|
@ -86,8 +96,10 @@ class PeriodicRoundRestartIndicator(RestartIndicator):
|
||||||
A pair (value, new_state) where value is the bool indicator and new_state
|
A pair (value, new_state) where value is the bool indicator and new_state
|
||||||
of `state+1`.
|
of `state+1`.
|
||||||
"""
|
"""
|
||||||
|
frequency = tf.constant(self.frequency, tf.int32)
|
||||||
|
warmup = tf.constant(self.warmup, tf.int32)
|
||||||
state = state + tf.constant(1, tf.int32)
|
state = state + tf.constant(1, tf.int32)
|
||||||
flag = state % self.frequency == 0
|
flag = tf.math.equal(tf.math.floormod(state, frequency), warmup)
|
||||||
return flag, state
|
return flag, state
|
||||||
|
|
||||||
|
|
||||||
|
@ -132,6 +144,7 @@ class RestartQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
|
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
|
||||||
return self._inner_query.preprocess_record(params, record)
|
return self._inner_query.preprocess_record(params, record)
|
||||||
|
|
||||||
|
@tf.function
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||||
noised_results, inner_state, event = self._inner_query.get_noised_result(
|
noised_results, inner_state, event = self._inner_query.get_noised_result(
|
||||||
|
|
|
@ -27,6 +27,15 @@ class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
ValueError, 'Restart frequency should be equal or larger than 1'):
|
ValueError, 'Restart frequency should be equal or larger than 1'):
|
||||||
restart_query.PeriodicRoundRestartIndicator(frequency)
|
restart_query.PeriodicRoundRestartIndicator(frequency)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('zero', 0), ('negative', -1), ('equal', 2),
|
||||||
|
('large', 3))
|
||||||
|
def test_round_raise_warmup(self, warmup):
|
||||||
|
frequency = 2
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
f'Warmup should be between 1 and `frequency-1={frequency-1}`'):
|
||||||
|
restart_query.PeriodicRoundRestartIndicator(frequency, warmup)
|
||||||
|
|
||||||
@parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5))
|
@parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5))
|
||||||
def test_round_indicator(self, frequency):
|
def test_round_indicator(self, frequency):
|
||||||
total_steps = 20
|
total_steps = 20
|
||||||
|
@ -39,6 +48,18 @@ class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
else:
|
else:
|
||||||
self.assertFalse(flag)
|
self.assertFalse(flag)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('f2', 2, 1), ('f4', 4, 3), ('f5', 5, 2))
|
||||||
|
def test_round_indicator_warmup(self, frequency, warmup):
|
||||||
|
total_steps = 20
|
||||||
|
indicator = restart_query.PeriodicRoundRestartIndicator(frequency, warmup)
|
||||||
|
state = indicator.initialize()
|
||||||
|
for i in range(total_steps):
|
||||||
|
flag, state = indicator.next(state)
|
||||||
|
if i % frequency == warmup - 1:
|
||||||
|
self.assertTrue(flag)
|
||||||
|
else:
|
||||||
|
self.assertFalse(flag)
|
||||||
|
|
||||||
|
|
||||||
def _get_l2_clip_fn():
|
def _get_l2_clip_fn():
|
||||||
|
|
||||||
|
@ -118,7 +139,6 @@ class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
expected = scalar_value + tree_node_value * (
|
expected = scalar_value + tree_node_value * (
|
||||||
bin(i % frequency + 1)[2:].count('1') -
|
bin(i % frequency + 1)[2:].count('1') -
|
||||||
bin(i % frequency)[2:].count('1'))
|
bin(i % frequency)[2:].count('1'))
|
||||||
print(i, query_result, expected)
|
|
||||||
self.assertEqual(query_result, expected)
|
self.assertEqual(query_result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue