Restart the tree state in tree related DPQuery for streaming data: a general abstract class and an instance of restarting every a few rounds.

PiperOrigin-RevId: 390244330
This commit is contained in:
Zheng Xu 2021-08-11 16:25:32 -07:00 committed by A. Unique TensorFlower
parent f44dcb8760
commit b4c04093cf
4 changed files with 399 additions and 67 deletions

View file

@ -16,7 +16,10 @@
`TreeAggregator` and `EfficientTreeAggregator` compute cumulative sums of noise
based on tree aggregation. When using an appropriate noise function (e.g.,
Gaussian noise), it allows for efficient differentially private algorithms under
continual observation, without prior subsampling or shuffling assumptions.
continual observation, without prior subsampling or shuffling assumptions. This
module implements the core logic of tree aggregation in Tensorflow, which serves
as helper functions for `tree_aggregation_query`. This module and helper
functions are publicly accessible.
"""
import abc
@ -26,6 +29,10 @@ import attr
import tensorflow as tf
# TODO(b/192464750): find a proper place for the helper functions, privatize
# the tree aggregation logic, and encourage users to use the DPQuery API.
class ValueGenerator(metaclass=abc.ABCMeta):
"""Base class establishing interface for stateful value generation.
@ -40,6 +47,7 @@ class ValueGenerator(metaclass=abc.ABCMeta):
Returns:
An initial state.
"""
raise NotImplementedError
@abc.abstractmethod
def next(self, state):
@ -52,6 +60,7 @@ class ValueGenerator(metaclass=abc.ABCMeta):
A pair (value, new_state) where value is the next value and new_state
is the advanced state.
"""
raise NotImplementedError
class GaussianNoiseGenerator(ValueGenerator):
@ -148,6 +157,78 @@ class StatelessValueGenerator(ValueGenerator):
return self.value_fn(), state
# TODO(b/192464750): define `RestartQuery` and move `RestartIndicator` to be
# in the same module.
class RestartIndicator(metaclass=abc.ABCMeta):
"""Base class establishing interface for restarting the tree state.
A `RestartIndicator` maintains a state, and each time `next` is called, a bool
value is generated to indicate whether to restart, and the indicator state is
advanced.
"""
@abc.abstractmethod
def initialize(self):
"""Makes an initialized state for `RestartIndicator`.
Returns:
An initial state.
"""
raise NotImplementedError
@abc.abstractmethod
def next(self, state):
"""Gets next bool indicator and advances the `RestartIndicator` state.
Args:
state: The current state.
Returns:
A pair (value, new_state) where value is bool indicator and new_state
is the advanced state.
"""
raise NotImplementedError
class PeriodicRoundRestartIndicator(RestartIndicator):
"""Indicator for resetting the tree state after every a few number of queries.
The indicator will maintain an internal counter as state.
"""
def __init__(self, frequency: int):
"""Construct the `PeriodicRoundRestartIndicator`.
Args:
frequency: The `next` function will return `True` every `frequency` number
of `next` calls.
"""
if frequency < 1:
raise ValueError('Restart frequency should be equal or larger than 1 '
f'got {frequency}')
self.frequency = tf.constant(frequency, tf.int32)
def initialize(self):
"""Returns initialized state of 0 for `PeriodicRoundRestartIndicator`."""
return tf.constant(0, tf.int32)
def next(self, state):
"""Gets next bool indicator and advances the state.
Args:
state: The current state.
Returns:
A pair (value, new_state) where value is the bool indicator and new_state
of `state+1`.
"""
state = state + tf.constant(1, tf.int32)
flag = state % self.frequency == 0
return flag, state
@attr.s(eq=False, frozen=True, slots=True)
class TreeState(object):
"""Class defining state of the tree.
@ -166,6 +247,7 @@ class TreeState(object):
value_generator_state = attr.ib(type=Any)
# TODO(b/192464750): move `get_step_idx` to be a property of `TreeState`.
@tf.function
def get_step_idx(state: TreeState) -> tf.Tensor:
"""Returns the current leaf node index based on `TreeState.level_buffer_idx`."""
@ -188,6 +270,14 @@ class TreeAggregator():
https://dl.acm.org/doi/pdf/10.1145/1806689.1806787. A buffer at the scale of
tree depth is maintained and updated when a new conceptual leaf node arrives.
Example usage:
random_generator = GaussianNoiseGenerator(...)
tree_aggregator = TreeAggregator(random_generator)
state = tree_aggregator.init_state()
for leaf_node_idx in range(total_steps):
assert leaf_node_idx == get_step_idx(state))
noise, state = tree_aggregator.get_cumsum_and_update(state)
Attributes:
value_generator: A `ValueGenerator` or a no-arg function to generate a noise
value for each tree node.
@ -205,14 +295,8 @@ class TreeAggregator():
else:
self.value_generator = StatelessValueGenerator(value_generator)
def init_state(self) -> TreeState:
"""Returns initial `TreeState`.
Initializes `TreeState` for a tree of a single leaf node: the respective
initial node value in `TreeState.level_buffer` is generated by the value
generator function, and the node index is 0.
"""
value_generator_state = self.value_generator.initialize()
def _get_init_state(self, value_generator_state) -> TreeState:
"""Returns initial `TreeState` given `value_generator_state`."""
level_buffer_idx = tf.TensorArray(dtype=tf.int32, size=1, dynamic_size=True)
level_buffer_idx = level_buffer_idx.write(0, tf.constant(
0, dtype=tf.int32)).stack()
@ -224,12 +308,28 @@ class TreeAggregator():
new_val)
level_buffer = tf.nest.map_structure(lambda x, y: x.write(0, y).stack(),
level_buffer_structure, new_val)
return TreeState(
level_buffer=level_buffer,
level_buffer_idx=level_buffer_idx,
value_generator_state=value_generator_state)
def init_state(self) -> TreeState:
"""Returns initial `TreeState`.
Initializes `TreeState` for a tree of a single leaf node: the respective
initial node value in `TreeState.level_buffer` is generated by the value
generator function, and the node index is 0.
Returns:
An initialized `TreeState`.
"""
value_generator_state = self.value_generator.initialize()
return self._get_init_state(value_generator_state)
def reset_state(self, state: TreeState) -> TreeState:
"""Returns reset `TreeState` after restarting a new tree."""
return self._get_init_state(state.value_generator_state)
@tf.function
def _get_cumsum(self, level_buffer: Collection[tf.Tensor]) -> tf.Tensor:
return tf.nest.map_structure(lambda x: tf.reduce_sum(x, axis=0),
@ -238,7 +338,7 @@ class TreeAggregator():
@tf.function
def get_cumsum_and_update(self,
state: TreeState) -> Tuple[tf.Tensor, TreeState]:
"""Returns tree aggregated value and updated `TreeState` for one step.
"""Returns tree aggregated noise and updates `TreeState` for the next step.
`TreeState` is updated to prepare for accepting the *next* leaf node. Note
that `get_step_idx` can be called to get the current index of the leaf node
@ -249,10 +349,20 @@ class TreeAggregator():
Args:
state: `TreeState` for the current leaf node, index can be queried by
`tree_aggregation.get_step_idx(state.level_buffer_idx)`.
Returns:
Tuple of (noise, state) where `noise` is generated by tree aggregated
protocol for the cumulative sum of streaming data, and `state` is the
updated `TreeState`.
"""
level_buffer_idx, level_buffer, value_generator_state = (
state.level_buffer_idx, state.level_buffer, state.value_generator_state)
# We only publicize a combined function for updating state and returning
# noised results because this DPQuery is designed for the streaming data,
# and we only maintain a dynamic memory buffer of max size logT. Only the
# the most recent noised results can be queried, and the queries are
# expected to happen for every step in the streaming setting.
cumsum = self._get_cumsum(level_buffer)
new_level_buffer = tf.nest.map_structure(
@ -311,6 +421,14 @@ class EfficientTreeAggregator():
`sigma * sqrt(2^{d-1}/(2^d-1))`. which becomes `sigma / sqrt(2)` when
the tree is very tall.
Example usage:
random_generator = GaussianNoiseGenerator(...)
tree_aggregator = EfficientTreeAggregator(random_generator)
state = tree_aggregator.init_state()
for leaf_node_idx in range(total_steps):
assert leaf_node_idx == get_step_idx(state))
noise, state = tree_aggregator.get_cumsum_and_update(state)
Attributes:
value_generator: A `ValueGenerator` or a no-arg function to generate a noise
value for each tree node.
@ -328,17 +446,8 @@ class EfficientTreeAggregator():
else:
self.value_generator = StatelessValueGenerator(value_generator)
def init_state(self) -> TreeState:
"""Returns initial `TreeState`.
Initializes `TreeState` for a tree of a single leaf node: the respective
initial node value in `TreeState.level_buffer` is generated by the value
generator function, and the node index is 0.
Returns:
An initialized `TreeState`.
"""
value_generator_state = self.value_generator.initialize()
def _get_init_state(self, value_generator_state):
"""Returns initial buffer for `TreeState`."""
level_buffer_idx = tf.TensorArray(dtype=tf.int32, size=1, dynamic_size=True)
level_buffer_idx = level_buffer_idx.write(0, tf.constant(
0, dtype=tf.int32)).stack()
@ -350,12 +459,28 @@ class EfficientTreeAggregator():
new_val)
level_buffer = tf.nest.map_structure(lambda x, y: x.write(0, y).stack(),
level_buffer_structure, new_val)
return TreeState(
level_buffer=level_buffer,
level_buffer_idx=level_buffer_idx,
value_generator_state=value_generator_state)
def init_state(self) -> TreeState:
"""Returns initial `TreeState`.
Initializes `TreeState` for a tree of a single leaf node: the respective
initial node value in `TreeState.level_buffer` is generated by the value
generator function, and the node index is 0.
Returns:
An initialized `TreeState`.
"""
value_generator_state = self.value_generator.initialize()
return self._get_init_state(value_generator_state)
def reset_state(self, state: TreeState) -> TreeState:
"""Returns reset `TreeState` after restarting a new tree."""
return self._get_init_state(state.value_generator_state)
@tf.function
def _get_cumsum(self, state: TreeState) -> tf.Tensor:
"""Returns weighted cumulative sum of noise based on `TreeState`."""
@ -377,7 +502,7 @@ class EfficientTreeAggregator():
@tf.function
def get_cumsum_and_update(self,
state: TreeState) -> Tuple[tf.Tensor, TreeState]:
"""Returns tree aggregated value and updated `TreeState` for one step.
"""Returns tree aggregated noise and updates `TreeState` for the next step.
`TreeState` is updated to prepare for accepting the *next* leaf node. Note
that `get_step_idx` can be called to get the current index of the leaf node
@ -390,7 +515,17 @@ class EfficientTreeAggregator():
Args:
state: `TreeState` for the current leaf node, index can be queried by
`tree_aggregation.get_step_idx(state.level_buffer_idx)`.
Returns:
Tuple of (noise, state) where `noise` is generated by tree aggregated
protocol for the cumulative sum of streaming data, and `state` is the
updated `TreeState`..
"""
# We only publicize a combined function for updating state and returning
# noised results because this DPQuery is designed for the streaming data,
# and we only maintain a dynamic memory buffer of max size logT. Only the
# the most recent noised results can be queried, and the queries are
# expected to happen for every step in the streaming setting.
cumsum = self._get_cumsum(state)
level_buffer_idx, level_buffer, value_generator_state = (

View file

@ -32,13 +32,35 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation
class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
"""Implements dp_query for adding correlated noise through tree structure.
# TODO(b/192464750): define `RestartQuery` and move `RestartIndicator` to be
# in the same module.
First clips and sums records in current sample, returns cumulative sum of
samples over time (instead of only current sample) with added noise for
cumulative sum proportional to log(T), T being the number of times the query
is called.
class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
"""Returns private cumulative sums by clipping and adding correlated noise.
Consider calling `get_noised_result` T times, and each (x_i, i=0,2,...,T-1) is
the private value returned by `accumulate_record`, i.e. x_i = sum_{j=0}^{n-1}
x_{i,j} where each x_{i,j} is a private record in the database. This class is
intended to make multiple queries, which release privatized values of the
cumulative sums s_i = sum_{k=0}^{i} x_k, for i=0,...,T-1.
Each call to `get_noised_result` releases the next cumulative sum s_i, which
is in contrast to the GaussianSumQuery that releases x_i. Noise for the
cumulative sums is accomplished using the tree aggregation logic in
`tree_aggregation`, which is proportional to log(T).
Example usage:
query = TreeCumulativeSumQuery(...)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i, samples in enumerate(streaming_samples):
sample_state = query.initial_sample_state(samples[0])
# Compute x_i = sum_{j=0}^{n-1} x_{i,j}
for j,sample in enumerate(samples):
sample_state = query.accumulate_record(params, sample_state, sample)
# noised_cumsum is privatized estimate of s_i
noised_cumsum, global_state = query.get_noised_result(
sample_state, global_state)
Attributes:
clip_fn: Callable that specifies clipping function. `clip_fn` receives two
@ -52,6 +74,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
node. Noise stdandard deviation is specified outside the `dp_query` by the
user when defining `noise_fn` and should have order
O(clip_norm*log(T)/eps) to guarantee eps-DP.
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
"""
@attr.s(frozen=True)
@ -63,17 +87,21 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
each level state.
clip_value: The clipping value to be passed to clip_fn.
samples_cumulative_sum: Noiseless cumulative sum of samples over time.
restarter_state: Current state of the restarter to indicate whether
the tree state will be reset.
"""
tree_state = attr.ib()
clip_value = attr.ib()
samples_cumulative_sum = attr.ib()
restarter_state = attr.ib()
def __init__(self,
record_specs,
noise_generator,
clip_fn,
clip_value,
use_efficient=True):
use_efficient=True,
restart_indicator=None):
"""Initializes the `TreeCumulativeSumQuery`.
Consider using `build_l2_gaussian_query` for the construction of a
@ -91,6 +119,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
use_efficient: Boolean indicating the usage of the efficient tree
aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees".
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
"""
self._clip_fn = clip_fn
self._clip_value = clip_value
@ -100,17 +130,21 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
noise_generator)
else:
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
self._restart_indicator = restart_indicator
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
initial_tree_state = self._tree_aggregator.init_state()
initial_samples_cumulative_sum = tf.nest.map_structure(
lambda spec: tf.zeros(spec.shape), self._record_specs)
initial_state = TreeCumulativeSumQuery.GlobalState(
restarter_state = None
if self._restart_indicator is not None:
restarter_state = self._restart_indicator.initialize()
return TreeCumulativeSumQuery.GlobalState(
tree_state=initial_tree_state,
clip_value=tf.constant(self._clip_value, tf.float32),
samples_cumulative_sum=initial_samples_cumulative_sum)
return initial_state
samples_cumulative_sum=initial_samples_cumulative_sum,
restarter_state=restarter_state)
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
@ -151,13 +185,21 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
tf.add, global_state.samples_cumulative_sum, sample_state)
cumulative_sum_noise, new_tree_state = self._tree_aggregator.get_cumsum_and_update(
global_state.tree_state)
noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
cumulative_sum_noise)
restarter_state = global_state.restarter_state
if self._restart_indicator is not None:
restart_flag, restarter_state = self._restart_indicator.next(
restarter_state)
if restart_flag:
new_cumulative_sum = noised_cumulative_sum
new_tree_state = self._tree_aggregator.reset_state(new_tree_state)
new_global_state = attr.evolve(
global_state,
samples_cumulative_sum=new_cumulative_sum,
tree_state=new_tree_state)
noised_cum_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
cumulative_sum_noise)
return noised_cum_sum, new_global_state
tree_state=new_tree_state,
restarter_state=restarter_state)
return noised_cumulative_sum, new_global_state
@classmethod
def build_l2_gaussian_query(cls,
@ -165,7 +207,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
noise_multiplier,
record_specs,
noise_seed=None,
use_efficient=True):
use_efficient=True,
restart_indicator=None):
"""Returns a query instance with L2 norm clipping and Gaussian noise.
Args:
@ -180,6 +223,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
use_efficient: Boolean indicating the usage of the efficient tree
aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees".
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
"""
if clip_norm <= 0:
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
@ -202,22 +247,48 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
clip_value=clip_norm,
record_specs=record_specs,
noise_generator=gaussian_noise_generator,
use_efficient=use_efficient)
use_efficient=use_efficient,
restart_indicator=restart_indicator)
class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
"""Implements dp_query for adding correlated noise through tree structure.
"""Implements DPQuery for adding correlated noise through tree structure.
Clips and sums records in current sample; returns the current sample adding
the noise residual from tree aggregation. The returned value is conceptually
equivalent to the following: calculates cumulative sum of samples over time
(instead of only current sample) with added noise for cumulative sum
proportional to log(T), T being the number of times the query is called;
returns the residual between the current noised cumsum and the previous one
when the query is called. Combining this query with a SGD optimizer can be
used to implement the DP-FTRL algorithm in
Clips and sums records in current sample x_i = sum_{j=0}^{n-1} x_{i,j};
returns the current sample adding the noise residual from tree aggregation.
The returned value is conceptually equivalent to the following: calculates
cumulative sum of samples over time s_i = sum_{k=0}^i x_i (instead of only
current sample) with added noise by tree aggregation protocol that is
proportional to log(T), T being the number of times the query is called; r
eturns the residual between the current noised cumsum noised(s_i) and the
previous one noised(s_{i-1}) when the query is called.
This can be used as a drop-in replacement for `GaussianSumQuery`, and can
offer stronger utility/privacy tradeoffs when aplification-via-sampling is not
possible, or when privacy epsilon is relativly large. This may result in
more noise by a log(T) factor in each individual estimate of x_i, but if the
x_i are used in the underlying code to compute cumulative sums, the noise in
those sums can be less. That is, this allows us to adapt code that was written
to use a regular `SumQuery` to benefit from the tree aggregation protocol.
Combining this query with a SGD optimizer can be used to implement the
DP-FTRL algorithm in
"Practical and Private (Deep) Learning without Sampling or Shuffling".
Example usage:
query = TreeResidualSumQuery(...)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i, samples in enumerate(streaming_samples):
sample_state = query.initial_sample_state(samples[0])
# Compute x_i = sum_{j=0}^{n-1} x_{i,j}
for j,sample in enumerate(samples):
sample_state = query.accumulate_record(params, sample_state, sample)
# noised_sum is privatized estimate of x_i by conceptually postprocessing
# noised cumulative sum s_i
noised_sum, global_state = query.get_noised_result(
sample_state, global_state)
Attributes:
clip_fn: Callable that specifies clipping function. `clip_fn` receives two
arguments: a flat list of vars in a record and a `clip_value` to clip the
@ -231,6 +302,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
node. Noise stdandard deviation is specified outside the `dp_query` by the
user when defining `noise_fn` and should have order
O(clip_norm*log(T)/eps) to guarantee eps-DP.
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
"""
@attr.s(frozen=True)
@ -243,21 +316,25 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
clip_value: The clipping value to be passed to clip_fn.
previous_tree_noise: Cumulative noise by tree aggregation from the
previous time the query is called on a sample.
restarter_state: Current state of the restarter to indicate whether
the tree state will be reset.
"""
tree_state = attr.ib()
clip_value = attr.ib()
previous_tree_noise = attr.ib()
restarter_state = attr.ib()
def __init__(self,
record_specs,
noise_generator,
clip_fn,
clip_value,
use_efficient=True):
"""Initializes the `TreeResidualSumQuery`.
use_efficient=True,
restart_indicator=None):
"""Initializes the `TreeCumulativeSumQuery`.
Consider using `build_l2_gaussian_query` for the construction of a
`TreeResidualSumQuery` with L2 norm clipping and Gaussian noise.
`TreeCumulativeSumQuery` with L2 norm clipping and Gaussian noise.
Args:
record_specs: A nested structure of `tf.TensorSpec`s specifying structure
@ -271,6 +348,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
use_efficient: Boolean indicating the usage of the efficient tree
aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees".
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
"""
self._clip_fn = clip_fn
self._clip_value = clip_value
@ -280,16 +359,23 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
noise_generator)
else:
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
self._restart_indicator = restart_indicator
def _zero_initial_noise(self):
return tf.nest.map_structure(lambda spec: tf.zeros(spec.shape),
self._record_specs)
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
initial_tree_state = self._tree_aggregator.init_state()
initial_noise = tf.nest.map_structure(lambda spec: tf.zeros(spec.shape),
self._record_specs)
restarter_state = None
if self._restart_indicator is not None:
restarter_state = self._restart_indicator.initialize()
return TreeResidualSumQuery.GlobalState(
tree_state=initial_tree_state,
clip_value=tf.constant(self._clip_value, tf.float32),
previous_tree_noise=initial_noise)
previous_tree_noise=self._zero_initial_noise(),
restarter_state=restarter_state)
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
@ -328,8 +414,18 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c,
sample_state, tree_noise,
global_state.previous_tree_noise)
restarter_state = global_state.restarter_state
if self._restart_indicator is not None:
restart_flag, restarter_state = self._restart_indicator.next(
restarter_state)
if restart_flag:
tree_noise = self._zero_initial_noise()
new_tree_state = self._tree_aggregator.reset_state(new_tree_state)
new_global_state = attr.evolve(
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
global_state,
previous_tree_noise=tree_noise,
tree_state=new_tree_state,
restarter_state=restarter_state)
return noised_sample, new_global_state
@classmethod
@ -338,7 +434,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
noise_multiplier,
record_specs,
noise_seed=None,
use_efficient=True):
use_efficient=True,
restart_indicator=None):
"""Returns `TreeResidualSumQuery` with L2 norm clipping and Gaussian noise.
Args:
@ -353,6 +450,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
use_efficient: Boolean indicating the usage of the efficient tree
aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees".
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
"""
if clip_norm <= 0:
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
@ -375,7 +474,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
clip_value=clip_norm,
record_specs=record_specs,
noise_generator=gaussian_noise_generator,
use_efficient=use_efficient)
use_efficient=use_efficient,
restart_indicator=restart_indicator)
@tf.function

View file

@ -263,13 +263,13 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(query_result, expected_sum)
@parameterized.named_parameters(
('s0t1step8', 0., 1., [1., 1., 2., 1., 2., 2., 3., 1.]),
('s1t1step8', 1., 1., [2., 3., 5., 5., 7., 8., 10., 9.]),
('s1t2step8', 1., 2., [3., 4., 7., 6., 9., 10., 13., 10.]),
('s0t1', 0., 1.),
('s1t1', 1., 1.),
('s1t2', 1., 2.),
)
def test_partial_sum_scalar_tree_aggregation(self, scalar_value,
tree_node_value,
expected_values):
tree_node_value):
total_steps = 8
query = tree_aggregation_query.TreeCumulativeSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip
@ -279,14 +279,54 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for val in expected_values:
# For each streaming step i , the expected value is roughly
# `scalar_value*i + tree_aggregation(tree_node_value, i)`
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result(
sample_state, global_state)
self.assertEqual(query_result, val)
# For each streaming step i , the expected value is roughly
# `scalar_value*(i+1) + tree_aggregation(tree_node_value, i)`.
# The tree aggregation value can be inferred from the binary
# representation of the current step.
self.assertEqual(
query_result,
scalar_value * (i + 1) + tree_node_value * bin(i + 1)[2:].count('1'))
@parameterized.named_parameters(
('s0t1f1', 0., 1., 1),
('s0t1f2', 0., 1., 2),
('s0t1f5', 0., 1., 5),
('s1t1f5', 1., 1., 5),
('s1t2f2', 1., 2., 2),
('s1t5f6', 1., 5., 6),
)
def test_sum_scalar_tree_aggregation_reset(self, scalar_value,
tree_node_value, frequency):
total_steps = 20
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
query = tree_aggregation_query.TreeCumulativeSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip
noise_generator=lambda: tree_node_value,
record_specs=tf.TensorSpec([]),
use_efficient=False,
restart_indicator=indicator,
)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result(
sample_state, global_state)
# Expected value is the combination of cumsum of signal; sum of trees
# that have been reset; current tree sum. The tree aggregation value can
# be inferred from the binary representation of the current step.
expected = (
scalar_value * (i + 1) +
i // frequency * tree_node_value * bin(frequency)[2:].count('1') +
tree_node_value * bin(i % frequency + 1)[2:].count('1'))
self.assertEqual(query_result, expected)
@parameterized.named_parameters(
('efficient', True, tree_aggregation.EfficientTreeAggregator),
@ -395,6 +435,42 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
)
self.assertIsInstance(query._tree_aggregator, tree_class)
@parameterized.named_parameters(
('s0t1f1', 0., 1., 1),
('s0t1f2', 0., 1., 2),
('s0t1f5', 0., 1., 5),
('s1t1f5', 1., 1., 5),
('s1t2f2', 1., 2., 2),
('s1t5f6', 1., 5., 6),
)
def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value,
frequency):
total_steps = 20
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
query = tree_aggregation_query.TreeResidualSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip
noise_generator=lambda: tree_node_value,
record_specs=tf.TensorSpec([]),
use_efficient=False,
restart_indicator=indicator,
)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result(
sample_state, global_state)
# Expected value is the signal of the current round plus the residual of
# two continous tree aggregation values. The tree aggregation value can
# be inferred from the binary representation of the current step.
expected = scalar_value + tree_node_value * (
bin(i % frequency + 1)[2:].count('1') -
bin(i % frequency)[2:].count('1'))
print(i, query_result, expected)
self.assertEqual(query_result, expected)
class BuildTreeTest(tf.test.TestCase, parameterized.TestCase):

View file

@ -365,5 +365,26 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
self.assertAllEqual(gstate, gstate2)
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('zero', 0), ('negative', -1))
def test_round_raise(self, frequency):
with self.assertRaisesRegex(
ValueError, 'Restart frequency should be equal or larger than 1'):
tree_aggregation.PeriodicRoundRestartIndicator(frequency)
@parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5))
def test_round_indicator(self, frequency):
total_steps = 20
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
state = indicator.initialize()
for i in range(total_steps):
flag, state = indicator.next(state)
if i % frequency == frequency - 1:
self.assertTrue(flag)
else:
self.assertFalse(flag)
if __name__ == '__main__':
tf.test.main()