Restart the tree state in tree related DPQuery for streaming data: a general abstract class and an instance of restarting every a few rounds.
PiperOrigin-RevId: 390244330
This commit is contained in:
parent
f44dcb8760
commit
b4c04093cf
4 changed files with 399 additions and 67 deletions
|
@ -16,7 +16,10 @@
|
|||
`TreeAggregator` and `EfficientTreeAggregator` compute cumulative sums of noise
|
||||
based on tree aggregation. When using an appropriate noise function (e.g.,
|
||||
Gaussian noise), it allows for efficient differentially private algorithms under
|
||||
continual observation, without prior subsampling or shuffling assumptions.
|
||||
continual observation, without prior subsampling or shuffling assumptions. This
|
||||
module implements the core logic of tree aggregation in Tensorflow, which serves
|
||||
as helper functions for `tree_aggregation_query`. This module and helper
|
||||
functions are publicly accessible.
|
||||
"""
|
||||
|
||||
import abc
|
||||
|
@ -26,6 +29,10 @@ import attr
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
# TODO(b/192464750): find a proper place for the helper functions, privatize
|
||||
# the tree aggregation logic, and encourage users to use the DPQuery API.
|
||||
|
||||
|
||||
class ValueGenerator(metaclass=abc.ABCMeta):
|
||||
"""Base class establishing interface for stateful value generation.
|
||||
|
||||
|
@ -40,6 +47,7 @@ class ValueGenerator(metaclass=abc.ABCMeta):
|
|||
Returns:
|
||||
An initial state.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def next(self, state):
|
||||
|
@ -52,6 +60,7 @@ class ValueGenerator(metaclass=abc.ABCMeta):
|
|||
A pair (value, new_state) where value is the next value and new_state
|
||||
is the advanced state.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GaussianNoiseGenerator(ValueGenerator):
|
||||
|
@ -148,6 +157,78 @@ class StatelessValueGenerator(ValueGenerator):
|
|||
return self.value_fn(), state
|
||||
|
||||
|
||||
# TODO(b/192464750): define `RestartQuery` and move `RestartIndicator` to be
|
||||
# in the same module.
|
||||
|
||||
|
||||
class RestartIndicator(metaclass=abc.ABCMeta):
|
||||
"""Base class establishing interface for restarting the tree state.
|
||||
|
||||
A `RestartIndicator` maintains a state, and each time `next` is called, a bool
|
||||
value is generated to indicate whether to restart, and the indicator state is
|
||||
advanced.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def initialize(self):
|
||||
"""Makes an initialized state for `RestartIndicator`.
|
||||
|
||||
Returns:
|
||||
An initial state.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def next(self, state):
|
||||
"""Gets next bool indicator and advances the `RestartIndicator` state.
|
||||
|
||||
Args:
|
||||
state: The current state.
|
||||
|
||||
Returns:
|
||||
A pair (value, new_state) where value is bool indicator and new_state
|
||||
is the advanced state.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PeriodicRoundRestartIndicator(RestartIndicator):
|
||||
"""Indicator for resetting the tree state after every a few number of queries.
|
||||
|
||||
The indicator will maintain an internal counter as state.
|
||||
"""
|
||||
|
||||
def __init__(self, frequency: int):
|
||||
"""Construct the `PeriodicRoundRestartIndicator`.
|
||||
|
||||
Args:
|
||||
frequency: The `next` function will return `True` every `frequency` number
|
||||
of `next` calls.
|
||||
"""
|
||||
if frequency < 1:
|
||||
raise ValueError('Restart frequency should be equal or larger than 1 '
|
||||
f'got {frequency}')
|
||||
self.frequency = tf.constant(frequency, tf.int32)
|
||||
|
||||
def initialize(self):
|
||||
"""Returns initialized state of 0 for `PeriodicRoundRestartIndicator`."""
|
||||
return tf.constant(0, tf.int32)
|
||||
|
||||
def next(self, state):
|
||||
"""Gets next bool indicator and advances the state.
|
||||
|
||||
Args:
|
||||
state: The current state.
|
||||
|
||||
Returns:
|
||||
A pair (value, new_state) where value is the bool indicator and new_state
|
||||
of `state+1`.
|
||||
"""
|
||||
state = state + tf.constant(1, tf.int32)
|
||||
flag = state % self.frequency == 0
|
||||
return flag, state
|
||||
|
||||
|
||||
@attr.s(eq=False, frozen=True, slots=True)
|
||||
class TreeState(object):
|
||||
"""Class defining state of the tree.
|
||||
|
@ -166,6 +247,7 @@ class TreeState(object):
|
|||
value_generator_state = attr.ib(type=Any)
|
||||
|
||||
|
||||
# TODO(b/192464750): move `get_step_idx` to be a property of `TreeState`.
|
||||
@tf.function
|
||||
def get_step_idx(state: TreeState) -> tf.Tensor:
|
||||
"""Returns the current leaf node index based on `TreeState.level_buffer_idx`."""
|
||||
|
@ -188,6 +270,14 @@ class TreeAggregator():
|
|||
https://dl.acm.org/doi/pdf/10.1145/1806689.1806787. A buffer at the scale of
|
||||
tree depth is maintained and updated when a new conceptual leaf node arrives.
|
||||
|
||||
Example usage:
|
||||
random_generator = GaussianNoiseGenerator(...)
|
||||
tree_aggregator = TreeAggregator(random_generator)
|
||||
state = tree_aggregator.init_state()
|
||||
for leaf_node_idx in range(total_steps):
|
||||
assert leaf_node_idx == get_step_idx(state))
|
||||
noise, state = tree_aggregator.get_cumsum_and_update(state)
|
||||
|
||||
Attributes:
|
||||
value_generator: A `ValueGenerator` or a no-arg function to generate a noise
|
||||
value for each tree node.
|
||||
|
@ -205,14 +295,8 @@ class TreeAggregator():
|
|||
else:
|
||||
self.value_generator = StatelessValueGenerator(value_generator)
|
||||
|
||||
def init_state(self) -> TreeState:
|
||||
"""Returns initial `TreeState`.
|
||||
|
||||
Initializes `TreeState` for a tree of a single leaf node: the respective
|
||||
initial node value in `TreeState.level_buffer` is generated by the value
|
||||
generator function, and the node index is 0.
|
||||
"""
|
||||
value_generator_state = self.value_generator.initialize()
|
||||
def _get_init_state(self, value_generator_state) -> TreeState:
|
||||
"""Returns initial `TreeState` given `value_generator_state`."""
|
||||
level_buffer_idx = tf.TensorArray(dtype=tf.int32, size=1, dynamic_size=True)
|
||||
level_buffer_idx = level_buffer_idx.write(0, tf.constant(
|
||||
0, dtype=tf.int32)).stack()
|
||||
|
@ -224,12 +308,28 @@ class TreeAggregator():
|
|||
new_val)
|
||||
level_buffer = tf.nest.map_structure(lambda x, y: x.write(0, y).stack(),
|
||||
level_buffer_structure, new_val)
|
||||
|
||||
return TreeState(
|
||||
level_buffer=level_buffer,
|
||||
level_buffer_idx=level_buffer_idx,
|
||||
value_generator_state=value_generator_state)
|
||||
|
||||
def init_state(self) -> TreeState:
|
||||
"""Returns initial `TreeState`.
|
||||
|
||||
Initializes `TreeState` for a tree of a single leaf node: the respective
|
||||
initial node value in `TreeState.level_buffer` is generated by the value
|
||||
generator function, and the node index is 0.
|
||||
|
||||
Returns:
|
||||
An initialized `TreeState`.
|
||||
"""
|
||||
value_generator_state = self.value_generator.initialize()
|
||||
return self._get_init_state(value_generator_state)
|
||||
|
||||
def reset_state(self, state: TreeState) -> TreeState:
|
||||
"""Returns reset `TreeState` after restarting a new tree."""
|
||||
return self._get_init_state(state.value_generator_state)
|
||||
|
||||
@tf.function
|
||||
def _get_cumsum(self, level_buffer: Collection[tf.Tensor]) -> tf.Tensor:
|
||||
return tf.nest.map_structure(lambda x: tf.reduce_sum(x, axis=0),
|
||||
|
@ -238,7 +338,7 @@ class TreeAggregator():
|
|||
@tf.function
|
||||
def get_cumsum_and_update(self,
|
||||
state: TreeState) -> Tuple[tf.Tensor, TreeState]:
|
||||
"""Returns tree aggregated value and updated `TreeState` for one step.
|
||||
"""Returns tree aggregated noise and updates `TreeState` for the next step.
|
||||
|
||||
`TreeState` is updated to prepare for accepting the *next* leaf node. Note
|
||||
that `get_step_idx` can be called to get the current index of the leaf node
|
||||
|
@ -249,10 +349,20 @@ class TreeAggregator():
|
|||
Args:
|
||||
state: `TreeState` for the current leaf node, index can be queried by
|
||||
`tree_aggregation.get_step_idx(state.level_buffer_idx)`.
|
||||
|
||||
Returns:
|
||||
Tuple of (noise, state) where `noise` is generated by tree aggregated
|
||||
protocol for the cumulative sum of streaming data, and `state` is the
|
||||
updated `TreeState`.
|
||||
"""
|
||||
|
||||
level_buffer_idx, level_buffer, value_generator_state = (
|
||||
state.level_buffer_idx, state.level_buffer, state.value_generator_state)
|
||||
# We only publicize a combined function for updating state and returning
|
||||
# noised results because this DPQuery is designed for the streaming data,
|
||||
# and we only maintain a dynamic memory buffer of max size logT. Only the
|
||||
# the most recent noised results can be queried, and the queries are
|
||||
# expected to happen for every step in the streaming setting.
|
||||
cumsum = self._get_cumsum(level_buffer)
|
||||
|
||||
new_level_buffer = tf.nest.map_structure(
|
||||
|
@ -311,6 +421,14 @@ class EfficientTreeAggregator():
|
|||
`sigma * sqrt(2^{d-1}/(2^d-1))`. which becomes `sigma / sqrt(2)` when
|
||||
the tree is very tall.
|
||||
|
||||
Example usage:
|
||||
random_generator = GaussianNoiseGenerator(...)
|
||||
tree_aggregator = EfficientTreeAggregator(random_generator)
|
||||
state = tree_aggregator.init_state()
|
||||
for leaf_node_idx in range(total_steps):
|
||||
assert leaf_node_idx == get_step_idx(state))
|
||||
noise, state = tree_aggregator.get_cumsum_and_update(state)
|
||||
|
||||
Attributes:
|
||||
value_generator: A `ValueGenerator` or a no-arg function to generate a noise
|
||||
value for each tree node.
|
||||
|
@ -328,17 +446,8 @@ class EfficientTreeAggregator():
|
|||
else:
|
||||
self.value_generator = StatelessValueGenerator(value_generator)
|
||||
|
||||
def init_state(self) -> TreeState:
|
||||
"""Returns initial `TreeState`.
|
||||
|
||||
Initializes `TreeState` for a tree of a single leaf node: the respective
|
||||
initial node value in `TreeState.level_buffer` is generated by the value
|
||||
generator function, and the node index is 0.
|
||||
|
||||
Returns:
|
||||
An initialized `TreeState`.
|
||||
"""
|
||||
value_generator_state = self.value_generator.initialize()
|
||||
def _get_init_state(self, value_generator_state):
|
||||
"""Returns initial buffer for `TreeState`."""
|
||||
level_buffer_idx = tf.TensorArray(dtype=tf.int32, size=1, dynamic_size=True)
|
||||
level_buffer_idx = level_buffer_idx.write(0, tf.constant(
|
||||
0, dtype=tf.int32)).stack()
|
||||
|
@ -350,12 +459,28 @@ class EfficientTreeAggregator():
|
|||
new_val)
|
||||
level_buffer = tf.nest.map_structure(lambda x, y: x.write(0, y).stack(),
|
||||
level_buffer_structure, new_val)
|
||||
|
||||
return TreeState(
|
||||
level_buffer=level_buffer,
|
||||
level_buffer_idx=level_buffer_idx,
|
||||
value_generator_state=value_generator_state)
|
||||
|
||||
def init_state(self) -> TreeState:
|
||||
"""Returns initial `TreeState`.
|
||||
|
||||
Initializes `TreeState` for a tree of a single leaf node: the respective
|
||||
initial node value in `TreeState.level_buffer` is generated by the value
|
||||
generator function, and the node index is 0.
|
||||
|
||||
Returns:
|
||||
An initialized `TreeState`.
|
||||
"""
|
||||
value_generator_state = self.value_generator.initialize()
|
||||
return self._get_init_state(value_generator_state)
|
||||
|
||||
def reset_state(self, state: TreeState) -> TreeState:
|
||||
"""Returns reset `TreeState` after restarting a new tree."""
|
||||
return self._get_init_state(state.value_generator_state)
|
||||
|
||||
@tf.function
|
||||
def _get_cumsum(self, state: TreeState) -> tf.Tensor:
|
||||
"""Returns weighted cumulative sum of noise based on `TreeState`."""
|
||||
|
@ -377,7 +502,7 @@ class EfficientTreeAggregator():
|
|||
@tf.function
|
||||
def get_cumsum_and_update(self,
|
||||
state: TreeState) -> Tuple[tf.Tensor, TreeState]:
|
||||
"""Returns tree aggregated value and updated `TreeState` for one step.
|
||||
"""Returns tree aggregated noise and updates `TreeState` for the next step.
|
||||
|
||||
`TreeState` is updated to prepare for accepting the *next* leaf node. Note
|
||||
that `get_step_idx` can be called to get the current index of the leaf node
|
||||
|
@ -390,7 +515,17 @@ class EfficientTreeAggregator():
|
|||
Args:
|
||||
state: `TreeState` for the current leaf node, index can be queried by
|
||||
`tree_aggregation.get_step_idx(state.level_buffer_idx)`.
|
||||
|
||||
Returns:
|
||||
Tuple of (noise, state) where `noise` is generated by tree aggregated
|
||||
protocol for the cumulative sum of streaming data, and `state` is the
|
||||
updated `TreeState`..
|
||||
"""
|
||||
# We only publicize a combined function for updating state and returning
|
||||
# noised results because this DPQuery is designed for the streaming data,
|
||||
# and we only maintain a dynamic memory buffer of max size logT. Only the
|
||||
# the most recent noised results can be queried, and the queries are
|
||||
# expected to happen for every step in the streaming setting.
|
||||
cumsum = self._get_cumsum(state)
|
||||
|
||||
level_buffer_idx, level_buffer, value_generator_state = (
|
||||
|
|
|
@ -32,13 +32,35 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query
|
|||
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
||||
|
||||
|
||||
class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
||||
"""Implements dp_query for adding correlated noise through tree structure.
|
||||
# TODO(b/192464750): define `RestartQuery` and move `RestartIndicator` to be
|
||||
# in the same module.
|
||||
|
||||
First clips and sums records in current sample, returns cumulative sum of
|
||||
samples over time (instead of only current sample) with added noise for
|
||||
cumulative sum proportional to log(T), T being the number of times the query
|
||||
is called.
|
||||
|
||||
class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
||||
"""Returns private cumulative sums by clipping and adding correlated noise.
|
||||
|
||||
Consider calling `get_noised_result` T times, and each (x_i, i=0,2,...,T-1) is
|
||||
the private value returned by `accumulate_record`, i.e. x_i = sum_{j=0}^{n-1}
|
||||
x_{i,j} where each x_{i,j} is a private record in the database. This class is
|
||||
intended to make multiple queries, which release privatized values of the
|
||||
cumulative sums s_i = sum_{k=0}^{i} x_k, for i=0,...,T-1.
|
||||
Each call to `get_noised_result` releases the next cumulative sum s_i, which
|
||||
is in contrast to the GaussianSumQuery that releases x_i. Noise for the
|
||||
cumulative sums is accomplished using the tree aggregation logic in
|
||||
`tree_aggregation`, which is proportional to log(T).
|
||||
|
||||
Example usage:
|
||||
query = TreeCumulativeSumQuery(...)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
for i, samples in enumerate(streaming_samples):
|
||||
sample_state = query.initial_sample_state(samples[0])
|
||||
# Compute x_i = sum_{j=0}^{n-1} x_{i,j}
|
||||
for j,sample in enumerate(samples):
|
||||
sample_state = query.accumulate_record(params, sample_state, sample)
|
||||
# noised_cumsum is privatized estimate of s_i
|
||||
noised_cumsum, global_state = query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
|
||||
Attributes:
|
||||
clip_fn: Callable that specifies clipping function. `clip_fn` receives two
|
||||
|
@ -52,6 +74,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
node. Noise stdandard deviation is specified outside the `dp_query` by the
|
||||
user when defining `noise_fn` and should have order
|
||||
O(clip_norm*log(T)/eps) to guarantee eps-DP.
|
||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
||||
boolean indicator for resetting the tree state.
|
||||
"""
|
||||
|
||||
@attr.s(frozen=True)
|
||||
|
@ -63,17 +87,21 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
each level state.
|
||||
clip_value: The clipping value to be passed to clip_fn.
|
||||
samples_cumulative_sum: Noiseless cumulative sum of samples over time.
|
||||
restarter_state: Current state of the restarter to indicate whether
|
||||
the tree state will be reset.
|
||||
"""
|
||||
tree_state = attr.ib()
|
||||
clip_value = attr.ib()
|
||||
samples_cumulative_sum = attr.ib()
|
||||
restarter_state = attr.ib()
|
||||
|
||||
def __init__(self,
|
||||
record_specs,
|
||||
noise_generator,
|
||||
clip_fn,
|
||||
clip_value,
|
||||
use_efficient=True):
|
||||
use_efficient=True,
|
||||
restart_indicator=None):
|
||||
"""Initializes the `TreeCumulativeSumQuery`.
|
||||
|
||||
Consider using `build_l2_gaussian_query` for the construction of a
|
||||
|
@ -91,6 +119,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
use_efficient: Boolean indicating the usage of the efficient tree
|
||||
aggregation algorithm based on the paper "Efficient Use of
|
||||
Differentially Private Binary Trees".
|
||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
||||
boolean indicator for resetting the tree state.
|
||||
"""
|
||||
self._clip_fn = clip_fn
|
||||
self._clip_value = clip_value
|
||||
|
@ -100,17 +130,21 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
noise_generator)
|
||||
else:
|
||||
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
|
||||
self._restart_indicator = restart_indicator
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
initial_tree_state = self._tree_aggregator.init_state()
|
||||
initial_samples_cumulative_sum = tf.nest.map_structure(
|
||||
lambda spec: tf.zeros(spec.shape), self._record_specs)
|
||||
initial_state = TreeCumulativeSumQuery.GlobalState(
|
||||
restarter_state = None
|
||||
if self._restart_indicator is not None:
|
||||
restarter_state = self._restart_indicator.initialize()
|
||||
return TreeCumulativeSumQuery.GlobalState(
|
||||
tree_state=initial_tree_state,
|
||||
clip_value=tf.constant(self._clip_value, tf.float32),
|
||||
samples_cumulative_sum=initial_samples_cumulative_sum)
|
||||
return initial_state
|
||||
samples_cumulative_sum=initial_samples_cumulative_sum,
|
||||
restarter_state=restarter_state)
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
|
@ -151,13 +185,21 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
tf.add, global_state.samples_cumulative_sum, sample_state)
|
||||
cumulative_sum_noise, new_tree_state = self._tree_aggregator.get_cumsum_and_update(
|
||||
global_state.tree_state)
|
||||
noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
|
||||
cumulative_sum_noise)
|
||||
restarter_state = global_state.restarter_state
|
||||
if self._restart_indicator is not None:
|
||||
restart_flag, restarter_state = self._restart_indicator.next(
|
||||
restarter_state)
|
||||
if restart_flag:
|
||||
new_cumulative_sum = noised_cumulative_sum
|
||||
new_tree_state = self._tree_aggregator.reset_state(new_tree_state)
|
||||
new_global_state = attr.evolve(
|
||||
global_state,
|
||||
samples_cumulative_sum=new_cumulative_sum,
|
||||
tree_state=new_tree_state)
|
||||
noised_cum_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
|
||||
cumulative_sum_noise)
|
||||
return noised_cum_sum, new_global_state
|
||||
tree_state=new_tree_state,
|
||||
restarter_state=restarter_state)
|
||||
return noised_cumulative_sum, new_global_state
|
||||
|
||||
@classmethod
|
||||
def build_l2_gaussian_query(cls,
|
||||
|
@ -165,7 +207,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
noise_multiplier,
|
||||
record_specs,
|
||||
noise_seed=None,
|
||||
use_efficient=True):
|
||||
use_efficient=True,
|
||||
restart_indicator=None):
|
||||
"""Returns a query instance with L2 norm clipping and Gaussian noise.
|
||||
|
||||
Args:
|
||||
|
@ -180,6 +223,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
use_efficient: Boolean indicating the usage of the efficient tree
|
||||
aggregation algorithm based on the paper "Efficient Use of
|
||||
Differentially Private Binary Trees".
|
||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
||||
boolean indicator for resetting the tree state.
|
||||
"""
|
||||
if clip_norm <= 0:
|
||||
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
|
||||
|
@ -202,22 +247,48 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
clip_value=clip_norm,
|
||||
record_specs=record_specs,
|
||||
noise_generator=gaussian_noise_generator,
|
||||
use_efficient=use_efficient)
|
||||
use_efficient=use_efficient,
|
||||
restart_indicator=restart_indicator)
|
||||
|
||||
|
||||
class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||
"""Implements dp_query for adding correlated noise through tree structure.
|
||||
"""Implements DPQuery for adding correlated noise through tree structure.
|
||||
|
||||
Clips and sums records in current sample; returns the current sample adding
|
||||
the noise residual from tree aggregation. The returned value is conceptually
|
||||
equivalent to the following: calculates cumulative sum of samples over time
|
||||
(instead of only current sample) with added noise for cumulative sum
|
||||
proportional to log(T), T being the number of times the query is called;
|
||||
returns the residual between the current noised cumsum and the previous one
|
||||
when the query is called. Combining this query with a SGD optimizer can be
|
||||
used to implement the DP-FTRL algorithm in
|
||||
Clips and sums records in current sample x_i = sum_{j=0}^{n-1} x_{i,j};
|
||||
returns the current sample adding the noise residual from tree aggregation.
|
||||
The returned value is conceptually equivalent to the following: calculates
|
||||
cumulative sum of samples over time s_i = sum_{k=0}^i x_i (instead of only
|
||||
current sample) with added noise by tree aggregation protocol that is
|
||||
proportional to log(T), T being the number of times the query is called; r
|
||||
eturns the residual between the current noised cumsum noised(s_i) and the
|
||||
previous one noised(s_{i-1}) when the query is called.
|
||||
|
||||
This can be used as a drop-in replacement for `GaussianSumQuery`, and can
|
||||
offer stronger utility/privacy tradeoffs when aplification-via-sampling is not
|
||||
possible, or when privacy epsilon is relativly large. This may result in
|
||||
more noise by a log(T) factor in each individual estimate of x_i, but if the
|
||||
x_i are used in the underlying code to compute cumulative sums, the noise in
|
||||
those sums can be less. That is, this allows us to adapt code that was written
|
||||
to use a regular `SumQuery` to benefit from the tree aggregation protocol.
|
||||
|
||||
Combining this query with a SGD optimizer can be used to implement the
|
||||
DP-FTRL algorithm in
|
||||
"Practical and Private (Deep) Learning without Sampling or Shuffling".
|
||||
|
||||
Example usage:
|
||||
query = TreeResidualSumQuery(...)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
for i, samples in enumerate(streaming_samples):
|
||||
sample_state = query.initial_sample_state(samples[0])
|
||||
# Compute x_i = sum_{j=0}^{n-1} x_{i,j}
|
||||
for j,sample in enumerate(samples):
|
||||
sample_state = query.accumulate_record(params, sample_state, sample)
|
||||
# noised_sum is privatized estimate of x_i by conceptually postprocessing
|
||||
# noised cumulative sum s_i
|
||||
noised_sum, global_state = query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
|
||||
Attributes:
|
||||
clip_fn: Callable that specifies clipping function. `clip_fn` receives two
|
||||
arguments: a flat list of vars in a record and a `clip_value` to clip the
|
||||
|
@ -231,6 +302,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
node. Noise stdandard deviation is specified outside the `dp_query` by the
|
||||
user when defining `noise_fn` and should have order
|
||||
O(clip_norm*log(T)/eps) to guarantee eps-DP.
|
||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
||||
boolean indicator for resetting the tree state.
|
||||
"""
|
||||
|
||||
@attr.s(frozen=True)
|
||||
|
@ -243,21 +316,25 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
clip_value: The clipping value to be passed to clip_fn.
|
||||
previous_tree_noise: Cumulative noise by tree aggregation from the
|
||||
previous time the query is called on a sample.
|
||||
restarter_state: Current state of the restarter to indicate whether
|
||||
the tree state will be reset.
|
||||
"""
|
||||
tree_state = attr.ib()
|
||||
clip_value = attr.ib()
|
||||
previous_tree_noise = attr.ib()
|
||||
restarter_state = attr.ib()
|
||||
|
||||
def __init__(self,
|
||||
record_specs,
|
||||
noise_generator,
|
||||
clip_fn,
|
||||
clip_value,
|
||||
use_efficient=True):
|
||||
"""Initializes the `TreeResidualSumQuery`.
|
||||
use_efficient=True,
|
||||
restart_indicator=None):
|
||||
"""Initializes the `TreeCumulativeSumQuery`.
|
||||
|
||||
Consider using `build_l2_gaussian_query` for the construction of a
|
||||
`TreeResidualSumQuery` with L2 norm clipping and Gaussian noise.
|
||||
`TreeCumulativeSumQuery` with L2 norm clipping and Gaussian noise.
|
||||
|
||||
Args:
|
||||
record_specs: A nested structure of `tf.TensorSpec`s specifying structure
|
||||
|
@ -271,6 +348,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
use_efficient: Boolean indicating the usage of the efficient tree
|
||||
aggregation algorithm based on the paper "Efficient Use of
|
||||
Differentially Private Binary Trees".
|
||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
||||
boolean indicator for resetting the tree state.
|
||||
"""
|
||||
self._clip_fn = clip_fn
|
||||
self._clip_value = clip_value
|
||||
|
@ -280,16 +359,23 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
noise_generator)
|
||||
else:
|
||||
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
|
||||
self._restart_indicator = restart_indicator
|
||||
|
||||
def _zero_initial_noise(self):
|
||||
return tf.nest.map_structure(lambda spec: tf.zeros(spec.shape),
|
||||
self._record_specs)
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
initial_tree_state = self._tree_aggregator.init_state()
|
||||
initial_noise = tf.nest.map_structure(lambda spec: tf.zeros(spec.shape),
|
||||
self._record_specs)
|
||||
restarter_state = None
|
||||
if self._restart_indicator is not None:
|
||||
restarter_state = self._restart_indicator.initialize()
|
||||
return TreeResidualSumQuery.GlobalState(
|
||||
tree_state=initial_tree_state,
|
||||
clip_value=tf.constant(self._clip_value, tf.float32),
|
||||
previous_tree_noise=initial_noise)
|
||||
previous_tree_noise=self._zero_initial_noise(),
|
||||
restarter_state=restarter_state)
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
|
@ -328,8 +414,18 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c,
|
||||
sample_state, tree_noise,
|
||||
global_state.previous_tree_noise)
|
||||
restarter_state = global_state.restarter_state
|
||||
if self._restart_indicator is not None:
|
||||
restart_flag, restarter_state = self._restart_indicator.next(
|
||||
restarter_state)
|
||||
if restart_flag:
|
||||
tree_noise = self._zero_initial_noise()
|
||||
new_tree_state = self._tree_aggregator.reset_state(new_tree_state)
|
||||
new_global_state = attr.evolve(
|
||||
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
|
||||
global_state,
|
||||
previous_tree_noise=tree_noise,
|
||||
tree_state=new_tree_state,
|
||||
restarter_state=restarter_state)
|
||||
return noised_sample, new_global_state
|
||||
|
||||
@classmethod
|
||||
|
@ -338,7 +434,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
noise_multiplier,
|
||||
record_specs,
|
||||
noise_seed=None,
|
||||
use_efficient=True):
|
||||
use_efficient=True,
|
||||
restart_indicator=None):
|
||||
"""Returns `TreeResidualSumQuery` with L2 norm clipping and Gaussian noise.
|
||||
|
||||
Args:
|
||||
|
@ -353,6 +450,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
use_efficient: Boolean indicating the usage of the efficient tree
|
||||
aggregation algorithm based on the paper "Efficient Use of
|
||||
Differentially Private Binary Trees".
|
||||
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
|
||||
boolean indicator for resetting the tree state.
|
||||
"""
|
||||
if clip_norm <= 0:
|
||||
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
|
||||
|
@ -375,7 +474,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
clip_value=clip_norm,
|
||||
record_specs=record_specs,
|
||||
noise_generator=gaussian_noise_generator,
|
||||
use_efficient=use_efficient)
|
||||
use_efficient=use_efficient,
|
||||
restart_indicator=restart_indicator)
|
||||
|
||||
|
||||
@tf.function
|
||||
|
|
|
@ -263,13 +263,13 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertEqual(query_result, expected_sum)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('s0t1step8', 0., 1., [1., 1., 2., 1., 2., 2., 3., 1.]),
|
||||
('s1t1step8', 1., 1., [2., 3., 5., 5., 7., 8., 10., 9.]),
|
||||
('s1t2step8', 1., 2., [3., 4., 7., 6., 9., 10., 13., 10.]),
|
||||
('s0t1', 0., 1.),
|
||||
('s1t1', 1., 1.),
|
||||
('s1t2', 1., 2.),
|
||||
)
|
||||
def test_partial_sum_scalar_tree_aggregation(self, scalar_value,
|
||||
tree_node_value,
|
||||
expected_values):
|
||||
tree_node_value):
|
||||
total_steps = 8
|
||||
query = tree_aggregation_query.TreeCumulativeSumQuery(
|
||||
clip_fn=_get_l2_clip_fn(),
|
||||
clip_value=scalar_value + 1., # no clip
|
||||
|
@ -279,14 +279,54 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
for val in expected_values:
|
||||
# For each streaming step i , the expected value is roughly
|
||||
# `scalar_value*i + tree_aggregation(tree_node_value, i)`
|
||||
for i in range(total_steps):
|
||||
sample_state = query.initial_sample_state(scalar_value)
|
||||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||
query_result, global_state = query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
self.assertEqual(query_result, val)
|
||||
# For each streaming step i , the expected value is roughly
|
||||
# `scalar_value*(i+1) + tree_aggregation(tree_node_value, i)`.
|
||||
# The tree aggregation value can be inferred from the binary
|
||||
# representation of the current step.
|
||||
self.assertEqual(
|
||||
query_result,
|
||||
scalar_value * (i + 1) + tree_node_value * bin(i + 1)[2:].count('1'))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('s0t1f1', 0., 1., 1),
|
||||
('s0t1f2', 0., 1., 2),
|
||||
('s0t1f5', 0., 1., 5),
|
||||
('s1t1f5', 1., 1., 5),
|
||||
('s1t2f2', 1., 2., 2),
|
||||
('s1t5f6', 1., 5., 6),
|
||||
)
|
||||
def test_sum_scalar_tree_aggregation_reset(self, scalar_value,
|
||||
tree_node_value, frequency):
|
||||
total_steps = 20
|
||||
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
|
||||
query = tree_aggregation_query.TreeCumulativeSumQuery(
|
||||
clip_fn=_get_l2_clip_fn(),
|
||||
clip_value=scalar_value + 1., # no clip
|
||||
noise_generator=lambda: tree_node_value,
|
||||
record_specs=tf.TensorSpec([]),
|
||||
use_efficient=False,
|
||||
restart_indicator=indicator,
|
||||
)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
for i in range(total_steps):
|
||||
sample_state = query.initial_sample_state(scalar_value)
|
||||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||
query_result, global_state = query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
# Expected value is the combination of cumsum of signal; sum of trees
|
||||
# that have been reset; current tree sum. The tree aggregation value can
|
||||
# be inferred from the binary representation of the current step.
|
||||
expected = (
|
||||
scalar_value * (i + 1) +
|
||||
i // frequency * tree_node_value * bin(frequency)[2:].count('1') +
|
||||
tree_node_value * bin(i % frequency + 1)[2:].count('1'))
|
||||
self.assertEqual(query_result, expected)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('efficient', True, tree_aggregation.EfficientTreeAggregator),
|
||||
|
@ -395,6 +435,42 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
)
|
||||
self.assertIsInstance(query._tree_aggregator, tree_class)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('s0t1f1', 0., 1., 1),
|
||||
('s0t1f2', 0., 1., 2),
|
||||
('s0t1f5', 0., 1., 5),
|
||||
('s1t1f5', 1., 1., 5),
|
||||
('s1t2f2', 1., 2., 2),
|
||||
('s1t5f6', 1., 5., 6),
|
||||
)
|
||||
def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value,
|
||||
frequency):
|
||||
total_steps = 20
|
||||
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
|
||||
query = tree_aggregation_query.TreeResidualSumQuery(
|
||||
clip_fn=_get_l2_clip_fn(),
|
||||
clip_value=scalar_value + 1., # no clip
|
||||
noise_generator=lambda: tree_node_value,
|
||||
record_specs=tf.TensorSpec([]),
|
||||
use_efficient=False,
|
||||
restart_indicator=indicator,
|
||||
)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
for i in range(total_steps):
|
||||
sample_state = query.initial_sample_state(scalar_value)
|
||||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||
query_result, global_state = query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
# Expected value is the signal of the current round plus the residual of
|
||||
# two continous tree aggregation values. The tree aggregation value can
|
||||
# be inferred from the binary representation of the current step.
|
||||
expected = scalar_value + tree_node_value * (
|
||||
bin(i % frequency + 1)[2:].count('1') -
|
||||
bin(i % frequency)[2:].count('1'))
|
||||
print(i, query_result, expected)
|
||||
self.assertEqual(query_result, expected)
|
||||
|
||||
|
||||
class BuildTreeTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
|
|
@ -365,5 +365,26 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
|
|||
self.assertAllEqual(gstate, gstate2)
|
||||
|
||||
|
||||
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(('zero', 0), ('negative', -1))
|
||||
def test_round_raise(self, frequency):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Restart frequency should be equal or larger than 1'):
|
||||
tree_aggregation.PeriodicRoundRestartIndicator(frequency)
|
||||
|
||||
@parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5))
|
||||
def test_round_indicator(self, frequency):
|
||||
total_steps = 20
|
||||
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
|
||||
state = indicator.initialize()
|
||||
for i in range(total_steps):
|
||||
flag, state = indicator.next(state)
|
||||
if i % frequency == frequency - 1:
|
||||
self.assertTrue(flag)
|
||||
else:
|
||||
self.assertFalse(flag)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in a new issue