Define RestartQuery for easy composition to restart tree in tree aggregation queries.

PiperOrigin-RevId: 394106175
This commit is contained in:
Zheng Xu 2021-08-31 16:05:57 -07:00 committed by A. Unique TensorFlower
parent 789a05df63
commit 6ac4bc8d01
7 changed files with 336 additions and 163 deletions

View file

@ -62,7 +62,9 @@ else:
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import QuantileEstimatorQuery
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import NoPrivacyQuantileEstimatorQuery
from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery
from tensorflow_privacy.privacy.dp_query import restart_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation
from tensorflow_privacy.privacy.dp_query.restart_query import RestartQuery
from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeCumulativeSumQuery
from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeResidualSumQuery
from tensorflow_privacy.privacy.dp_query.tree_range_query import TreeRangeSumQuery

View file

@ -0,0 +1,148 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements DPQuery interface for restarting the states of another query.
This query is used to compose with a DPQuery that has `reset_state` function.
"""
import abc
import collections
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
class RestartIndicator(metaclass=abc.ABCMeta):
"""Base class establishing interface for restarting the tree state.
A `RestartIndicator` maintains a state, and each time `next` is called, a bool
value is generated to indicate whether to restart, and the indicator state is
advanced.
"""
@abc.abstractmethod
def initialize(self):
"""Makes an initialized state for `RestartIndicator`.
Returns:
An initial state.
"""
raise NotImplementedError
@abc.abstractmethod
def next(self, state):
"""Gets next bool indicator and advances the `RestartIndicator` state.
Args:
state: The current state.
Returns:
A pair (value, new_state) where value is bool indicator and new_state
is the advanced state.
"""
raise NotImplementedError
class PeriodicRoundRestartIndicator(RestartIndicator):
"""Indicator for resetting the tree state after every a few number of queries.
The indicator will maintain an internal counter as state.
"""
def __init__(self, frequency: int):
"""Construct the `PeriodicRoundRestartIndicator`.
Args:
frequency: The `next` function will return `True` every `frequency` number
of `next` calls.
"""
if frequency < 1:
raise ValueError('Restart frequency should be equal or larger than 1 '
f'got {frequency}')
self.frequency = tf.constant(frequency, tf.int32)
def initialize(self):
"""Returns initialized state of 0 for `PeriodicRoundRestartIndicator`."""
return tf.constant(0, tf.int32)
def next(self, state):
"""Gets next bool indicator and advances the state.
Args:
state: The current state.
Returns:
A pair (value, new_state) where value is the bool indicator and new_state
of `state+1`.
"""
state = state + tf.constant(1, tf.int32)
flag = state % self.frequency == 0
return flag, state
class RestartQuery(dp_query.SumAggregationDPQuery):
"""`DPQuery` for `SumAggregationDPQuery` with a `reset_state` function."""
# pylint: disable=invalid-name
_GlobalState = collections.namedtuple(
'_GlobalState', ['inner_query_state', 'indicator_state'])
def __init__(self, inner_query: dp_query.SumAggregationDPQuery,
restart_indicator: RestartIndicator):
"""Initializes `RestartQuery`.
Args:
inner_query: A `SumAggregationDPQuery` has `reset_state` attribute.
restart_indicator: A `RestartIndicator` to generate the boolean indicator
for resetting the state.
"""
if not hasattr(inner_query, 'reset_state'):
raise ValueError(f'{type(inner_query)} must define `reset_state` to be '
'composed with `RestartQuery`.')
self._inner_query = inner_query
self._restart_indicator = restart_indicator
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return self._GlobalState(
inner_query_state=self._inner_query.initial_global_state(),
indicator_state=self._restart_indicator.initialize())
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return self._inner_query.derive_sample_params(
global_state.inner_query_state)
def initial_sample_state(self, template):
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
return self._inner_query.initial_sample_state(template)
def preprocess_record(self, params, record):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
return self._inner_query.preprocess_record(params, record)
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
noised_results, inner_query_state = self._inner_query.get_noised_result(
sample_state, global_state.inner_query_state)
restart_flag, indicator_state = self._restart_indicator.next(
global_state.indicator_state)
if restart_flag:
inner_query_state = self._inner_query.reset_state(noised_results,
inner_query_state)
return noised_results, self._GlobalState(inner_query_state, indicator_state)
def derive_metrics(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
return self._inner_query.derive_metrics(global_state.inner_query_state)

View file

@ -0,0 +1,126 @@
# Copyright 2021, Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for `restart_query`."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import restart_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('zero', 0), ('negative', -1))
def test_round_raise(self, frequency):
with self.assertRaisesRegex(
ValueError, 'Restart frequency should be equal or larger than 1'):
restart_query.PeriodicRoundRestartIndicator(frequency)
@parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5))
def test_round_indicator(self, frequency):
total_steps = 20
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
state = indicator.initialize()
for i in range(total_steps):
flag, state = indicator.next(state)
if i % frequency == frequency - 1:
self.assertTrue(flag)
else:
self.assertFalse(flag)
def _get_l2_clip_fn():
def l2_clip_fn(record_as_list, clip_value):
clipped_record, _ = tf.clip_by_global_norm(record_as_list, clip_value)
return clipped_record
return l2_clip_fn
class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('s0t1f1', 0., 1., 1),
('s0t1f2', 0., 1., 2),
('s0t1f5', 0., 1., 5),
('s1t1f5', 1., 1., 5),
('s1t2f2', 1., 2., 2),
('s1t5f6', 1., 5., 6),
)
def test_sum_scalar_tree_aggregation_reset(self, scalar_value,
tree_node_value, frequency):
total_steps = 20
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
query = tree_aggregation_query.TreeCumulativeSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip
noise_generator=lambda: tree_node_value,
record_specs=tf.TensorSpec([]),
use_efficient=False)
query = restart_query.RestartQuery(query, indicator)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result(
sample_state, global_state)
# Expected value is the combination of cumsum of signal; sum of trees
# that have been reset; current tree sum. The tree aggregation value can
# be inferred from the binary representation of the current step.
expected = (
scalar_value * (i + 1) +
i // frequency * tree_node_value * bin(frequency)[2:].count('1') +
tree_node_value * bin(i % frequency + 1)[2:].count('1'))
self.assertEqual(query_result, expected)
@parameterized.named_parameters(
('s0t1f1', 0., 1., 1),
('s0t1f2', 0., 1., 2),
('s0t1f5', 0., 1., 5),
('s1t1f5', 1., 1., 5),
('s1t2f2', 1., 2., 2),
('s1t5f6', 1., 5., 6),
)
def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value,
frequency):
total_steps = 20
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
query = tree_aggregation_query.TreeResidualSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip
noise_generator=lambda: tree_node_value,
record_specs=tf.TensorSpec([]),
use_efficient=False)
query = restart_query.RestartQuery(query, indicator)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result(
sample_state, global_state)
# Expected value is the signal of the current round plus the residual of
# two continous tree aggregation values. The tree aggregation value can
# be inferred from the binary representation of the current step.
expected = scalar_value + tree_node_value * (
bin(i % frequency + 1)[2:].count('1') -
bin(i % frequency)[2:].count('1'))
print(i, query_result, expected)
self.assertEqual(query_result, expected)
if __name__ == '__main__':
tf.test.main()

View file

@ -171,78 +171,6 @@ class StatelessValueGenerator(ValueGenerator):
return self.value_fn(), state
# TODO(b/192464750): define `RestartQuery` and move `RestartIndicator` to be
# in the same module.
class RestartIndicator(metaclass=abc.ABCMeta):
"""Base class establishing interface for restarting the tree state.
A `RestartIndicator` maintains a state, and each time `next` is called, a bool
value is generated to indicate whether to restart, and the indicator state is
advanced.
"""
@abc.abstractmethod
def initialize(self):
"""Makes an initialized state for `RestartIndicator`.
Returns:
An initial state.
"""
raise NotImplementedError
@abc.abstractmethod
def next(self, state):
"""Gets next bool indicator and advances the `RestartIndicator` state.
Args:
state: The current state.
Returns:
A pair (value, new_state) where value is bool indicator and new_state
is the advanced state.
"""
raise NotImplementedError
class PeriodicRoundRestartIndicator(RestartIndicator):
"""Indicator for resetting the tree state after every a few number of queries.
The indicator will maintain an internal counter as state.
"""
def __init__(self, frequency: int):
"""Construct the `PeriodicRoundRestartIndicator`.
Args:
frequency: The `next` function will return `True` every `frequency` number
of `next` calls.
"""
if frequency < 1:
raise ValueError('Restart frequency should be equal or larger than 1 '
f'got {frequency}')
self.frequency = tf.constant(frequency, tf.int32)
def initialize(self):
"""Returns initialized state of 0 for `PeriodicRoundRestartIndicator`."""
return tf.constant(0, tf.int32)
def next(self, state):
"""Gets next bool indicator and advances the state.
Args:
state: The current state.
Returns:
A pair (value, new_state) where value is the bool indicator and new_state
of `state+1`.
"""
state = state + tf.constant(1, tf.int32)
flag = state % self.frequency == 0
return flag, state
@attr.s(eq=False, frozen=True, slots=True)
class TreeState(object):
"""Class defining state of the tree.

View file

@ -72,8 +72,6 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
node. Noise stdandard deviation is specified outside the `dp_query` by the
user when defining `noise_fn` and should have order
O(clip_norm*log(T)/eps) to guarantee eps-DP.
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
"""
@attr.s(frozen=True)
@ -85,21 +83,17 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
each level state.
clip_value: The clipping value to be passed to clip_fn.
samples_cumulative_sum: Noiseless cumulative sum of samples over time.
restarter_state: Current state of the restarter to indicate whether
the tree state will be reset.
"""
tree_state = attr.ib()
clip_value = attr.ib()
samples_cumulative_sum = attr.ib()
restarter_state = attr.ib()
def __init__(self,
record_specs,
noise_generator,
clip_fn,
clip_value,
use_efficient=True,
restart_indicator=None):
use_efficient=True):
"""Initializes the `TreeCumulativeSumQuery`.
Consider using `build_l2_gaussian_query` for the construction of a
@ -117,8 +111,6 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
use_efficient: Boolean indicating the usage of the efficient tree
aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees".
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
"""
self._clip_fn = clip_fn
self._clip_value = clip_value
@ -128,21 +120,16 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
noise_generator)
else:
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
self._restart_indicator = restart_indicator
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
initial_tree_state = self._tree_aggregator.init_state()
initial_samples_cumulative_sum = tf.nest.map_structure(
lambda spec: tf.zeros(spec.shape), self._record_specs)
restarter_state = ()
if self._restart_indicator is not None:
restarter_state = self._restart_indicator.initialize()
return TreeCumulativeSumQuery.GlobalState(
tree_state=initial_tree_state,
clip_value=tf.constant(self._clip_value, tf.float32),
samples_cumulative_sum=initial_samples_cumulative_sum,
restarter_state=restarter_state)
samples_cumulative_sum=initial_samples_cumulative_sum)
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
@ -185,28 +172,41 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
global_state.tree_state)
noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
cumulative_sum_noise)
restarter_state = global_state.restarter_state
if self._restart_indicator is not None:
restart_flag, restarter_state = self._restart_indicator.next(
restarter_state)
if restart_flag:
new_cumulative_sum = noised_cumulative_sum
new_tree_state = self._tree_aggregator.reset_state(new_tree_state)
new_global_state = attr.evolve(
global_state,
samples_cumulative_sum=new_cumulative_sum,
tree_state=new_tree_state,
restarter_state=restarter_state)
tree_state=new_tree_state)
return noised_cumulative_sum, new_global_state
def reset_state(self, noised_results, global_state):
"""Returns state after resetting the tree.
This function will be used in `restart_query.RestartQuery` after calling
`get_noised_result` when the restarting condition is met.
Args:
noised_results: Noised cumulative sum returned by `get_noised_result`.
global_state: Updated global state returned by `get_noised_result`, which
has current sample's cumulative sum and tree state for the next
cumulative sum.
Returns:
New global state with current noised cumulative sum and restarted tree
state for the next cumulative sum.
"""
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
return attr.evolve(
global_state,
samples_cumulative_sum=noised_results,
tree_state=new_tree_state)
@classmethod
def build_l2_gaussian_query(cls,
clip_norm,
noise_multiplier,
record_specs,
noise_seed=None,
use_efficient=True,
restart_indicator=None):
use_efficient=True):
"""Returns a query instance with L2 norm clipping and Gaussian noise.
Args:
@ -221,8 +221,6 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
use_efficient: Boolean indicating the usage of the efficient tree
aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees".
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
"""
if clip_norm <= 0:
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
@ -245,8 +243,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
clip_value=clip_norm,
record_specs=record_specs,
noise_generator=gaussian_noise_generator,
use_efficient=use_efficient,
restart_indicator=restart_indicator)
use_efficient=use_efficient)
class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
@ -300,8 +297,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
node. Noise stdandard deviation is specified outside the `dp_query` by the
user when defining `noise_fn` and should have order
O(clip_norm*log(T)/eps) to guarantee eps-DP.
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
"""
@attr.s(frozen=True)
@ -314,21 +309,17 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
clip_value: The clipping value to be passed to clip_fn.
previous_tree_noise: Cumulative noise by tree aggregation from the
previous time the query is called on a sample.
restarter_state: Current state of the restarter to indicate whether
the tree state will be reset.
"""
tree_state = attr.ib()
clip_value = attr.ib()
previous_tree_noise = attr.ib()
restarter_state = attr.ib()
def __init__(self,
record_specs,
noise_generator,
clip_fn,
clip_value,
use_efficient=True,
restart_indicator=None):
use_efficient=True):
"""Initializes the `TreeCumulativeSumQuery`.
Consider using `build_l2_gaussian_query` for the construction of a
@ -346,8 +337,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
use_efficient: Boolean indicating the usage of the efficient tree
aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees".
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
"""
self._clip_fn = clip_fn
self._clip_value = clip_value
@ -357,7 +346,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
noise_generator)
else:
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
self._restart_indicator = restart_indicator
def _zero_initial_noise(self):
return tf.nest.map_structure(lambda spec: tf.zeros(spec.shape),
@ -366,14 +354,10 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
initial_tree_state = self._tree_aggregator.init_state()
restarter_state = ()
if self._restart_indicator is not None:
restarter_state = self._restart_indicator.initialize()
return TreeResidualSumQuery.GlobalState(
tree_state=initial_tree_state,
clip_value=tf.constant(self._clip_value, tf.float32),
previous_tree_noise=self._zero_initial_noise(),
restarter_state=restarter_state)
previous_tree_noise=self._zero_initial_noise())
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
@ -412,28 +396,39 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c,
sample_state, tree_noise,
global_state.previous_tree_noise)
restarter_state = global_state.restarter_state
if self._restart_indicator is not None:
restart_flag, restarter_state = self._restart_indicator.next(
restarter_state)
if restart_flag:
tree_noise = self._zero_initial_noise()
new_tree_state = self._tree_aggregator.reset_state(new_tree_state)
new_global_state = attr.evolve(
global_state,
previous_tree_noise=tree_noise,
tree_state=new_tree_state,
restarter_state=restarter_state)
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
return noised_sample, new_global_state
def reset_state(self, noised_results, global_state):
"""Returns state after resetting the tree.
This function will be used in `restart_query.RestartQuery` after calling
`get_noised_result` when the restarting condition is met.
Args:
noised_results: Noised cumulative sum returned by `get_noised_result`.
global_state: Updated global state returned by `get_noised_result`, which
records noise for the conceptual cumulative sum of the current leaf
node, and tree state for the next conceptual cumulative sum.
Returns:
New global state with zero noise and restarted tree state.
"""
del noised_results
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
return attr.evolve(
global_state,
previous_tree_noise=self._zero_initial_noise(),
tree_state=new_tree_state)
@classmethod
def build_l2_gaussian_query(cls,
clip_norm,
noise_multiplier,
record_specs,
noise_seed=None,
use_efficient=True,
restart_indicator=None):
use_efficient=True):
"""Returns `TreeResidualSumQuery` with L2 norm clipping and Gaussian noise.
Args:
@ -448,8 +443,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
use_efficient: Boolean indicating the usage of the efficient tree
aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees".
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
"""
if clip_norm <= 0:
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
@ -472,8 +465,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
clip_value=clip_norm,
record_specs=record_specs,
noise_generator=gaussian_noise_generator,
use_efficient=use_efficient,
restart_indicator=restart_indicator)
use_efficient=use_efficient)
# TODO(b/197596864): Remove `TreeRangeSumQuery` from this file after the next

View file

@ -303,15 +303,12 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_sum_scalar_tree_aggregation_reset(self, scalar_value,
tree_node_value, frequency):
total_steps = 20
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
query = tree_aggregation_query.TreeCumulativeSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip
noise_generator=lambda: tree_node_value,
record_specs=tf.TensorSpec([]),
use_efficient=False,
restart_indicator=indicator,
)
use_efficient=False)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i in range(total_steps):
@ -319,6 +316,8 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result(
sample_state, global_state)
if i % frequency == frequency - 1:
global_state = query.reset_state(query_result, global_state)
# Expected value is the combination of cumsum of signal; sum of trees
# that have been reset; current tree sum. The tree aggregation value can
# be inferred from the binary representation of the current step.
@ -446,15 +445,12 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value,
frequency):
total_steps = 20
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
query = tree_aggregation_query.TreeResidualSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip
noise_generator=lambda: tree_node_value,
record_specs=tf.TensorSpec([]),
use_efficient=False,
restart_indicator=indicator,
)
use_efficient=False)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i in range(total_steps):
@ -462,6 +458,8 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result(
sample_state, global_state)
if i % frequency == frequency - 1:
global_state = query.reset_state(query_result, global_state)
# Expected value is the signal of the current round plus the residual of
# two continous tree aggregation values. The tree aggregation value can
# be inferred from the binary representation of the current step.

View file

@ -396,26 +396,5 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase, parameterized.TestCase):
self.assertNotAllEqual(gstate.seeds, prev_gstate.seeds)
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('zero', 0), ('negative', -1))
def test_round_raise(self, frequency):
with self.assertRaisesRegex(
ValueError, 'Restart frequency should be equal or larger than 1'):
tree_aggregation.PeriodicRoundRestartIndicator(frequency)
@parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5))
def test_round_indicator(self, frequency):
total_steps = 20
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
state = indicator.initialize()
for i in range(total_steps):
flag, state = indicator.next(state)
if i % frequency == frequency - 1:
self.assertTrue(flag)
else:
self.assertFalse(flag)
if __name__ == '__main__':
tf.test.main()