Define RestartQuery
for easy composition to restart tree in tree aggregation queries.
PiperOrigin-RevId: 394106175
This commit is contained in:
parent
789a05df63
commit
6ac4bc8d01
7 changed files with 336 additions and 163 deletions
|
@ -62,7 +62,9 @@ else:
|
|||
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import QuantileEstimatorQuery
|
||||
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import NoPrivacyQuantileEstimatorQuery
|
||||
from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery
|
||||
from tensorflow_privacy.privacy.dp_query import restart_query
|
||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
||||
from tensorflow_privacy.privacy.dp_query.restart_query import RestartQuery
|
||||
from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeCumulativeSumQuery
|
||||
from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeResidualSumQuery
|
||||
from tensorflow_privacy.privacy.dp_query.tree_range_query import TreeRangeSumQuery
|
||||
|
|
148
tensorflow_privacy/privacy/dp_query/restart_query.py
Normal file
148
tensorflow_privacy/privacy/dp_query/restart_query.py
Normal file
|
@ -0,0 +1,148 @@
|
|||
# Copyright 2021, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Implements DPQuery interface for restarting the states of another query.
|
||||
|
||||
This query is used to compose with a DPQuery that has `reset_state` function.
|
||||
"""
|
||||
import abc
|
||||
import collections
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
|
||||
|
||||
class RestartIndicator(metaclass=abc.ABCMeta):
|
||||
"""Base class establishing interface for restarting the tree state.
|
||||
|
||||
A `RestartIndicator` maintains a state, and each time `next` is called, a bool
|
||||
value is generated to indicate whether to restart, and the indicator state is
|
||||
advanced.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def initialize(self):
|
||||
"""Makes an initialized state for `RestartIndicator`.
|
||||
|
||||
Returns:
|
||||
An initial state.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def next(self, state):
|
||||
"""Gets next bool indicator and advances the `RestartIndicator` state.
|
||||
|
||||
Args:
|
||||
state: The current state.
|
||||
|
||||
Returns:
|
||||
A pair (value, new_state) where value is bool indicator and new_state
|
||||
is the advanced state.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PeriodicRoundRestartIndicator(RestartIndicator):
|
||||
"""Indicator for resetting the tree state after every a few number of queries.
|
||||
|
||||
The indicator will maintain an internal counter as state.
|
||||
"""
|
||||
|
||||
def __init__(self, frequency: int):
|
||||
"""Construct the `PeriodicRoundRestartIndicator`.
|
||||
|
||||
Args:
|
||||
frequency: The `next` function will return `True` every `frequency` number
|
||||
of `next` calls.
|
||||
"""
|
||||
if frequency < 1:
|
||||
raise ValueError('Restart frequency should be equal or larger than 1 '
|
||||
f'got {frequency}')
|
||||
self.frequency = tf.constant(frequency, tf.int32)
|
||||
|
||||
def initialize(self):
|
||||
"""Returns initialized state of 0 for `PeriodicRoundRestartIndicator`."""
|
||||
return tf.constant(0, tf.int32)
|
||||
|
||||
def next(self, state):
|
||||
"""Gets next bool indicator and advances the state.
|
||||
|
||||
Args:
|
||||
state: The current state.
|
||||
|
||||
Returns:
|
||||
A pair (value, new_state) where value is the bool indicator and new_state
|
||||
of `state+1`.
|
||||
"""
|
||||
state = state + tf.constant(1, tf.int32)
|
||||
flag = state % self.frequency == 0
|
||||
return flag, state
|
||||
|
||||
|
||||
class RestartQuery(dp_query.SumAggregationDPQuery):
|
||||
"""`DPQuery` for `SumAggregationDPQuery` with a `reset_state` function."""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_GlobalState = collections.namedtuple(
|
||||
'_GlobalState', ['inner_query_state', 'indicator_state'])
|
||||
|
||||
def __init__(self, inner_query: dp_query.SumAggregationDPQuery,
|
||||
restart_indicator: RestartIndicator):
|
||||
"""Initializes `RestartQuery`.
|
||||
|
||||
Args:
|
||||
inner_query: A `SumAggregationDPQuery` has `reset_state` attribute.
|
||||
restart_indicator: A `RestartIndicator` to generate the boolean indicator
|
||||
for resetting the state.
|
||||
"""
|
||||
if not hasattr(inner_query, 'reset_state'):
|
||||
raise ValueError(f'{type(inner_query)} must define `reset_state` to be '
|
||||
'composed with `RestartQuery`.')
|
||||
self._inner_query = inner_query
|
||||
self._restart_indicator = restart_indicator
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
return self._GlobalState(
|
||||
inner_query_state=self._inner_query.initial_global_state(),
|
||||
indicator_state=self._restart_indicator.initialize())
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
return self._inner_query.derive_sample_params(
|
||||
global_state.inner_query_state)
|
||||
|
||||
def initial_sample_state(self, template):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
|
||||
return self._inner_query.initial_sample_state(template)
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
|
||||
return self._inner_query.preprocess_record(params, record)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
noised_results, inner_query_state = self._inner_query.get_noised_result(
|
||||
sample_state, global_state.inner_query_state)
|
||||
restart_flag, indicator_state = self._restart_indicator.next(
|
||||
global_state.indicator_state)
|
||||
if restart_flag:
|
||||
inner_query_state = self._inner_query.reset_state(noised_results,
|
||||
inner_query_state)
|
||||
return noised_results, self._GlobalState(inner_query_state, indicator_state)
|
||||
|
||||
def derive_metrics(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
|
||||
return self._inner_query.derive_metrics(global_state.inner_query_state)
|
126
tensorflow_privacy/privacy/dp_query/restart_query_test.py
Normal file
126
tensorflow_privacy/privacy/dp_query/restart_query_test.py
Normal file
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2021, Google LLC.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for `restart_query`."""
|
||||
from absl.testing import parameterized
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.dp_query import restart_query
|
||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
|
||||
|
||||
|
||||
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(('zero', 0), ('negative', -1))
|
||||
def test_round_raise(self, frequency):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Restart frequency should be equal or larger than 1'):
|
||||
restart_query.PeriodicRoundRestartIndicator(frequency)
|
||||
|
||||
@parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5))
|
||||
def test_round_indicator(self, frequency):
|
||||
total_steps = 20
|
||||
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
|
||||
state = indicator.initialize()
|
||||
for i in range(total_steps):
|
||||
flag, state = indicator.next(state)
|
||||
if i % frequency == frequency - 1:
|
||||
self.assertTrue(flag)
|
||||
else:
|
||||
self.assertFalse(flag)
|
||||
|
||||
|
||||
def _get_l2_clip_fn():
|
||||
|
||||
def l2_clip_fn(record_as_list, clip_value):
|
||||
clipped_record, _ = tf.clip_by_global_norm(record_as_list, clip_value)
|
||||
return clipped_record
|
||||
|
||||
return l2_clip_fn
|
||||
|
||||
|
||||
class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('s0t1f1', 0., 1., 1),
|
||||
('s0t1f2', 0., 1., 2),
|
||||
('s0t1f5', 0., 1., 5),
|
||||
('s1t1f5', 1., 1., 5),
|
||||
('s1t2f2', 1., 2., 2),
|
||||
('s1t5f6', 1., 5., 6),
|
||||
)
|
||||
def test_sum_scalar_tree_aggregation_reset(self, scalar_value,
|
||||
tree_node_value, frequency):
|
||||
total_steps = 20
|
||||
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
|
||||
query = tree_aggregation_query.TreeCumulativeSumQuery(
|
||||
clip_fn=_get_l2_clip_fn(),
|
||||
clip_value=scalar_value + 1., # no clip
|
||||
noise_generator=lambda: tree_node_value,
|
||||
record_specs=tf.TensorSpec([]),
|
||||
use_efficient=False)
|
||||
query = restart_query.RestartQuery(query, indicator)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
for i in range(total_steps):
|
||||
sample_state = query.initial_sample_state(scalar_value)
|
||||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||
query_result, global_state = query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
# Expected value is the combination of cumsum of signal; sum of trees
|
||||
# that have been reset; current tree sum. The tree aggregation value can
|
||||
# be inferred from the binary representation of the current step.
|
||||
expected = (
|
||||
scalar_value * (i + 1) +
|
||||
i // frequency * tree_node_value * bin(frequency)[2:].count('1') +
|
||||
tree_node_value * bin(i % frequency + 1)[2:].count('1'))
|
||||
self.assertEqual(query_result, expected)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('s0t1f1', 0., 1., 1),
|
||||
('s0t1f2', 0., 1., 2),
|
||||
('s0t1f5', 0., 1., 5),
|
||||
('s1t1f5', 1., 1., 5),
|
||||
('s1t2f2', 1., 2., 2),
|
||||
('s1t5f6', 1., 5., 6),
|
||||
)
|
||||
def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value,
|
||||
frequency):
|
||||
total_steps = 20
|
||||
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
|
||||
query = tree_aggregation_query.TreeResidualSumQuery(
|
||||
clip_fn=_get_l2_clip_fn(),
|
||||
clip_value=scalar_value + 1., # no clip
|
||||
noise_generator=lambda: tree_node_value,
|
||||
record_specs=tf.TensorSpec([]),
|
||||
use_efficient=False)
|
||||
query = restart_query.RestartQuery(query, indicator)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
for i in range(total_steps):
|
||||
sample_state = query.initial_sample_state(scalar_value)
|
||||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||
query_result, global_state = query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
# Expected value is the signal of the current round plus the residual of
|
||||
# two continous tree aggregation values. The tree aggregation value can
|
||||
# be inferred from the binary representation of the current step.
|
||||
expected = scalar_value + tree_node_value * (
|
||||
bin(i % frequency + 1)[2:].count('1') -
|
||||
bin(i % frequency)[2:].count('1'))
|
||||
print(i, query_result, expected)
|
||||
self.assertEqual(query_result, expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -171,78 +171,6 @@ class StatelessValueGenerator(ValueGenerator):
|
|||
return self.value_fn(), state
|
||||
|
||||
|
||||
# TODO(b/192464750): define `RestartQuery` and move `RestartIndicator` to be
|
||||
# in the same module.
|
||||
|
||||
|
||||
class RestartIndicator(metaclass=abc.ABCMeta):
|
||||
"""Base class establishing interface for restarting the tree state.
|
||||
|
||||
A `RestartIndicator` maintains a state, and each time `next` is called, a bool
|
||||
value is generated to indicate whether to restart, and the indicator state is
|
||||
advanced.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def initialize(self):
|
||||
"""Makes an initialized state for `RestartIndicator`.
|
||||
|
||||
Returns:
|
||||
An initial state.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def next(self, state):
|
||||
"""Gets next bool indicator and advances the `RestartIndicator` state.
|
||||
|
||||
Args:
|
||||
state: The current state.
|
||||
|
||||
Returns:
|
||||
A pair (value, new_state) where value is bool indicator and new_state
|
||||
is the advanced state.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PeriodicRoundRestartIndicator(RestartIndicator):
|
||||
"""Indicator for resetting the tree state after every a few number of queries.
|
||||
|
||||
The indicator will maintain an internal counter as state.
|
||||
"""
|
||||
|
||||
def __init__(self, frequency: int):
|
||||
"""Construct the `PeriodicRoundRestartIndicator`.
|
||||
|
||||
Args:
|
||||
frequency: The `next` function will return `True` every `frequency` number
|
||||
of `next` calls.
|
||||
"""
|
||||
if frequency < 1:
|
||||
raise ValueError('Restart frequency should be equal or larger than 1 '
|
||||
f'got {frequency}')
|
||||
self.frequency = tf.constant(frequency, tf.int32)
|
||||
|
||||
def initialize(self):
|
||||
"""Returns initialized state of 0 for `PeriodicRoundRestartIndicator`."""
|
||||
return tf.constant(0, tf.int32)
|
||||
|
||||
def next(self, state):
|
||||
"""Gets next bool indicator and advances the state.
|
||||
|
||||
Args:
|
||||
state: The current state.
|
||||
|
||||
Returns:
|
||||
A pair (value, new_state) where value is the bool indicator and new_state
|
||||
of `state+1`.
|
||||
"""
|
||||
state = state + tf.constant(1, tf.int32)
|
||||
flag = state % self.frequency == 0
|
||||
return flag, state
|
||||
|
||||
|
||||
@attr.s(eq=False, frozen=True, slots=True)
|
||||
class TreeState(object):
|
||||
"""Class defining state of the tree.
|
||||
|
|
|
@ -72,8 +72,6 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
node. Noise stdandard deviation is specified outside the `dp_query` by the
|
||||
user when defining `noise_fn` and should have order
|
||||
O(clip_norm*log(T)/eps) to guarantee eps-DP.
|
||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
||||
boolean indicator for resetting the tree state.
|
||||
"""
|
||||
|
||||
@attr.s(frozen=True)
|
||||
|
@ -85,21 +83,17 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
each level state.
|
||||
clip_value: The clipping value to be passed to clip_fn.
|
||||
samples_cumulative_sum: Noiseless cumulative sum of samples over time.
|
||||
restarter_state: Current state of the restarter to indicate whether
|
||||
the tree state will be reset.
|
||||
"""
|
||||
tree_state = attr.ib()
|
||||
clip_value = attr.ib()
|
||||
samples_cumulative_sum = attr.ib()
|
||||
restarter_state = attr.ib()
|
||||
|
||||
def __init__(self,
|
||||
record_specs,
|
||||
noise_generator,
|
||||
clip_fn,
|
||||
clip_value,
|
||||
use_efficient=True,
|
||||
restart_indicator=None):
|
||||
use_efficient=True):
|
||||
"""Initializes the `TreeCumulativeSumQuery`.
|
||||
|
||||
Consider using `build_l2_gaussian_query` for the construction of a
|
||||
|
@ -117,8 +111,6 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
use_efficient: Boolean indicating the usage of the efficient tree
|
||||
aggregation algorithm based on the paper "Efficient Use of
|
||||
Differentially Private Binary Trees".
|
||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
||||
boolean indicator for resetting the tree state.
|
||||
"""
|
||||
self._clip_fn = clip_fn
|
||||
self._clip_value = clip_value
|
||||
|
@ -128,21 +120,16 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
noise_generator)
|
||||
else:
|
||||
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
|
||||
self._restart_indicator = restart_indicator
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
initial_tree_state = self._tree_aggregator.init_state()
|
||||
initial_samples_cumulative_sum = tf.nest.map_structure(
|
||||
lambda spec: tf.zeros(spec.shape), self._record_specs)
|
||||
restarter_state = ()
|
||||
if self._restart_indicator is not None:
|
||||
restarter_state = self._restart_indicator.initialize()
|
||||
return TreeCumulativeSumQuery.GlobalState(
|
||||
tree_state=initial_tree_state,
|
||||
clip_value=tf.constant(self._clip_value, tf.float32),
|
||||
samples_cumulative_sum=initial_samples_cumulative_sum,
|
||||
restarter_state=restarter_state)
|
||||
samples_cumulative_sum=initial_samples_cumulative_sum)
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
|
@ -185,28 +172,41 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
global_state.tree_state)
|
||||
noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
|
||||
cumulative_sum_noise)
|
||||
restarter_state = global_state.restarter_state
|
||||
if self._restart_indicator is not None:
|
||||
restart_flag, restarter_state = self._restart_indicator.next(
|
||||
restarter_state)
|
||||
if restart_flag:
|
||||
new_cumulative_sum = noised_cumulative_sum
|
||||
new_tree_state = self._tree_aggregator.reset_state(new_tree_state)
|
||||
new_global_state = attr.evolve(
|
||||
global_state,
|
||||
samples_cumulative_sum=new_cumulative_sum,
|
||||
tree_state=new_tree_state,
|
||||
restarter_state=restarter_state)
|
||||
tree_state=new_tree_state)
|
||||
return noised_cumulative_sum, new_global_state
|
||||
|
||||
def reset_state(self, noised_results, global_state):
|
||||
"""Returns state after resetting the tree.
|
||||
|
||||
This function will be used in `restart_query.RestartQuery` after calling
|
||||
`get_noised_result` when the restarting condition is met.
|
||||
|
||||
Args:
|
||||
noised_results: Noised cumulative sum returned by `get_noised_result`.
|
||||
global_state: Updated global state returned by `get_noised_result`, which
|
||||
has current sample's cumulative sum and tree state for the next
|
||||
cumulative sum.
|
||||
|
||||
Returns:
|
||||
New global state with current noised cumulative sum and restarted tree
|
||||
state for the next cumulative sum.
|
||||
"""
|
||||
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
|
||||
return attr.evolve(
|
||||
global_state,
|
||||
samples_cumulative_sum=noised_results,
|
||||
tree_state=new_tree_state)
|
||||
|
||||
@classmethod
|
||||
def build_l2_gaussian_query(cls,
|
||||
clip_norm,
|
||||
noise_multiplier,
|
||||
record_specs,
|
||||
noise_seed=None,
|
||||
use_efficient=True,
|
||||
restart_indicator=None):
|
||||
use_efficient=True):
|
||||
"""Returns a query instance with L2 norm clipping and Gaussian noise.
|
||||
|
||||
Args:
|
||||
|
@ -221,8 +221,6 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
use_efficient: Boolean indicating the usage of the efficient tree
|
||||
aggregation algorithm based on the paper "Efficient Use of
|
||||
Differentially Private Binary Trees".
|
||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
||||
boolean indicator for resetting the tree state.
|
||||
"""
|
||||
if clip_norm <= 0:
|
||||
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
|
||||
|
@ -245,8 +243,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
clip_value=clip_norm,
|
||||
record_specs=record_specs,
|
||||
noise_generator=gaussian_noise_generator,
|
||||
use_efficient=use_efficient,
|
||||
restart_indicator=restart_indicator)
|
||||
use_efficient=use_efficient)
|
||||
|
||||
|
||||
class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||
|
@ -300,8 +297,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
node. Noise stdandard deviation is specified outside the `dp_query` by the
|
||||
user when defining `noise_fn` and should have order
|
||||
O(clip_norm*log(T)/eps) to guarantee eps-DP.
|
||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
||||
boolean indicator for resetting the tree state.
|
||||
"""
|
||||
|
||||
@attr.s(frozen=True)
|
||||
|
@ -314,21 +309,17 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
clip_value: The clipping value to be passed to clip_fn.
|
||||
previous_tree_noise: Cumulative noise by tree aggregation from the
|
||||
previous time the query is called on a sample.
|
||||
restarter_state: Current state of the restarter to indicate whether
|
||||
the tree state will be reset.
|
||||
"""
|
||||
tree_state = attr.ib()
|
||||
clip_value = attr.ib()
|
||||
previous_tree_noise = attr.ib()
|
||||
restarter_state = attr.ib()
|
||||
|
||||
def __init__(self,
|
||||
record_specs,
|
||||
noise_generator,
|
||||
clip_fn,
|
||||
clip_value,
|
||||
use_efficient=True,
|
||||
restart_indicator=None):
|
||||
use_efficient=True):
|
||||
"""Initializes the `TreeCumulativeSumQuery`.
|
||||
|
||||
Consider using `build_l2_gaussian_query` for the construction of a
|
||||
|
@ -346,8 +337,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
use_efficient: Boolean indicating the usage of the efficient tree
|
||||
aggregation algorithm based on the paper "Efficient Use of
|
||||
Differentially Private Binary Trees".
|
||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
||||
boolean indicator for resetting the tree state.
|
||||
"""
|
||||
self._clip_fn = clip_fn
|
||||
self._clip_value = clip_value
|
||||
|
@ -357,7 +346,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
noise_generator)
|
||||
else:
|
||||
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
|
||||
self._restart_indicator = restart_indicator
|
||||
|
||||
def _zero_initial_noise(self):
|
||||
return tf.nest.map_structure(lambda spec: tf.zeros(spec.shape),
|
||||
|
@ -366,14 +354,10 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
def initial_global_state(self):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
initial_tree_state = self._tree_aggregator.init_state()
|
||||
restarter_state = ()
|
||||
if self._restart_indicator is not None:
|
||||
restarter_state = self._restart_indicator.initialize()
|
||||
return TreeResidualSumQuery.GlobalState(
|
||||
tree_state=initial_tree_state,
|
||||
clip_value=tf.constant(self._clip_value, tf.float32),
|
||||
previous_tree_noise=self._zero_initial_noise(),
|
||||
restarter_state=restarter_state)
|
||||
previous_tree_noise=self._zero_initial_noise())
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
|
@ -412,28 +396,39 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c,
|
||||
sample_state, tree_noise,
|
||||
global_state.previous_tree_noise)
|
||||
restarter_state = global_state.restarter_state
|
||||
if self._restart_indicator is not None:
|
||||
restart_flag, restarter_state = self._restart_indicator.next(
|
||||
restarter_state)
|
||||
if restart_flag:
|
||||
tree_noise = self._zero_initial_noise()
|
||||
new_tree_state = self._tree_aggregator.reset_state(new_tree_state)
|
||||
new_global_state = attr.evolve(
|
||||
global_state,
|
||||
previous_tree_noise=tree_noise,
|
||||
tree_state=new_tree_state,
|
||||
restarter_state=restarter_state)
|
||||
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
|
||||
return noised_sample, new_global_state
|
||||
|
||||
def reset_state(self, noised_results, global_state):
|
||||
"""Returns state after resetting the tree.
|
||||
|
||||
This function will be used in `restart_query.RestartQuery` after calling
|
||||
`get_noised_result` when the restarting condition is met.
|
||||
|
||||
Args:
|
||||
noised_results: Noised cumulative sum returned by `get_noised_result`.
|
||||
global_state: Updated global state returned by `get_noised_result`, which
|
||||
records noise for the conceptual cumulative sum of the current leaf
|
||||
node, and tree state for the next conceptual cumulative sum.
|
||||
|
||||
Returns:
|
||||
New global state with zero noise and restarted tree state.
|
||||
"""
|
||||
del noised_results
|
||||
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
|
||||
return attr.evolve(
|
||||
global_state,
|
||||
previous_tree_noise=self._zero_initial_noise(),
|
||||
tree_state=new_tree_state)
|
||||
|
||||
@classmethod
|
||||
def build_l2_gaussian_query(cls,
|
||||
clip_norm,
|
||||
noise_multiplier,
|
||||
record_specs,
|
||||
noise_seed=None,
|
||||
use_efficient=True,
|
||||
restart_indicator=None):
|
||||
use_efficient=True):
|
||||
"""Returns `TreeResidualSumQuery` with L2 norm clipping and Gaussian noise.
|
||||
|
||||
Args:
|
||||
|
@ -448,8 +443,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
use_efficient: Boolean indicating the usage of the efficient tree
|
||||
aggregation algorithm based on the paper "Efficient Use of
|
||||
Differentially Private Binary Trees".
|
||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
||||
boolean indicator for resetting the tree state.
|
||||
"""
|
||||
if clip_norm <= 0:
|
||||
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
|
||||
|
@ -472,8 +465,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
clip_value=clip_norm,
|
||||
record_specs=record_specs,
|
||||
noise_generator=gaussian_noise_generator,
|
||||
use_efficient=use_efficient,
|
||||
restart_indicator=restart_indicator)
|
||||
use_efficient=use_efficient)
|
||||
|
||||
|
||||
# TODO(b/197596864): Remove `TreeRangeSumQuery` from this file after the next
|
||||
|
|
|
@ -303,15 +303,12 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
def test_sum_scalar_tree_aggregation_reset(self, scalar_value,
|
||||
tree_node_value, frequency):
|
||||
total_steps = 20
|
||||
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
|
||||
query = tree_aggregation_query.TreeCumulativeSumQuery(
|
||||
clip_fn=_get_l2_clip_fn(),
|
||||
clip_value=scalar_value + 1., # no clip
|
||||
noise_generator=lambda: tree_node_value,
|
||||
record_specs=tf.TensorSpec([]),
|
||||
use_efficient=False,
|
||||
restart_indicator=indicator,
|
||||
)
|
||||
use_efficient=False)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
for i in range(total_steps):
|
||||
|
@ -319,6 +316,8 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||
query_result, global_state = query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
if i % frequency == frequency - 1:
|
||||
global_state = query.reset_state(query_result, global_state)
|
||||
# Expected value is the combination of cumsum of signal; sum of trees
|
||||
# that have been reset; current tree sum. The tree aggregation value can
|
||||
# be inferred from the binary representation of the current step.
|
||||
|
@ -446,15 +445,12 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value,
|
||||
frequency):
|
||||
total_steps = 20
|
||||
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
|
||||
query = tree_aggregation_query.TreeResidualSumQuery(
|
||||
clip_fn=_get_l2_clip_fn(),
|
||||
clip_value=scalar_value + 1., # no clip
|
||||
noise_generator=lambda: tree_node_value,
|
||||
record_specs=tf.TensorSpec([]),
|
||||
use_efficient=False,
|
||||
restart_indicator=indicator,
|
||||
)
|
||||
use_efficient=False)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
for i in range(total_steps):
|
||||
|
@ -462,6 +458,8 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||
query_result, global_state = query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
if i % frequency == frequency - 1:
|
||||
global_state = query.reset_state(query_result, global_state)
|
||||
# Expected value is the signal of the current round plus the residual of
|
||||
# two continous tree aggregation values. The tree aggregation value can
|
||||
# be inferred from the binary representation of the current step.
|
||||
|
|
|
@ -396,26 +396,5 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertNotAllEqual(gstate.seeds, prev_gstate.seeds)
|
||||
|
||||
|
||||
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(('zero', 0), ('negative', -1))
|
||||
def test_round_raise(self, frequency):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Restart frequency should be equal or larger than 1'):
|
||||
tree_aggregation.PeriodicRoundRestartIndicator(frequency)
|
||||
|
||||
@parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5))
|
||||
def test_round_indicator(self, frequency):
|
||||
total_steps = 20
|
||||
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
|
||||
state = indicator.initialize()
|
||||
for i in range(total_steps):
|
||||
flag, state = indicator.next(state)
|
||||
if i % frequency == frequency - 1:
|
||||
self.assertTrue(flag)
|
||||
else:
|
||||
self.assertFalse(flag)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in a new issue