diff --git a/tensorflow_privacy/__init__.py b/tensorflow_privacy/__init__.py index a2e25d1..29b64ff 100644 --- a/tensorflow_privacy/__init__.py +++ b/tensorflow_privacy/__init__.py @@ -62,7 +62,9 @@ else: from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import QuantileEstimatorQuery from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import NoPrivacyQuantileEstimatorQuery from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery + from tensorflow_privacy.privacy.dp_query import restart_query from tensorflow_privacy.privacy.dp_query import tree_aggregation + from tensorflow_privacy.privacy.dp_query.restart_query import RestartQuery from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeCumulativeSumQuery from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeResidualSumQuery from tensorflow_privacy.privacy.dp_query.tree_range_query import TreeRangeSumQuery diff --git a/tensorflow_privacy/privacy/dp_query/restart_query.py b/tensorflow_privacy/privacy/dp_query/restart_query.py new file mode 100644 index 0000000..b3994cc --- /dev/null +++ b/tensorflow_privacy/privacy/dp_query/restart_query.py @@ -0,0 +1,148 @@ +# Copyright 2021, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implements DPQuery interface for restarting the states of another query. + +This query is used to compose with a DPQuery that has `reset_state` function. +""" +import abc +import collections + +import tensorflow as tf + +from tensorflow_privacy.privacy.dp_query import dp_query + + +class RestartIndicator(metaclass=abc.ABCMeta): + """Base class establishing interface for restarting the tree state. + + A `RestartIndicator` maintains a state, and each time `next` is called, a bool + value is generated to indicate whether to restart, and the indicator state is + advanced. + """ + + @abc.abstractmethod + def initialize(self): + """Makes an initialized state for `RestartIndicator`. + + Returns: + An initial state. + """ + raise NotImplementedError + + @abc.abstractmethod + def next(self, state): + """Gets next bool indicator and advances the `RestartIndicator` state. + + Args: + state: The current state. + + Returns: + A pair (value, new_state) where value is bool indicator and new_state + is the advanced state. + """ + raise NotImplementedError + + +class PeriodicRoundRestartIndicator(RestartIndicator): + """Indicator for resetting the tree state after every a few number of queries. + + The indicator will maintain an internal counter as state. + """ + + def __init__(self, frequency: int): + """Construct the `PeriodicRoundRestartIndicator`. + + Args: + frequency: The `next` function will return `True` every `frequency` number + of `next` calls. + """ + if frequency < 1: + raise ValueError('Restart frequency should be equal or larger than 1 ' + f'got {frequency}') + self.frequency = tf.constant(frequency, tf.int32) + + def initialize(self): + """Returns initialized state of 0 for `PeriodicRoundRestartIndicator`.""" + return tf.constant(0, tf.int32) + + def next(self, state): + """Gets next bool indicator and advances the state. + + Args: + state: The current state. + + Returns: + A pair (value, new_state) where value is the bool indicator and new_state + of `state+1`. + """ + state = state + tf.constant(1, tf.int32) + flag = state % self.frequency == 0 + return flag, state + + +class RestartQuery(dp_query.SumAggregationDPQuery): + """`DPQuery` for `SumAggregationDPQuery` with a `reset_state` function.""" + + # pylint: disable=invalid-name + _GlobalState = collections.namedtuple( + '_GlobalState', ['inner_query_state', 'indicator_state']) + + def __init__(self, inner_query: dp_query.SumAggregationDPQuery, + restart_indicator: RestartIndicator): + """Initializes `RestartQuery`. + + Args: + inner_query: A `SumAggregationDPQuery` has `reset_state` attribute. + restart_indicator: A `RestartIndicator` to generate the boolean indicator + for resetting the state. + """ + if not hasattr(inner_query, 'reset_state'): + raise ValueError(f'{type(inner_query)} must define `reset_state` to be ' + 'composed with `RestartQuery`.') + self._inner_query = inner_query + self._restart_indicator = restart_indicator + + def initial_global_state(self): + """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" + return self._GlobalState( + inner_query_state=self._inner_query.initial_global_state(), + indicator_state=self._restart_indicator.initialize()) + + def derive_sample_params(self, global_state): + """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" + return self._inner_query.derive_sample_params( + global_state.inner_query_state) + + def initial_sample_state(self, template): + """Implements `tensorflow_privacy.DPQuery.initial_sample_state`.""" + return self._inner_query.initial_sample_state(template) + + def preprocess_record(self, params, record): + """Implements `tensorflow_privacy.DPQuery.preprocess_record`.""" + return self._inner_query.preprocess_record(params, record) + + def get_noised_result(self, sample_state, global_state): + """Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" + noised_results, inner_query_state = self._inner_query.get_noised_result( + sample_state, global_state.inner_query_state) + restart_flag, indicator_state = self._restart_indicator.next( + global_state.indicator_state) + if restart_flag: + inner_query_state = self._inner_query.reset_state(noised_results, + inner_query_state) + return noised_results, self._GlobalState(inner_query_state, indicator_state) + + def derive_metrics(self, global_state): + """Implements `tensorflow_privacy.DPQuery.derive_metrics`.""" + return self._inner_query.derive_metrics(global_state.inner_query_state) diff --git a/tensorflow_privacy/privacy/dp_query/restart_query_test.py b/tensorflow_privacy/privacy/dp_query/restart_query_test.py new file mode 100644 index 0000000..ef57a2b --- /dev/null +++ b/tensorflow_privacy/privacy/dp_query/restart_query_test.py @@ -0,0 +1,126 @@ +# Copyright 2021, Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for `restart_query`.""" +from absl.testing import parameterized + +import tensorflow as tf +from tensorflow_privacy.privacy.dp_query import restart_query +from tensorflow_privacy.privacy.dp_query import tree_aggregation_query + + +class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters(('zero', 0), ('negative', -1)) + def test_round_raise(self, frequency): + with self.assertRaisesRegex( + ValueError, 'Restart frequency should be equal or larger than 1'): + restart_query.PeriodicRoundRestartIndicator(frequency) + + @parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5)) + def test_round_indicator(self, frequency): + total_steps = 20 + indicator = restart_query.PeriodicRoundRestartIndicator(frequency) + state = indicator.initialize() + for i in range(total_steps): + flag, state = indicator.next(state) + if i % frequency == frequency - 1: + self.assertTrue(flag) + else: + self.assertFalse(flag) + + +def _get_l2_clip_fn(): + + def l2_clip_fn(record_as_list, clip_value): + clipped_record, _ = tf.clip_by_global_norm(record_as_list, clip_value) + return clipped_record + + return l2_clip_fn + + +class RestartQueryTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + ('s0t1f1', 0., 1., 1), + ('s0t1f2', 0., 1., 2), + ('s0t1f5', 0., 1., 5), + ('s1t1f5', 1., 1., 5), + ('s1t2f2', 1., 2., 2), + ('s1t5f6', 1., 5., 6), + ) + def test_sum_scalar_tree_aggregation_reset(self, scalar_value, + tree_node_value, frequency): + total_steps = 20 + indicator = restart_query.PeriodicRoundRestartIndicator(frequency) + query = tree_aggregation_query.TreeCumulativeSumQuery( + clip_fn=_get_l2_clip_fn(), + clip_value=scalar_value + 1., # no clip + noise_generator=lambda: tree_node_value, + record_specs=tf.TensorSpec([]), + use_efficient=False) + query = restart_query.RestartQuery(query, indicator) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + for i in range(total_steps): + sample_state = query.initial_sample_state(scalar_value) + sample_state = query.accumulate_record(params, sample_state, scalar_value) + query_result, global_state = query.get_noised_result( + sample_state, global_state) + # Expected value is the combination of cumsum of signal; sum of trees + # that have been reset; current tree sum. The tree aggregation value can + # be inferred from the binary representation of the current step. + expected = ( + scalar_value * (i + 1) + + i // frequency * tree_node_value * bin(frequency)[2:].count('1') + + tree_node_value * bin(i % frequency + 1)[2:].count('1')) + self.assertEqual(query_result, expected) + + @parameterized.named_parameters( + ('s0t1f1', 0., 1., 1), + ('s0t1f2', 0., 1., 2), + ('s0t1f5', 0., 1., 5), + ('s1t1f5', 1., 1., 5), + ('s1t2f2', 1., 2., 2), + ('s1t5f6', 1., 5., 6), + ) + def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value, + frequency): + total_steps = 20 + indicator = restart_query.PeriodicRoundRestartIndicator(frequency) + query = tree_aggregation_query.TreeResidualSumQuery( + clip_fn=_get_l2_clip_fn(), + clip_value=scalar_value + 1., # no clip + noise_generator=lambda: tree_node_value, + record_specs=tf.TensorSpec([]), + use_efficient=False) + query = restart_query.RestartQuery(query, indicator) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + for i in range(total_steps): + sample_state = query.initial_sample_state(scalar_value) + sample_state = query.accumulate_record(params, sample_state, scalar_value) + query_result, global_state = query.get_noised_result( + sample_state, global_state) + # Expected value is the signal of the current round plus the residual of + # two continous tree aggregation values. The tree aggregation value can + # be inferred from the binary representation of the current step. + expected = scalar_value + tree_node_value * ( + bin(i % frequency + 1)[2:].count('1') - + bin(i % frequency)[2:].count('1')) + print(i, query_result, expected) + self.assertEqual(query_result, expected) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation.py index e4cc35f..0842975 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation.py @@ -171,78 +171,6 @@ class StatelessValueGenerator(ValueGenerator): return self.value_fn(), state -# TODO(b/192464750): define `RestartQuery` and move `RestartIndicator` to be -# in the same module. - - -class RestartIndicator(metaclass=abc.ABCMeta): - """Base class establishing interface for restarting the tree state. - - A `RestartIndicator` maintains a state, and each time `next` is called, a bool - value is generated to indicate whether to restart, and the indicator state is - advanced. - """ - - @abc.abstractmethod - def initialize(self): - """Makes an initialized state for `RestartIndicator`. - - Returns: - An initial state. - """ - raise NotImplementedError - - @abc.abstractmethod - def next(self, state): - """Gets next bool indicator and advances the `RestartIndicator` state. - - Args: - state: The current state. - - Returns: - A pair (value, new_state) where value is bool indicator and new_state - is the advanced state. - """ - raise NotImplementedError - - -class PeriodicRoundRestartIndicator(RestartIndicator): - """Indicator for resetting the tree state after every a few number of queries. - - The indicator will maintain an internal counter as state. - """ - - def __init__(self, frequency: int): - """Construct the `PeriodicRoundRestartIndicator`. - - Args: - frequency: The `next` function will return `True` every `frequency` number - of `next` calls. - """ - if frequency < 1: - raise ValueError('Restart frequency should be equal or larger than 1 ' - f'got {frequency}') - self.frequency = tf.constant(frequency, tf.int32) - - def initialize(self): - """Returns initialized state of 0 for `PeriodicRoundRestartIndicator`.""" - return tf.constant(0, tf.int32) - - def next(self, state): - """Gets next bool indicator and advances the state. - - Args: - state: The current state. - - Returns: - A pair (value, new_state) where value is the bool indicator and new_state - of `state+1`. - """ - state = state + tf.constant(1, tf.int32) - flag = state % self.frequency == 0 - return flag, state - - @attr.s(eq=False, frozen=True, slots=True) class TreeState(object): """Class defining state of the tree. diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index 3120eea..4907585 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -72,8 +72,6 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): node. Noise stdandard deviation is specified outside the `dp_query` by the user when defining `noise_fn` and should have order O(clip_norm*log(T)/eps) to guarantee eps-DP. - restart_indicator: `tree_aggregation.RestartIndicator` to generate the - boolean indicator for resetting the tree state. """ @attr.s(frozen=True) @@ -85,21 +83,17 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): each level state. clip_value: The clipping value to be passed to clip_fn. samples_cumulative_sum: Noiseless cumulative sum of samples over time. - restarter_state: Current state of the restarter to indicate whether - the tree state will be reset. """ tree_state = attr.ib() clip_value = attr.ib() samples_cumulative_sum = attr.ib() - restarter_state = attr.ib() def __init__(self, record_specs, noise_generator, clip_fn, clip_value, - use_efficient=True, - restart_indicator=None): + use_efficient=True): """Initializes the `TreeCumulativeSumQuery`. Consider using `build_l2_gaussian_query` for the construction of a @@ -117,8 +111,6 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): use_efficient: Boolean indicating the usage of the efficient tree aggregation algorithm based on the paper "Efficient Use of Differentially Private Binary Trees". - restart_indicator: `tree_aggregation.RestartIndicator` to generate the - boolean indicator for resetting the tree state. """ self._clip_fn = clip_fn self._clip_value = clip_value @@ -128,21 +120,16 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): noise_generator) else: self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator) - self._restart_indicator = restart_indicator def initial_global_state(self): """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" initial_tree_state = self._tree_aggregator.init_state() initial_samples_cumulative_sum = tf.nest.map_structure( lambda spec: tf.zeros(spec.shape), self._record_specs) - restarter_state = () - if self._restart_indicator is not None: - restarter_state = self._restart_indicator.initialize() return TreeCumulativeSumQuery.GlobalState( tree_state=initial_tree_state, clip_value=tf.constant(self._clip_value, tf.float32), - samples_cumulative_sum=initial_samples_cumulative_sum, - restarter_state=restarter_state) + samples_cumulative_sum=initial_samples_cumulative_sum) def derive_sample_params(self, global_state): """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" @@ -185,28 +172,41 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): global_state.tree_state) noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum, cumulative_sum_noise) - restarter_state = global_state.restarter_state - if self._restart_indicator is not None: - restart_flag, restarter_state = self._restart_indicator.next( - restarter_state) - if restart_flag: - new_cumulative_sum = noised_cumulative_sum - new_tree_state = self._tree_aggregator.reset_state(new_tree_state) new_global_state = attr.evolve( global_state, samples_cumulative_sum=new_cumulative_sum, - tree_state=new_tree_state, - restarter_state=restarter_state) + tree_state=new_tree_state) return noised_cumulative_sum, new_global_state + def reset_state(self, noised_results, global_state): + """Returns state after resetting the tree. + + This function will be used in `restart_query.RestartQuery` after calling + `get_noised_result` when the restarting condition is met. + + Args: + noised_results: Noised cumulative sum returned by `get_noised_result`. + global_state: Updated global state returned by `get_noised_result`, which + has current sample's cumulative sum and tree state for the next + cumulative sum. + + Returns: + New global state with current noised cumulative sum and restarted tree + state for the next cumulative sum. + """ + new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state) + return attr.evolve( + global_state, + samples_cumulative_sum=noised_results, + tree_state=new_tree_state) + @classmethod def build_l2_gaussian_query(cls, clip_norm, noise_multiplier, record_specs, noise_seed=None, - use_efficient=True, - restart_indicator=None): + use_efficient=True): """Returns a query instance with L2 norm clipping and Gaussian noise. Args: @@ -221,8 +221,6 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): use_efficient: Boolean indicating the usage of the efficient tree aggregation algorithm based on the paper "Efficient Use of Differentially Private Binary Trees". - restart_indicator: `tree_aggregation.RestartIndicator` to generate the - boolean indicator for resetting the tree state. """ if clip_norm <= 0: raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.') @@ -245,8 +243,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): clip_value=clip_norm, record_specs=record_specs, noise_generator=gaussian_noise_generator, - use_efficient=use_efficient, - restart_indicator=restart_indicator) + use_efficient=use_efficient) class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): @@ -300,8 +297,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): node. Noise stdandard deviation is specified outside the `dp_query` by the user when defining `noise_fn` and should have order O(clip_norm*log(T)/eps) to guarantee eps-DP. - restart_indicator: `tree_aggregation.RestartIndicator` to generate the - boolean indicator for resetting the tree state. """ @attr.s(frozen=True) @@ -314,21 +309,17 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): clip_value: The clipping value to be passed to clip_fn. previous_tree_noise: Cumulative noise by tree aggregation from the previous time the query is called on a sample. - restarter_state: Current state of the restarter to indicate whether - the tree state will be reset. """ tree_state = attr.ib() clip_value = attr.ib() previous_tree_noise = attr.ib() - restarter_state = attr.ib() def __init__(self, record_specs, noise_generator, clip_fn, clip_value, - use_efficient=True, - restart_indicator=None): + use_efficient=True): """Initializes the `TreeCumulativeSumQuery`. Consider using `build_l2_gaussian_query` for the construction of a @@ -346,8 +337,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): use_efficient: Boolean indicating the usage of the efficient tree aggregation algorithm based on the paper "Efficient Use of Differentially Private Binary Trees". - restart_indicator: `tree_aggregation.RestartIndicator` to generate the - boolean indicator for resetting the tree state. """ self._clip_fn = clip_fn self._clip_value = clip_value @@ -357,7 +346,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): noise_generator) else: self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator) - self._restart_indicator = restart_indicator def _zero_initial_noise(self): return tf.nest.map_structure(lambda spec: tf.zeros(spec.shape), @@ -366,14 +354,10 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): def initial_global_state(self): """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" initial_tree_state = self._tree_aggregator.init_state() - restarter_state = () - if self._restart_indicator is not None: - restarter_state = self._restart_indicator.initialize() return TreeResidualSumQuery.GlobalState( tree_state=initial_tree_state, clip_value=tf.constant(self._clip_value, tf.float32), - previous_tree_noise=self._zero_initial_noise(), - restarter_state=restarter_state) + previous_tree_noise=self._zero_initial_noise()) def derive_sample_params(self, global_state): """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" @@ -412,28 +396,39 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c, sample_state, tree_noise, global_state.previous_tree_noise) - restarter_state = global_state.restarter_state - if self._restart_indicator is not None: - restart_flag, restarter_state = self._restart_indicator.next( - restarter_state) - if restart_flag: - tree_noise = self._zero_initial_noise() - new_tree_state = self._tree_aggregator.reset_state(new_tree_state) new_global_state = attr.evolve( - global_state, - previous_tree_noise=tree_noise, - tree_state=new_tree_state, - restarter_state=restarter_state) + global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state) return noised_sample, new_global_state + def reset_state(self, noised_results, global_state): + """Returns state after resetting the tree. + + This function will be used in `restart_query.RestartQuery` after calling + `get_noised_result` when the restarting condition is met. + + Args: + noised_results: Noised cumulative sum returned by `get_noised_result`. + global_state: Updated global state returned by `get_noised_result`, which + records noise for the conceptual cumulative sum of the current leaf + node, and tree state for the next conceptual cumulative sum. + + Returns: + New global state with zero noise and restarted tree state. + """ + del noised_results + new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state) + return attr.evolve( + global_state, + previous_tree_noise=self._zero_initial_noise(), + tree_state=new_tree_state) + @classmethod def build_l2_gaussian_query(cls, clip_norm, noise_multiplier, record_specs, noise_seed=None, - use_efficient=True, - restart_indicator=None): + use_efficient=True): """Returns `TreeResidualSumQuery` with L2 norm clipping and Gaussian noise. Args: @@ -448,8 +443,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): use_efficient: Boolean indicating the usage of the efficient tree aggregation algorithm based on the paper "Efficient Use of Differentially Private Binary Trees". - restart_indicator: `tree_aggregation.RestartIndicator` to generate the - boolean indicator for resetting the tree state. """ if clip_norm <= 0: raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.') @@ -472,8 +465,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): clip_value=clip_norm, record_specs=record_specs, noise_generator=gaussian_noise_generator, - use_efficient=use_efficient, - restart_indicator=restart_indicator) + use_efficient=use_efficient) # TODO(b/197596864): Remove `TreeRangeSumQuery` from this file after the next diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py index f88ed90..65ab076 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py @@ -303,15 +303,12 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase): def test_sum_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value, frequency): total_steps = 20 - indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency) query = tree_aggregation_query.TreeCumulativeSumQuery( clip_fn=_get_l2_clip_fn(), clip_value=scalar_value + 1., # no clip noise_generator=lambda: tree_node_value, record_specs=tf.TensorSpec([]), - use_efficient=False, - restart_indicator=indicator, - ) + use_efficient=False) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) for i in range(total_steps): @@ -319,6 +316,8 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase): sample_state = query.accumulate_record(params, sample_state, scalar_value) query_result, global_state = query.get_noised_result( sample_state, global_state) + if i % frequency == frequency - 1: + global_state = query.reset_state(query_result, global_state) # Expected value is the combination of cumsum of signal; sum of trees # that have been reset; current tree sum. The tree aggregation value can # be inferred from the binary representation of the current step. @@ -446,15 +445,12 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase): def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value, frequency): total_steps = 20 - indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency) query = tree_aggregation_query.TreeResidualSumQuery( clip_fn=_get_l2_clip_fn(), clip_value=scalar_value + 1., # no clip noise_generator=lambda: tree_node_value, record_specs=tf.TensorSpec([]), - use_efficient=False, - restart_indicator=indicator, - ) + use_efficient=False) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) for i in range(total_steps): @@ -462,6 +458,8 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase): sample_state = query.accumulate_record(params, sample_state, scalar_value) query_result, global_state = query.get_noised_result( sample_state, global_state) + if i % frequency == frequency - 1: + global_state = query.reset_state(query_result, global_state) # Expected value is the signal of the current round plus the residual of # two continous tree aggregation values. The tree aggregation value can # be inferred from the binary representation of the current step. diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py index 47be880..2f6ad82 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py @@ -396,26 +396,5 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase, parameterized.TestCase): self.assertNotAllEqual(gstate.seeds, prev_gstate.seeds) -class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase): - - @parameterized.named_parameters(('zero', 0), ('negative', -1)) - def test_round_raise(self, frequency): - with self.assertRaisesRegex( - ValueError, 'Restart frequency should be equal or larger than 1'): - tree_aggregation.PeriodicRoundRestartIndicator(frequency) - - @parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5)) - def test_round_indicator(self, frequency): - total_steps = 20 - indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency) - state = indicator.initialize() - for i in range(total_steps): - flag, state = indicator.next(state) - if i % frequency == frequency - 1: - self.assertTrue(flag) - else: - self.assertFalse(flag) - - if __name__ == '__main__': tf.test.main()