Define RestartQuery
for easy composition to restart tree in tree aggregation queries.
PiperOrigin-RevId: 394106175
This commit is contained in:
parent
789a05df63
commit
6ac4bc8d01
7 changed files with 336 additions and 163 deletions
|
@ -62,7 +62,9 @@ else:
|
||||||
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import QuantileEstimatorQuery
|
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import QuantileEstimatorQuery
|
||||||
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import NoPrivacyQuantileEstimatorQuery
|
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import NoPrivacyQuantileEstimatorQuery
|
||||||
from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery
|
from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery
|
||||||
|
from tensorflow_privacy.privacy.dp_query import restart_query
|
||||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
||||||
|
from tensorflow_privacy.privacy.dp_query.restart_query import RestartQuery
|
||||||
from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeCumulativeSumQuery
|
from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeCumulativeSumQuery
|
||||||
from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeResidualSumQuery
|
from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeResidualSumQuery
|
||||||
from tensorflow_privacy.privacy.dp_query.tree_range_query import TreeRangeSumQuery
|
from tensorflow_privacy.privacy.dp_query.tree_range_query import TreeRangeSumQuery
|
||||||
|
|
148
tensorflow_privacy/privacy/dp_query/restart_query.py
Normal file
148
tensorflow_privacy/privacy/dp_query/restart_query.py
Normal file
|
@ -0,0 +1,148 @@
|
||||||
|
# Copyright 2021, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Implements DPQuery interface for restarting the states of another query.
|
||||||
|
|
||||||
|
This query is used to compose with a DPQuery that has `reset_state` function.
|
||||||
|
"""
|
||||||
|
import abc
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
|
|
||||||
|
|
||||||
|
class RestartIndicator(metaclass=abc.ABCMeta):
|
||||||
|
"""Base class establishing interface for restarting the tree state.
|
||||||
|
|
||||||
|
A `RestartIndicator` maintains a state, and each time `next` is called, a bool
|
||||||
|
value is generated to indicate whether to restart, and the indicator state is
|
||||||
|
advanced.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def initialize(self):
|
||||||
|
"""Makes an initialized state for `RestartIndicator`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An initial state.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def next(self, state):
|
||||||
|
"""Gets next bool indicator and advances the `RestartIndicator` state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The current state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A pair (value, new_state) where value is bool indicator and new_state
|
||||||
|
is the advanced state.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class PeriodicRoundRestartIndicator(RestartIndicator):
|
||||||
|
"""Indicator for resetting the tree state after every a few number of queries.
|
||||||
|
|
||||||
|
The indicator will maintain an internal counter as state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, frequency: int):
|
||||||
|
"""Construct the `PeriodicRoundRestartIndicator`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frequency: The `next` function will return `True` every `frequency` number
|
||||||
|
of `next` calls.
|
||||||
|
"""
|
||||||
|
if frequency < 1:
|
||||||
|
raise ValueError('Restart frequency should be equal or larger than 1 '
|
||||||
|
f'got {frequency}')
|
||||||
|
self.frequency = tf.constant(frequency, tf.int32)
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
"""Returns initialized state of 0 for `PeriodicRoundRestartIndicator`."""
|
||||||
|
return tf.constant(0, tf.int32)
|
||||||
|
|
||||||
|
def next(self, state):
|
||||||
|
"""Gets next bool indicator and advances the state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The current state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A pair (value, new_state) where value is the bool indicator and new_state
|
||||||
|
of `state+1`.
|
||||||
|
"""
|
||||||
|
state = state + tf.constant(1, tf.int32)
|
||||||
|
flag = state % self.frequency == 0
|
||||||
|
return flag, state
|
||||||
|
|
||||||
|
|
||||||
|
class RestartQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
"""`DPQuery` for `SumAggregationDPQuery` with a `reset_state` function."""
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
_GlobalState = collections.namedtuple(
|
||||||
|
'_GlobalState', ['inner_query_state', 'indicator_state'])
|
||||||
|
|
||||||
|
def __init__(self, inner_query: dp_query.SumAggregationDPQuery,
|
||||||
|
restart_indicator: RestartIndicator):
|
||||||
|
"""Initializes `RestartQuery`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inner_query: A `SumAggregationDPQuery` has `reset_state` attribute.
|
||||||
|
restart_indicator: A `RestartIndicator` to generate the boolean indicator
|
||||||
|
for resetting the state.
|
||||||
|
"""
|
||||||
|
if not hasattr(inner_query, 'reset_state'):
|
||||||
|
raise ValueError(f'{type(inner_query)} must define `reset_state` to be '
|
||||||
|
'composed with `RestartQuery`.')
|
||||||
|
self._inner_query = inner_query
|
||||||
|
self._restart_indicator = restart_indicator
|
||||||
|
|
||||||
|
def initial_global_state(self):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||||
|
return self._GlobalState(
|
||||||
|
inner_query_state=self._inner_query.initial_global_state(),
|
||||||
|
indicator_state=self._restart_indicator.initialize())
|
||||||
|
|
||||||
|
def derive_sample_params(self, global_state):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||||
|
return self._inner_query.derive_sample_params(
|
||||||
|
global_state.inner_query_state)
|
||||||
|
|
||||||
|
def initial_sample_state(self, template):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
|
||||||
|
return self._inner_query.initial_sample_state(template)
|
||||||
|
|
||||||
|
def preprocess_record(self, params, record):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
|
||||||
|
return self._inner_query.preprocess_record(params, record)
|
||||||
|
|
||||||
|
def get_noised_result(self, sample_state, global_state):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||||
|
noised_results, inner_query_state = self._inner_query.get_noised_result(
|
||||||
|
sample_state, global_state.inner_query_state)
|
||||||
|
restart_flag, indicator_state = self._restart_indicator.next(
|
||||||
|
global_state.indicator_state)
|
||||||
|
if restart_flag:
|
||||||
|
inner_query_state = self._inner_query.reset_state(noised_results,
|
||||||
|
inner_query_state)
|
||||||
|
return noised_results, self._GlobalState(inner_query_state, indicator_state)
|
||||||
|
|
||||||
|
def derive_metrics(self, global_state):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
|
||||||
|
return self._inner_query.derive_metrics(global_state.inner_query_state)
|
126
tensorflow_privacy/privacy/dp_query/restart_query_test.py
Normal file
126
tensorflow_privacy/privacy/dp_query/restart_query_test.py
Normal file
|
@ -0,0 +1,126 @@
|
||||||
|
# Copyright 2021, Google LLC.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for `restart_query`."""
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.dp_query import restart_query
|
||||||
|
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
|
||||||
|
|
||||||
|
|
||||||
|
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('zero', 0), ('negative', -1))
|
||||||
|
def test_round_raise(self, frequency):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'Restart frequency should be equal or larger than 1'):
|
||||||
|
restart_query.PeriodicRoundRestartIndicator(frequency)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5))
|
||||||
|
def test_round_indicator(self, frequency):
|
||||||
|
total_steps = 20
|
||||||
|
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
|
||||||
|
state = indicator.initialize()
|
||||||
|
for i in range(total_steps):
|
||||||
|
flag, state = indicator.next(state)
|
||||||
|
if i % frequency == frequency - 1:
|
||||||
|
self.assertTrue(flag)
|
||||||
|
else:
|
||||||
|
self.assertFalse(flag)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_l2_clip_fn():
|
||||||
|
|
||||||
|
def l2_clip_fn(record_as_list, clip_value):
|
||||||
|
clipped_record, _ = tf.clip_by_global_norm(record_as_list, clip_value)
|
||||||
|
return clipped_record
|
||||||
|
|
||||||
|
return l2_clip_fn
|
||||||
|
|
||||||
|
|
||||||
|
class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('s0t1f1', 0., 1., 1),
|
||||||
|
('s0t1f2', 0., 1., 2),
|
||||||
|
('s0t1f5', 0., 1., 5),
|
||||||
|
('s1t1f5', 1., 1., 5),
|
||||||
|
('s1t2f2', 1., 2., 2),
|
||||||
|
('s1t5f6', 1., 5., 6),
|
||||||
|
)
|
||||||
|
def test_sum_scalar_tree_aggregation_reset(self, scalar_value,
|
||||||
|
tree_node_value, frequency):
|
||||||
|
total_steps = 20
|
||||||
|
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
|
||||||
|
query = tree_aggregation_query.TreeCumulativeSumQuery(
|
||||||
|
clip_fn=_get_l2_clip_fn(),
|
||||||
|
clip_value=scalar_value + 1., # no clip
|
||||||
|
noise_generator=lambda: tree_node_value,
|
||||||
|
record_specs=tf.TensorSpec([]),
|
||||||
|
use_efficient=False)
|
||||||
|
query = restart_query.RestartQuery(query, indicator)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
for i in range(total_steps):
|
||||||
|
sample_state = query.initial_sample_state(scalar_value)
|
||||||
|
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||||
|
query_result, global_state = query.get_noised_result(
|
||||||
|
sample_state, global_state)
|
||||||
|
# Expected value is the combination of cumsum of signal; sum of trees
|
||||||
|
# that have been reset; current tree sum. The tree aggregation value can
|
||||||
|
# be inferred from the binary representation of the current step.
|
||||||
|
expected = (
|
||||||
|
scalar_value * (i + 1) +
|
||||||
|
i // frequency * tree_node_value * bin(frequency)[2:].count('1') +
|
||||||
|
tree_node_value * bin(i % frequency + 1)[2:].count('1'))
|
||||||
|
self.assertEqual(query_result, expected)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('s0t1f1', 0., 1., 1),
|
||||||
|
('s0t1f2', 0., 1., 2),
|
||||||
|
('s0t1f5', 0., 1., 5),
|
||||||
|
('s1t1f5', 1., 1., 5),
|
||||||
|
('s1t2f2', 1., 2., 2),
|
||||||
|
('s1t5f6', 1., 5., 6),
|
||||||
|
)
|
||||||
|
def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value,
|
||||||
|
frequency):
|
||||||
|
total_steps = 20
|
||||||
|
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
|
||||||
|
query = tree_aggregation_query.TreeResidualSumQuery(
|
||||||
|
clip_fn=_get_l2_clip_fn(),
|
||||||
|
clip_value=scalar_value + 1., # no clip
|
||||||
|
noise_generator=lambda: tree_node_value,
|
||||||
|
record_specs=tf.TensorSpec([]),
|
||||||
|
use_efficient=False)
|
||||||
|
query = restart_query.RestartQuery(query, indicator)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
for i in range(total_steps):
|
||||||
|
sample_state = query.initial_sample_state(scalar_value)
|
||||||
|
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||||
|
query_result, global_state = query.get_noised_result(
|
||||||
|
sample_state, global_state)
|
||||||
|
# Expected value is the signal of the current round plus the residual of
|
||||||
|
# two continous tree aggregation values. The tree aggregation value can
|
||||||
|
# be inferred from the binary representation of the current step.
|
||||||
|
expected = scalar_value + tree_node_value * (
|
||||||
|
bin(i % frequency + 1)[2:].count('1') -
|
||||||
|
bin(i % frequency)[2:].count('1'))
|
||||||
|
print(i, query_result, expected)
|
||||||
|
self.assertEqual(query_result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -171,78 +171,6 @@ class StatelessValueGenerator(ValueGenerator):
|
||||||
return self.value_fn(), state
|
return self.value_fn(), state
|
||||||
|
|
||||||
|
|
||||||
# TODO(b/192464750): define `RestartQuery` and move `RestartIndicator` to be
|
|
||||||
# in the same module.
|
|
||||||
|
|
||||||
|
|
||||||
class RestartIndicator(metaclass=abc.ABCMeta):
|
|
||||||
"""Base class establishing interface for restarting the tree state.
|
|
||||||
|
|
||||||
A `RestartIndicator` maintains a state, and each time `next` is called, a bool
|
|
||||||
value is generated to indicate whether to restart, and the indicator state is
|
|
||||||
advanced.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def initialize(self):
|
|
||||||
"""Makes an initialized state for `RestartIndicator`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An initial state.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def next(self, state):
|
|
||||||
"""Gets next bool indicator and advances the `RestartIndicator` state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: The current state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A pair (value, new_state) where value is bool indicator and new_state
|
|
||||||
is the advanced state.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class PeriodicRoundRestartIndicator(RestartIndicator):
|
|
||||||
"""Indicator for resetting the tree state after every a few number of queries.
|
|
||||||
|
|
||||||
The indicator will maintain an internal counter as state.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, frequency: int):
|
|
||||||
"""Construct the `PeriodicRoundRestartIndicator`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
frequency: The `next` function will return `True` every `frequency` number
|
|
||||||
of `next` calls.
|
|
||||||
"""
|
|
||||||
if frequency < 1:
|
|
||||||
raise ValueError('Restart frequency should be equal or larger than 1 '
|
|
||||||
f'got {frequency}')
|
|
||||||
self.frequency = tf.constant(frequency, tf.int32)
|
|
||||||
|
|
||||||
def initialize(self):
|
|
||||||
"""Returns initialized state of 0 for `PeriodicRoundRestartIndicator`."""
|
|
||||||
return tf.constant(0, tf.int32)
|
|
||||||
|
|
||||||
def next(self, state):
|
|
||||||
"""Gets next bool indicator and advances the state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: The current state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A pair (value, new_state) where value is the bool indicator and new_state
|
|
||||||
of `state+1`.
|
|
||||||
"""
|
|
||||||
state = state + tf.constant(1, tf.int32)
|
|
||||||
flag = state % self.frequency == 0
|
|
||||||
return flag, state
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(eq=False, frozen=True, slots=True)
|
@attr.s(eq=False, frozen=True, slots=True)
|
||||||
class TreeState(object):
|
class TreeState(object):
|
||||||
"""Class defining state of the tree.
|
"""Class defining state of the tree.
|
||||||
|
|
|
@ -72,8 +72,6 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
node. Noise stdandard deviation is specified outside the `dp_query` by the
|
node. Noise stdandard deviation is specified outside the `dp_query` by the
|
||||||
user when defining `noise_fn` and should have order
|
user when defining `noise_fn` and should have order
|
||||||
O(clip_norm*log(T)/eps) to guarantee eps-DP.
|
O(clip_norm*log(T)/eps) to guarantee eps-DP.
|
||||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
|
||||||
boolean indicator for resetting the tree state.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@attr.s(frozen=True)
|
@attr.s(frozen=True)
|
||||||
|
@ -85,21 +83,17 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
each level state.
|
each level state.
|
||||||
clip_value: The clipping value to be passed to clip_fn.
|
clip_value: The clipping value to be passed to clip_fn.
|
||||||
samples_cumulative_sum: Noiseless cumulative sum of samples over time.
|
samples_cumulative_sum: Noiseless cumulative sum of samples over time.
|
||||||
restarter_state: Current state of the restarter to indicate whether
|
|
||||||
the tree state will be reset.
|
|
||||||
"""
|
"""
|
||||||
tree_state = attr.ib()
|
tree_state = attr.ib()
|
||||||
clip_value = attr.ib()
|
clip_value = attr.ib()
|
||||||
samples_cumulative_sum = attr.ib()
|
samples_cumulative_sum = attr.ib()
|
||||||
restarter_state = attr.ib()
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
record_specs,
|
record_specs,
|
||||||
noise_generator,
|
noise_generator,
|
||||||
clip_fn,
|
clip_fn,
|
||||||
clip_value,
|
clip_value,
|
||||||
use_efficient=True,
|
use_efficient=True):
|
||||||
restart_indicator=None):
|
|
||||||
"""Initializes the `TreeCumulativeSumQuery`.
|
"""Initializes the `TreeCumulativeSumQuery`.
|
||||||
|
|
||||||
Consider using `build_l2_gaussian_query` for the construction of a
|
Consider using `build_l2_gaussian_query` for the construction of a
|
||||||
|
@ -117,8 +111,6 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
use_efficient: Boolean indicating the usage of the efficient tree
|
use_efficient: Boolean indicating the usage of the efficient tree
|
||||||
aggregation algorithm based on the paper "Efficient Use of
|
aggregation algorithm based on the paper "Efficient Use of
|
||||||
Differentially Private Binary Trees".
|
Differentially Private Binary Trees".
|
||||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
|
||||||
boolean indicator for resetting the tree state.
|
|
||||||
"""
|
"""
|
||||||
self._clip_fn = clip_fn
|
self._clip_fn = clip_fn
|
||||||
self._clip_value = clip_value
|
self._clip_value = clip_value
|
||||||
|
@ -128,21 +120,16 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
noise_generator)
|
noise_generator)
|
||||||
else:
|
else:
|
||||||
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
|
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
|
||||||
self._restart_indicator = restart_indicator
|
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||||
initial_tree_state = self._tree_aggregator.init_state()
|
initial_tree_state = self._tree_aggregator.init_state()
|
||||||
initial_samples_cumulative_sum = tf.nest.map_structure(
|
initial_samples_cumulative_sum = tf.nest.map_structure(
|
||||||
lambda spec: tf.zeros(spec.shape), self._record_specs)
|
lambda spec: tf.zeros(spec.shape), self._record_specs)
|
||||||
restarter_state = ()
|
|
||||||
if self._restart_indicator is not None:
|
|
||||||
restarter_state = self._restart_indicator.initialize()
|
|
||||||
return TreeCumulativeSumQuery.GlobalState(
|
return TreeCumulativeSumQuery.GlobalState(
|
||||||
tree_state=initial_tree_state,
|
tree_state=initial_tree_state,
|
||||||
clip_value=tf.constant(self._clip_value, tf.float32),
|
clip_value=tf.constant(self._clip_value, tf.float32),
|
||||||
samples_cumulative_sum=initial_samples_cumulative_sum,
|
samples_cumulative_sum=initial_samples_cumulative_sum)
|
||||||
restarter_state=restarter_state)
|
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||||
|
@ -185,28 +172,41 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
global_state.tree_state)
|
global_state.tree_state)
|
||||||
noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
|
noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
|
||||||
cumulative_sum_noise)
|
cumulative_sum_noise)
|
||||||
restarter_state = global_state.restarter_state
|
|
||||||
if self._restart_indicator is not None:
|
|
||||||
restart_flag, restarter_state = self._restart_indicator.next(
|
|
||||||
restarter_state)
|
|
||||||
if restart_flag:
|
|
||||||
new_cumulative_sum = noised_cumulative_sum
|
|
||||||
new_tree_state = self._tree_aggregator.reset_state(new_tree_state)
|
|
||||||
new_global_state = attr.evolve(
|
new_global_state = attr.evolve(
|
||||||
global_state,
|
global_state,
|
||||||
samples_cumulative_sum=new_cumulative_sum,
|
samples_cumulative_sum=new_cumulative_sum,
|
||||||
tree_state=new_tree_state,
|
tree_state=new_tree_state)
|
||||||
restarter_state=restarter_state)
|
|
||||||
return noised_cumulative_sum, new_global_state
|
return noised_cumulative_sum, new_global_state
|
||||||
|
|
||||||
|
def reset_state(self, noised_results, global_state):
|
||||||
|
"""Returns state after resetting the tree.
|
||||||
|
|
||||||
|
This function will be used in `restart_query.RestartQuery` after calling
|
||||||
|
`get_noised_result` when the restarting condition is met.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
noised_results: Noised cumulative sum returned by `get_noised_result`.
|
||||||
|
global_state: Updated global state returned by `get_noised_result`, which
|
||||||
|
has current sample's cumulative sum and tree state for the next
|
||||||
|
cumulative sum.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New global state with current noised cumulative sum and restarted tree
|
||||||
|
state for the next cumulative sum.
|
||||||
|
"""
|
||||||
|
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
|
||||||
|
return attr.evolve(
|
||||||
|
global_state,
|
||||||
|
samples_cumulative_sum=noised_results,
|
||||||
|
tree_state=new_tree_state)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_l2_gaussian_query(cls,
|
def build_l2_gaussian_query(cls,
|
||||||
clip_norm,
|
clip_norm,
|
||||||
noise_multiplier,
|
noise_multiplier,
|
||||||
record_specs,
|
record_specs,
|
||||||
noise_seed=None,
|
noise_seed=None,
|
||||||
use_efficient=True,
|
use_efficient=True):
|
||||||
restart_indicator=None):
|
|
||||||
"""Returns a query instance with L2 norm clipping and Gaussian noise.
|
"""Returns a query instance with L2 norm clipping and Gaussian noise.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -221,8 +221,6 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
use_efficient: Boolean indicating the usage of the efficient tree
|
use_efficient: Boolean indicating the usage of the efficient tree
|
||||||
aggregation algorithm based on the paper "Efficient Use of
|
aggregation algorithm based on the paper "Efficient Use of
|
||||||
Differentially Private Binary Trees".
|
Differentially Private Binary Trees".
|
||||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
|
||||||
boolean indicator for resetting the tree state.
|
|
||||||
"""
|
"""
|
||||||
if clip_norm <= 0:
|
if clip_norm <= 0:
|
||||||
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
|
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
|
||||||
|
@ -245,8 +243,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
clip_value=clip_norm,
|
clip_value=clip_norm,
|
||||||
record_specs=record_specs,
|
record_specs=record_specs,
|
||||||
noise_generator=gaussian_noise_generator,
|
noise_generator=gaussian_noise_generator,
|
||||||
use_efficient=use_efficient,
|
use_efficient=use_efficient)
|
||||||
restart_indicator=restart_indicator)
|
|
||||||
|
|
||||||
|
|
||||||
class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
@ -300,8 +297,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
node. Noise stdandard deviation is specified outside the `dp_query` by the
|
node. Noise stdandard deviation is specified outside the `dp_query` by the
|
||||||
user when defining `noise_fn` and should have order
|
user when defining `noise_fn` and should have order
|
||||||
O(clip_norm*log(T)/eps) to guarantee eps-DP.
|
O(clip_norm*log(T)/eps) to guarantee eps-DP.
|
||||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
|
||||||
boolean indicator for resetting the tree state.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@attr.s(frozen=True)
|
@attr.s(frozen=True)
|
||||||
|
@ -314,21 +309,17 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
clip_value: The clipping value to be passed to clip_fn.
|
clip_value: The clipping value to be passed to clip_fn.
|
||||||
previous_tree_noise: Cumulative noise by tree aggregation from the
|
previous_tree_noise: Cumulative noise by tree aggregation from the
|
||||||
previous time the query is called on a sample.
|
previous time the query is called on a sample.
|
||||||
restarter_state: Current state of the restarter to indicate whether
|
|
||||||
the tree state will be reset.
|
|
||||||
"""
|
"""
|
||||||
tree_state = attr.ib()
|
tree_state = attr.ib()
|
||||||
clip_value = attr.ib()
|
clip_value = attr.ib()
|
||||||
previous_tree_noise = attr.ib()
|
previous_tree_noise = attr.ib()
|
||||||
restarter_state = attr.ib()
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
record_specs,
|
record_specs,
|
||||||
noise_generator,
|
noise_generator,
|
||||||
clip_fn,
|
clip_fn,
|
||||||
clip_value,
|
clip_value,
|
||||||
use_efficient=True,
|
use_efficient=True):
|
||||||
restart_indicator=None):
|
|
||||||
"""Initializes the `TreeCumulativeSumQuery`.
|
"""Initializes the `TreeCumulativeSumQuery`.
|
||||||
|
|
||||||
Consider using `build_l2_gaussian_query` for the construction of a
|
Consider using `build_l2_gaussian_query` for the construction of a
|
||||||
|
@ -346,8 +337,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
use_efficient: Boolean indicating the usage of the efficient tree
|
use_efficient: Boolean indicating the usage of the efficient tree
|
||||||
aggregation algorithm based on the paper "Efficient Use of
|
aggregation algorithm based on the paper "Efficient Use of
|
||||||
Differentially Private Binary Trees".
|
Differentially Private Binary Trees".
|
||||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
|
||||||
boolean indicator for resetting the tree state.
|
|
||||||
"""
|
"""
|
||||||
self._clip_fn = clip_fn
|
self._clip_fn = clip_fn
|
||||||
self._clip_value = clip_value
|
self._clip_value = clip_value
|
||||||
|
@ -357,7 +346,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
noise_generator)
|
noise_generator)
|
||||||
else:
|
else:
|
||||||
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
|
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
|
||||||
self._restart_indicator = restart_indicator
|
|
||||||
|
|
||||||
def _zero_initial_noise(self):
|
def _zero_initial_noise(self):
|
||||||
return tf.nest.map_structure(lambda spec: tf.zeros(spec.shape),
|
return tf.nest.map_structure(lambda spec: tf.zeros(spec.shape),
|
||||||
|
@ -366,14 +354,10 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||||
initial_tree_state = self._tree_aggregator.init_state()
|
initial_tree_state = self._tree_aggregator.init_state()
|
||||||
restarter_state = ()
|
|
||||||
if self._restart_indicator is not None:
|
|
||||||
restarter_state = self._restart_indicator.initialize()
|
|
||||||
return TreeResidualSumQuery.GlobalState(
|
return TreeResidualSumQuery.GlobalState(
|
||||||
tree_state=initial_tree_state,
|
tree_state=initial_tree_state,
|
||||||
clip_value=tf.constant(self._clip_value, tf.float32),
|
clip_value=tf.constant(self._clip_value, tf.float32),
|
||||||
previous_tree_noise=self._zero_initial_noise(),
|
previous_tree_noise=self._zero_initial_noise())
|
||||||
restarter_state=restarter_state)
|
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||||
|
@ -412,28 +396,39 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c,
|
noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c,
|
||||||
sample_state, tree_noise,
|
sample_state, tree_noise,
|
||||||
global_state.previous_tree_noise)
|
global_state.previous_tree_noise)
|
||||||
restarter_state = global_state.restarter_state
|
|
||||||
if self._restart_indicator is not None:
|
|
||||||
restart_flag, restarter_state = self._restart_indicator.next(
|
|
||||||
restarter_state)
|
|
||||||
if restart_flag:
|
|
||||||
tree_noise = self._zero_initial_noise()
|
|
||||||
new_tree_state = self._tree_aggregator.reset_state(new_tree_state)
|
|
||||||
new_global_state = attr.evolve(
|
new_global_state = attr.evolve(
|
||||||
global_state,
|
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
|
||||||
previous_tree_noise=tree_noise,
|
|
||||||
tree_state=new_tree_state,
|
|
||||||
restarter_state=restarter_state)
|
|
||||||
return noised_sample, new_global_state
|
return noised_sample, new_global_state
|
||||||
|
|
||||||
|
def reset_state(self, noised_results, global_state):
|
||||||
|
"""Returns state after resetting the tree.
|
||||||
|
|
||||||
|
This function will be used in `restart_query.RestartQuery` after calling
|
||||||
|
`get_noised_result` when the restarting condition is met.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
noised_results: Noised cumulative sum returned by `get_noised_result`.
|
||||||
|
global_state: Updated global state returned by `get_noised_result`, which
|
||||||
|
records noise for the conceptual cumulative sum of the current leaf
|
||||||
|
node, and tree state for the next conceptual cumulative sum.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New global state with zero noise and restarted tree state.
|
||||||
|
"""
|
||||||
|
del noised_results
|
||||||
|
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
|
||||||
|
return attr.evolve(
|
||||||
|
global_state,
|
||||||
|
previous_tree_noise=self._zero_initial_noise(),
|
||||||
|
tree_state=new_tree_state)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_l2_gaussian_query(cls,
|
def build_l2_gaussian_query(cls,
|
||||||
clip_norm,
|
clip_norm,
|
||||||
noise_multiplier,
|
noise_multiplier,
|
||||||
record_specs,
|
record_specs,
|
||||||
noise_seed=None,
|
noise_seed=None,
|
||||||
use_efficient=True,
|
use_efficient=True):
|
||||||
restart_indicator=None):
|
|
||||||
"""Returns `TreeResidualSumQuery` with L2 norm clipping and Gaussian noise.
|
"""Returns `TreeResidualSumQuery` with L2 norm clipping and Gaussian noise.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -448,8 +443,6 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
use_efficient: Boolean indicating the usage of the efficient tree
|
use_efficient: Boolean indicating the usage of the efficient tree
|
||||||
aggregation algorithm based on the paper "Efficient Use of
|
aggregation algorithm based on the paper "Efficient Use of
|
||||||
Differentially Private Binary Trees".
|
Differentially Private Binary Trees".
|
||||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
|
||||||
boolean indicator for resetting the tree state.
|
|
||||||
"""
|
"""
|
||||||
if clip_norm <= 0:
|
if clip_norm <= 0:
|
||||||
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
|
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
|
||||||
|
@ -472,8 +465,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
clip_value=clip_norm,
|
clip_value=clip_norm,
|
||||||
record_specs=record_specs,
|
record_specs=record_specs,
|
||||||
noise_generator=gaussian_noise_generator,
|
noise_generator=gaussian_noise_generator,
|
||||||
use_efficient=use_efficient,
|
use_efficient=use_efficient)
|
||||||
restart_indicator=restart_indicator)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(b/197596864): Remove `TreeRangeSumQuery` from this file after the next
|
# TODO(b/197596864): Remove `TreeRangeSumQuery` from this file after the next
|
||||||
|
|
|
@ -303,15 +303,12 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
def test_sum_scalar_tree_aggregation_reset(self, scalar_value,
|
def test_sum_scalar_tree_aggregation_reset(self, scalar_value,
|
||||||
tree_node_value, frequency):
|
tree_node_value, frequency):
|
||||||
total_steps = 20
|
total_steps = 20
|
||||||
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
|
|
||||||
query = tree_aggregation_query.TreeCumulativeSumQuery(
|
query = tree_aggregation_query.TreeCumulativeSumQuery(
|
||||||
clip_fn=_get_l2_clip_fn(),
|
clip_fn=_get_l2_clip_fn(),
|
||||||
clip_value=scalar_value + 1., # no clip
|
clip_value=scalar_value + 1., # no clip
|
||||||
noise_generator=lambda: tree_node_value,
|
noise_generator=lambda: tree_node_value,
|
||||||
record_specs=tf.TensorSpec([]),
|
record_specs=tf.TensorSpec([]),
|
||||||
use_efficient=False,
|
use_efficient=False)
|
||||||
restart_indicator=indicator,
|
|
||||||
)
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
params = query.derive_sample_params(global_state)
|
params = query.derive_sample_params(global_state)
|
||||||
for i in range(total_steps):
|
for i in range(total_steps):
|
||||||
|
@ -319,6 +316,8 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||||
query_result, global_state = query.get_noised_result(
|
query_result, global_state = query.get_noised_result(
|
||||||
sample_state, global_state)
|
sample_state, global_state)
|
||||||
|
if i % frequency == frequency - 1:
|
||||||
|
global_state = query.reset_state(query_result, global_state)
|
||||||
# Expected value is the combination of cumsum of signal; sum of trees
|
# Expected value is the combination of cumsum of signal; sum of trees
|
||||||
# that have been reset; current tree sum. The tree aggregation value can
|
# that have been reset; current tree sum. The tree aggregation value can
|
||||||
# be inferred from the binary representation of the current step.
|
# be inferred from the binary representation of the current step.
|
||||||
|
@ -446,15 +445,12 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value,
|
def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value,
|
||||||
frequency):
|
frequency):
|
||||||
total_steps = 20
|
total_steps = 20
|
||||||
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
|
|
||||||
query = tree_aggregation_query.TreeResidualSumQuery(
|
query = tree_aggregation_query.TreeResidualSumQuery(
|
||||||
clip_fn=_get_l2_clip_fn(),
|
clip_fn=_get_l2_clip_fn(),
|
||||||
clip_value=scalar_value + 1., # no clip
|
clip_value=scalar_value + 1., # no clip
|
||||||
noise_generator=lambda: tree_node_value,
|
noise_generator=lambda: tree_node_value,
|
||||||
record_specs=tf.TensorSpec([]),
|
record_specs=tf.TensorSpec([]),
|
||||||
use_efficient=False,
|
use_efficient=False)
|
||||||
restart_indicator=indicator,
|
|
||||||
)
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
params = query.derive_sample_params(global_state)
|
params = query.derive_sample_params(global_state)
|
||||||
for i in range(total_steps):
|
for i in range(total_steps):
|
||||||
|
@ -462,6 +458,8 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||||
query_result, global_state = query.get_noised_result(
|
query_result, global_state = query.get_noised_result(
|
||||||
sample_state, global_state)
|
sample_state, global_state)
|
||||||
|
if i % frequency == frequency - 1:
|
||||||
|
global_state = query.reset_state(query_result, global_state)
|
||||||
# Expected value is the signal of the current round plus the residual of
|
# Expected value is the signal of the current round plus the residual of
|
||||||
# two continous tree aggregation values. The tree aggregation value can
|
# two continous tree aggregation values. The tree aggregation value can
|
||||||
# be inferred from the binary representation of the current step.
|
# be inferred from the binary representation of the current step.
|
||||||
|
|
|
@ -396,26 +396,5 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertNotAllEqual(gstate.seeds, prev_gstate.seeds)
|
self.assertNotAllEqual(gstate.seeds, prev_gstate.seeds)
|
||||||
|
|
||||||
|
|
||||||
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
|
||||||
|
|
||||||
@parameterized.named_parameters(('zero', 0), ('negative', -1))
|
|
||||||
def test_round_raise(self, frequency):
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
ValueError, 'Restart frequency should be equal or larger than 1'):
|
|
||||||
tree_aggregation.PeriodicRoundRestartIndicator(frequency)
|
|
||||||
|
|
||||||
@parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5))
|
|
||||||
def test_round_indicator(self, frequency):
|
|
||||||
total_steps = 20
|
|
||||||
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
|
|
||||||
state = indicator.initialize()
|
|
||||||
for i in range(total_steps):
|
|
||||||
flag, state = indicator.next(state)
|
|
||||||
if i % frequency == frequency - 1:
|
|
||||||
self.assertTrue(flag)
|
|
||||||
else:
|
|
||||||
self.assertFalse(flag)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue