Define RestartQuery for easy composition to restart tree in tree aggregation queries.

PiperOrigin-RevId: 394106175
This commit is contained in:
Zheng Xu 2021-08-31 16:05:57 -07:00 committed by A. Unique TensorFlower
parent 789a05df63
commit 6ac4bc8d01
7 changed files with 336 additions and 163 deletions

View file

@ -62,7 +62,9 @@ else:
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import QuantileEstimatorQuery from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import QuantileEstimatorQuery
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import NoPrivacyQuantileEstimatorQuery from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import NoPrivacyQuantileEstimatorQuery
from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery
from tensorflow_privacy.privacy.dp_query import restart_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation from tensorflow_privacy.privacy.dp_query import tree_aggregation
from tensorflow_privacy.privacy.dp_query.restart_query import RestartQuery
from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeCumulativeSumQuery from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeCumulativeSumQuery
from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeResidualSumQuery from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeResidualSumQuery
from tensorflow_privacy.privacy.dp_query.tree_range_query import TreeRangeSumQuery from tensorflow_privacy.privacy.dp_query.tree_range_query import TreeRangeSumQuery

View file

@ -0,0 +1,148 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements DPQuery interface for restarting the states of another query.
This query is used to compose with a DPQuery that has `reset_state` function.
"""
import abc
import collections
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
class RestartIndicator(metaclass=abc.ABCMeta):
"""Base class establishing interface for restarting the tree state.
A `RestartIndicator` maintains a state, and each time `next` is called, a bool
value is generated to indicate whether to restart, and the indicator state is
advanced.
"""
@abc.abstractmethod
def initialize(self):
"""Makes an initialized state for `RestartIndicator`.
Returns:
An initial state.
"""
raise NotImplementedError
@abc.abstractmethod
def next(self, state):
"""Gets next bool indicator and advances the `RestartIndicator` state.
Args:
state: The current state.
Returns:
A pair (value, new_state) where value is bool indicator and new_state
is the advanced state.
"""
raise NotImplementedError
class PeriodicRoundRestartIndicator(RestartIndicator):
"""Indicator for resetting the tree state after every a few number of queries.
The indicator will maintain an internal counter as state.
"""
def __init__(self, frequency: int):
"""Construct the `PeriodicRoundRestartIndicator`.
Args:
frequency: The `next` function will return `True` every `frequency` number
of `next` calls.
"""
if frequency < 1:
raise ValueError('Restart frequency should be equal or larger than 1 '
f'got {frequency}')
self.frequency = tf.constant(frequency, tf.int32)
def initialize(self):
"""Returns initialized state of 0 for `PeriodicRoundRestartIndicator`."""
return tf.constant(0, tf.int32)
def next(self, state):
"""Gets next bool indicator and advances the state.
Args:
state: The current state.
Returns:
A pair (value, new_state) where value is the bool indicator and new_state
of `state+1`.
"""
state = state + tf.constant(1, tf.int32)
flag = state % self.frequency == 0
return flag, state
class RestartQuery(dp_query.SumAggregationDPQuery):
"""`DPQuery` for `SumAggregationDPQuery` with a `reset_state` function."""
# pylint: disable=invalid-name
_GlobalState = collections.namedtuple(
'_GlobalState', ['inner_query_state', 'indicator_state'])
def __init__(self, inner_query: dp_query.SumAggregationDPQuery,
restart_indicator: RestartIndicator):
"""Initializes `RestartQuery`.
Args:
inner_query: A `SumAggregationDPQuery` has `reset_state` attribute.
restart_indicator: A `RestartIndicator` to generate the boolean indicator
for resetting the state.
"""
if not hasattr(inner_query, 'reset_state'):
raise ValueError(f'{type(inner_query)} must define `reset_state` to be '
'composed with `RestartQuery`.')
self._inner_query = inner_query
self._restart_indicator = restart_indicator
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return self._GlobalState(
inner_query_state=self._inner_query.initial_global_state(),
indicator_state=self._restart_indicator.initialize())
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return self._inner_query.derive_sample_params(
global_state.inner_query_state)
def initial_sample_state(self, template):
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
return self._inner_query.initial_sample_state(template)
def preprocess_record(self, params, record):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
return self._inner_query.preprocess_record(params, record)
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
noised_results, inner_query_state = self._inner_query.get_noised_result(
sample_state, global_state.inner_query_state)
restart_flag, indicator_state = self._restart_indicator.next(
global_state.indicator_state)
if restart_flag:
inner_query_state = self._inner_query.reset_state(noised_results,
inner_query_state)
return noised_results, self._GlobalState(inner_query_state, indicator_state)
def derive_metrics(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
return self._inner_query.derive_metrics(global_state.inner_query_state)

View file

@ -0,0 +1,126 @@
# Copyright 2021, Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for `restart_query`."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import restart_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('zero', 0), ('negative', -1))
def test_round_raise(self, frequency):
with self.assertRaisesRegex(
ValueError, 'Restart frequency should be equal or larger than 1'):
restart_query.PeriodicRoundRestartIndicator(frequency)
@parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5))
def test_round_indicator(self, frequency):
total_steps = 20
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
state = indicator.initialize()
for i in range(total_steps):
flag, state = indicator.next(state)
if i % frequency == frequency - 1:
self.assertTrue(flag)
else:
self.assertFalse(flag)
def _get_l2_clip_fn():
def l2_clip_fn(record_as_list, clip_value):
clipped_record, _ = tf.clip_by_global_norm(record_as_list, clip_value)
return clipped_record
return l2_clip_fn
class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('s0t1f1', 0., 1., 1),
('s0t1f2', 0., 1., 2),
('s0t1f5', 0., 1., 5),
('s1t1f5', 1., 1., 5),
('s1t2f2', 1., 2., 2),
('s1t5f6', 1., 5., 6),
)
def test_sum_scalar_tree_aggregation_reset(self, scalar_value,
tree_node_value, frequency):
total_steps = 20
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
query = tree_aggregation_query.TreeCumulativeSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip
noise_generator=lambda: tree_node_value,
record_specs=tf.TensorSpec([]),
use_efficient=False)
query = restart_query.RestartQuery(query, indicator)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result(
sample_state, global_state)
# Expected value is the combination of cumsum of signal; sum of trees
# that have been reset; current tree sum. The tree aggregation value can
# be inferred from the binary representation of the current step.
expected = (
scalar_value * (i + 1) +
i // frequency * tree_node_value * bin(frequency)[2:].count('1') +
tree_node_value * bin(i % frequency + 1)[2:].count('1'))
self.assertEqual(query_result, expected)
@parameterized.named_parameters(
('s0t1f1', 0., 1., 1),
('s0t1f2', 0., 1., 2),
('s0t1f5', 0., 1., 5),
('s1t1f5', 1., 1., 5),
('s1t2f2', 1., 2., 2),
('s1t5f6', 1., 5., 6),
)
def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value,
frequency):
total_steps = 20
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
query = tree_aggregation_query.TreeResidualSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip
noise_generator=lambda: tree_node_value,
record_specs=tf.TensorSpec([]),
use_efficient=False)
query = restart_query.RestartQuery(query, indicator)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result(
sample_state, global_state)
# Expected value is the signal of the current round plus the residual of
# two continous tree aggregation values. The tree aggregation value can
# be inferred from the binary representation of the current step.
expected = scalar_value + tree_node_value * (
bin(i % frequency + 1)[2:].count('1') -
bin(i % frequency)[2:].count('1'))
print(i, query_result, expected)
self.assertEqual(query_result, expected)
if __name__ == '__main__':
tf.test.main()

View file

@ -171,78 +171,6 @@ class StatelessValueGenerator(ValueGenerator):
return self.value_fn(), state return self.value_fn(), state
# TODO(b/192464750): define `RestartQuery` and move `RestartIndicator` to be
# in the same module.
class RestartIndicator(metaclass=abc.ABCMeta):
"""Base class establishing interface for restarting the tree state.
A `RestartIndicator` maintains a state, and each time `next` is called, a bool
value is generated to indicate whether to restart, and the indicator state is
advanced.
"""
@abc.abstractmethod
def initialize(self):
"""Makes an initialized state for `RestartIndicator`.
Returns:
An initial state.
"""
raise NotImplementedError
@abc.abstractmethod
def next(self, state):
"""Gets next bool indicator and advances the `RestartIndicator` state.
Args:
state: The current state.
Returns:
A pair (value, new_state) where value is bool indicator and new_state
is the advanced state.
"""
raise NotImplementedError
class PeriodicRoundRestartIndicator(RestartIndicator):
"""Indicator for resetting the tree state after every a few number of queries.
The indicator will maintain an internal counter as state.
"""
def __init__(self, frequency: int):
"""Construct the `PeriodicRoundRestartIndicator`.
Args:
frequency: The `next` function will return `True` every `frequency` number
of `next` calls.
"""
if frequency < 1:
raise ValueError('Restart frequency should be equal or larger than 1 '
f'got {frequency}')
self.frequency = tf.constant(frequency, tf.int32)
def initialize(self):
"""Returns initialized state of 0 for `PeriodicRoundRestartIndicator`."""
return tf.constant(0, tf.int32)
def next(self, state):
"""Gets next bool indicator and advances the state.
Args:
state: The current state.
Returns:
A pair (value, new_state) where value is the bool indicator and new_state
of `state+1`.
"""
state = state + tf.constant(1, tf.int32)
flag = state % self.frequency == 0
return flag, state
@attr.s(eq=False, frozen=True, slots=True) @attr.s(eq=False, frozen=True, slots=True)
class TreeState(object): class TreeState(object):
"""Class defining state of the tree. """Class defining state of the tree.

View file

@ -72,8 +72,6 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
node. Noise stdandard deviation is specified outside the `dp_query` by the node. Noise stdandard deviation is specified outside the `dp_query` by the
user when defining `noise_fn` and should have order user when defining `noise_fn` and should have order
O(clip_norm*log(T)/eps) to guarantee eps-DP. O(clip_norm*log(T)/eps) to guarantee eps-DP.
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
""" """
@attr.s(frozen=True) @attr.s(frozen=True)
@ -85,21 +83,17 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
each level state. each level state.
clip_value: The clipping value to be passed to clip_fn. clip_value: The clipping value to be passed to clip_fn.
samples_cumulative_sum: Noiseless cumulative sum of samples over time. samples_cumulative_sum: Noiseless cumulative sum of samples over time.
restarter_state: Current state of the restarter to indicate whether
the tree state will be reset.
""" """
tree_state = attr.ib() tree_state = attr.ib()
clip_value = attr.ib() clip_value = attr.ib()
samples_cumulative_sum = attr.ib() samples_cumulative_sum = attr.ib()
restarter_state = attr.ib()
def __init__(self, def __init__(self,
record_specs, record_specs,
noise_generator, noise_generator,
clip_fn, clip_fn,
clip_value, clip_value,
use_efficient=True, use_efficient=True):
restart_indicator=None):
"""Initializes the `TreeCumulativeSumQuery`. """Initializes the `TreeCumulativeSumQuery`.
Consider using `build_l2_gaussian_query` for the construction of a Consider using `build_l2_gaussian_query` for the construction of a
@ -117,8 +111,6 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
use_efficient: Boolean indicating the usage of the efficient tree use_efficient: Boolean indicating the usage of the efficient tree
aggregation algorithm based on the paper "Efficient Use of aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees". Differentially Private Binary Trees".
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
""" """
self._clip_fn = clip_fn self._clip_fn = clip_fn
self._clip_value = clip_value self._clip_value = clip_value
@ -128,21 +120,16 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
noise_generator) noise_generator)
else: else:
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator) self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
self._restart_indicator = restart_indicator
def initial_global_state(self): def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" """Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
initial_tree_state = self._tree_aggregator.init_state() initial_tree_state = self._tree_aggregator.init_state()
initial_samples_cumulative_sum = tf.nest.map_structure( initial_samples_cumulative_sum = tf.nest.map_structure(
lambda spec: tf.zeros(spec.shape), self._record_specs) lambda spec: tf.zeros(spec.shape), self._record_specs)
restarter_state = ()
if self._restart_indicator is not None:
restarter_state = self._restart_indicator.initialize()
return TreeCumulativeSumQuery.GlobalState( return TreeCumulativeSumQuery.GlobalState(
tree_state=initial_tree_state, tree_state=initial_tree_state,
clip_value=tf.constant(self._clip_value, tf.float32), clip_value=tf.constant(self._clip_value, tf.float32),
samples_cumulative_sum=initial_samples_cumulative_sum, samples_cumulative_sum=initial_samples_cumulative_sum)
restarter_state=restarter_state)
def derive_sample_params(self, global_state): def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" """Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
@ -185,28 +172,41 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
global_state.tree_state) global_state.tree_state)
noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum, noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
cumulative_sum_noise) cumulative_sum_noise)
restarter_state = global_state.restarter_state
if self._restart_indicator is not None:
restart_flag, restarter_state = self._restart_indicator.next(
restarter_state)
if restart_flag:
new_cumulative_sum = noised_cumulative_sum
new_tree_state = self._tree_aggregator.reset_state(new_tree_state)
new_global_state = attr.evolve( new_global_state = attr.evolve(
global_state, global_state,
samples_cumulative_sum=new_cumulative_sum, samples_cumulative_sum=new_cumulative_sum,
tree_state=new_tree_state, tree_state=new_tree_state)
restarter_state=restarter_state)
return noised_cumulative_sum, new_global_state return noised_cumulative_sum, new_global_state
def reset_state(self, noised_results, global_state):
"""Returns state after resetting the tree.
This function will be used in `restart_query.RestartQuery` after calling
`get_noised_result` when the restarting condition is met.
Args:
noised_results: Noised cumulative sum returned by `get_noised_result`.
global_state: Updated global state returned by `get_noised_result`, which
has current sample's cumulative sum and tree state for the next
cumulative sum.
Returns:
New global state with current noised cumulative sum and restarted tree
state for the next cumulative sum.
"""
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
return attr.evolve(
global_state,
samples_cumulative_sum=noised_results,
tree_state=new_tree_state)
@classmethod @classmethod
def build_l2_gaussian_query(cls, def build_l2_gaussian_query(cls,
clip_norm, clip_norm,
noise_multiplier, noise_multiplier,
record_specs, record_specs,
noise_seed=None, noise_seed=None,
use_efficient=True, use_efficient=True):
restart_indicator=None):
"""Returns a query instance with L2 norm clipping and Gaussian noise. """Returns a query instance with L2 norm clipping and Gaussian noise.
Args: Args:
@ -221,8 +221,6 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
use_efficient: Boolean indicating the usage of the efficient tree use_efficient: Boolean indicating the usage of the efficient tree
aggregation algorithm based on the paper "Efficient Use of aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees". Differentially Private Binary Trees".
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
""" """
if clip_norm <= 0: if clip_norm <= 0:
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.') raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
@ -245,8 +243,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
clip_value=clip_norm, clip_value=clip_norm,
record_specs=record_specs, record_specs=record_specs,
noise_generator=gaussian_noise_generator, noise_generator=gaussian_noise_generator,
use_efficient=use_efficient, use_efficient=use_efficient)
restart_indicator=restart_indicator)
class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
@ -300,8 +297,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
node. Noise stdandard deviation is specified outside the `dp_query` by the node. Noise stdandard deviation is specified outside the `dp_query` by the
user when defining `noise_fn` and should have order user when defining `noise_fn` and should have order
O(clip_norm*log(T)/eps) to guarantee eps-DP. O(clip_norm*log(T)/eps) to guarantee eps-DP.
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
""" """
@attr.s(frozen=True) @attr.s(frozen=True)
@ -314,21 +309,17 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
clip_value: The clipping value to be passed to clip_fn. clip_value: The clipping value to be passed to clip_fn.
previous_tree_noise: Cumulative noise by tree aggregation from the previous_tree_noise: Cumulative noise by tree aggregation from the
previous time the query is called on a sample. previous time the query is called on a sample.
restarter_state: Current state of the restarter to indicate whether
the tree state will be reset.
""" """
tree_state = attr.ib() tree_state = attr.ib()
clip_value = attr.ib() clip_value = attr.ib()
previous_tree_noise = attr.ib() previous_tree_noise = attr.ib()
restarter_state = attr.ib()
def __init__(self, def __init__(self,
record_specs, record_specs,
noise_generator, noise_generator,
clip_fn, clip_fn,
clip_value, clip_value,
use_efficient=True, use_efficient=True):
restart_indicator=None):
"""Initializes the `TreeCumulativeSumQuery`. """Initializes the `TreeCumulativeSumQuery`.
Consider using `build_l2_gaussian_query` for the construction of a Consider using `build_l2_gaussian_query` for the construction of a
@ -346,8 +337,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
use_efficient: Boolean indicating the usage of the efficient tree use_efficient: Boolean indicating the usage of the efficient tree
aggregation algorithm based on the paper "Efficient Use of aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees". Differentially Private Binary Trees".
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
""" """
self._clip_fn = clip_fn self._clip_fn = clip_fn
self._clip_value = clip_value self._clip_value = clip_value
@ -357,7 +346,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
noise_generator) noise_generator)
else: else:
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator) self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
self._restart_indicator = restart_indicator
def _zero_initial_noise(self): def _zero_initial_noise(self):
return tf.nest.map_structure(lambda spec: tf.zeros(spec.shape), return tf.nest.map_structure(lambda spec: tf.zeros(spec.shape),
@ -366,14 +354,10 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
def initial_global_state(self): def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" """Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
initial_tree_state = self._tree_aggregator.init_state() initial_tree_state = self._tree_aggregator.init_state()
restarter_state = ()
if self._restart_indicator is not None:
restarter_state = self._restart_indicator.initialize()
return TreeResidualSumQuery.GlobalState( return TreeResidualSumQuery.GlobalState(
tree_state=initial_tree_state, tree_state=initial_tree_state,
clip_value=tf.constant(self._clip_value, tf.float32), clip_value=tf.constant(self._clip_value, tf.float32),
previous_tree_noise=self._zero_initial_noise(), previous_tree_noise=self._zero_initial_noise())
restarter_state=restarter_state)
def derive_sample_params(self, global_state): def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" """Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
@ -412,28 +396,39 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c, noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c,
sample_state, tree_noise, sample_state, tree_noise,
global_state.previous_tree_noise) global_state.previous_tree_noise)
restarter_state = global_state.restarter_state
if self._restart_indicator is not None:
restart_flag, restarter_state = self._restart_indicator.next(
restarter_state)
if restart_flag:
tree_noise = self._zero_initial_noise()
new_tree_state = self._tree_aggregator.reset_state(new_tree_state)
new_global_state = attr.evolve( new_global_state = attr.evolve(
global_state, global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
previous_tree_noise=tree_noise,
tree_state=new_tree_state,
restarter_state=restarter_state)
return noised_sample, new_global_state return noised_sample, new_global_state
def reset_state(self, noised_results, global_state):
"""Returns state after resetting the tree.
This function will be used in `restart_query.RestartQuery` after calling
`get_noised_result` when the restarting condition is met.
Args:
noised_results: Noised cumulative sum returned by `get_noised_result`.
global_state: Updated global state returned by `get_noised_result`, which
records noise for the conceptual cumulative sum of the current leaf
node, and tree state for the next conceptual cumulative sum.
Returns:
New global state with zero noise and restarted tree state.
"""
del noised_results
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
return attr.evolve(
global_state,
previous_tree_noise=self._zero_initial_noise(),
tree_state=new_tree_state)
@classmethod @classmethod
def build_l2_gaussian_query(cls, def build_l2_gaussian_query(cls,
clip_norm, clip_norm,
noise_multiplier, noise_multiplier,
record_specs, record_specs,
noise_seed=None, noise_seed=None,
use_efficient=True, use_efficient=True):
restart_indicator=None):
"""Returns `TreeResidualSumQuery` with L2 norm clipping and Gaussian noise. """Returns `TreeResidualSumQuery` with L2 norm clipping and Gaussian noise.
Args: Args:
@ -448,8 +443,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
use_efficient: Boolean indicating the usage of the efficient tree use_efficient: Boolean indicating the usage of the efficient tree
aggregation algorithm based on the paper "Efficient Use of aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees". Differentially Private Binary Trees".
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
""" """
if clip_norm <= 0: if clip_norm <= 0:
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.') raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
@ -472,8 +465,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
clip_value=clip_norm, clip_value=clip_norm,
record_specs=record_specs, record_specs=record_specs,
noise_generator=gaussian_noise_generator, noise_generator=gaussian_noise_generator,
use_efficient=use_efficient, use_efficient=use_efficient)
restart_indicator=restart_indicator)
# TODO(b/197596864): Remove `TreeRangeSumQuery` from this file after the next # TODO(b/197596864): Remove `TreeRangeSumQuery` from this file after the next

View file

@ -303,15 +303,12 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_sum_scalar_tree_aggregation_reset(self, scalar_value, def test_sum_scalar_tree_aggregation_reset(self, scalar_value,
tree_node_value, frequency): tree_node_value, frequency):
total_steps = 20 total_steps = 20
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
query = tree_aggregation_query.TreeCumulativeSumQuery( query = tree_aggregation_query.TreeCumulativeSumQuery(
clip_fn=_get_l2_clip_fn(), clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip clip_value=scalar_value + 1., # no clip
noise_generator=lambda: tree_node_value, noise_generator=lambda: tree_node_value,
record_specs=tf.TensorSpec([]), record_specs=tf.TensorSpec([]),
use_efficient=False, use_efficient=False)
restart_indicator=indicator,
)
global_state = query.initial_global_state() global_state = query.initial_global_state()
params = query.derive_sample_params(global_state) params = query.derive_sample_params(global_state)
for i in range(total_steps): for i in range(total_steps):
@ -319,6 +316,8 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
sample_state = query.accumulate_record(params, sample_state, scalar_value) sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result( query_result, global_state = query.get_noised_result(
sample_state, global_state) sample_state, global_state)
if i % frequency == frequency - 1:
global_state = query.reset_state(query_result, global_state)
# Expected value is the combination of cumsum of signal; sum of trees # Expected value is the combination of cumsum of signal; sum of trees
# that have been reset; current tree sum. The tree aggregation value can # that have been reset; current tree sum. The tree aggregation value can
# be inferred from the binary representation of the current step. # be inferred from the binary representation of the current step.
@ -446,15 +445,12 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value, def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value,
frequency): frequency):
total_steps = 20 total_steps = 20
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
query = tree_aggregation_query.TreeResidualSumQuery( query = tree_aggregation_query.TreeResidualSumQuery(
clip_fn=_get_l2_clip_fn(), clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip clip_value=scalar_value + 1., # no clip
noise_generator=lambda: tree_node_value, noise_generator=lambda: tree_node_value,
record_specs=tf.TensorSpec([]), record_specs=tf.TensorSpec([]),
use_efficient=False, use_efficient=False)
restart_indicator=indicator,
)
global_state = query.initial_global_state() global_state = query.initial_global_state()
params = query.derive_sample_params(global_state) params = query.derive_sample_params(global_state)
for i in range(total_steps): for i in range(total_steps):
@ -462,6 +458,8 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
sample_state = query.accumulate_record(params, sample_state, scalar_value) sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result( query_result, global_state = query.get_noised_result(
sample_state, global_state) sample_state, global_state)
if i % frequency == frequency - 1:
global_state = query.reset_state(query_result, global_state)
# Expected value is the signal of the current round plus the residual of # Expected value is the signal of the current round plus the residual of
# two continous tree aggregation values. The tree aggregation value can # two continous tree aggregation values. The tree aggregation value can
# be inferred from the binary representation of the current step. # be inferred from the binary representation of the current step.

View file

@ -396,26 +396,5 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase, parameterized.TestCase):
self.assertNotAllEqual(gstate.seeds, prev_gstate.seeds) self.assertNotAllEqual(gstate.seeds, prev_gstate.seeds)
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('zero', 0), ('negative', -1))
def test_round_raise(self, frequency):
with self.assertRaisesRegex(
ValueError, 'Restart frequency should be equal or larger than 1'):
tree_aggregation.PeriodicRoundRestartIndicator(frequency)
@parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5))
def test_round_indicator(self, frequency):
total_steps = 20
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
state = indicator.initialize()
for i in range(total_steps):
flag, state = indicator.next(state)
if i % frequency == frequency - 1:
self.assertTrue(flag)
else:
self.assertFalse(flag)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()