Restart the tree state in tree related DPQuery for streaming data: a general abstract class and an instance of restarting every a few rounds.

PiperOrigin-RevId: 390244330
This commit is contained in:
Zheng Xu 2021-08-11 16:25:32 -07:00 committed by A. Unique TensorFlower
parent f44dcb8760
commit b4c04093cf
4 changed files with 399 additions and 67 deletions

View file

@ -16,7 +16,10 @@
`TreeAggregator` and `EfficientTreeAggregator` compute cumulative sums of noise `TreeAggregator` and `EfficientTreeAggregator` compute cumulative sums of noise
based on tree aggregation. When using an appropriate noise function (e.g., based on tree aggregation. When using an appropriate noise function (e.g.,
Gaussian noise), it allows for efficient differentially private algorithms under Gaussian noise), it allows for efficient differentially private algorithms under
continual observation, without prior subsampling or shuffling assumptions. continual observation, without prior subsampling or shuffling assumptions. This
module implements the core logic of tree aggregation in Tensorflow, which serves
as helper functions for `tree_aggregation_query`. This module and helper
functions are publicly accessible.
""" """
import abc import abc
@ -26,6 +29,10 @@ import attr
import tensorflow as tf import tensorflow as tf
# TODO(b/192464750): find a proper place for the helper functions, privatize
# the tree aggregation logic, and encourage users to use the DPQuery API.
class ValueGenerator(metaclass=abc.ABCMeta): class ValueGenerator(metaclass=abc.ABCMeta):
"""Base class establishing interface for stateful value generation. """Base class establishing interface for stateful value generation.
@ -40,6 +47,7 @@ class ValueGenerator(metaclass=abc.ABCMeta):
Returns: Returns:
An initial state. An initial state.
""" """
raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def next(self, state): def next(self, state):
@ -52,6 +60,7 @@ class ValueGenerator(metaclass=abc.ABCMeta):
A pair (value, new_state) where value is the next value and new_state A pair (value, new_state) where value is the next value and new_state
is the advanced state. is the advanced state.
""" """
raise NotImplementedError
class GaussianNoiseGenerator(ValueGenerator): class GaussianNoiseGenerator(ValueGenerator):
@ -148,6 +157,78 @@ class StatelessValueGenerator(ValueGenerator):
return self.value_fn(), state return self.value_fn(), state
# TODO(b/192464750): define `RestartQuery` and move `RestartIndicator` to be
# in the same module.
class RestartIndicator(metaclass=abc.ABCMeta):
"""Base class establishing interface for restarting the tree state.
A `RestartIndicator` maintains a state, and each time `next` is called, a bool
value is generated to indicate whether to restart, and the indicator state is
advanced.
"""
@abc.abstractmethod
def initialize(self):
"""Makes an initialized state for `RestartIndicator`.
Returns:
An initial state.
"""
raise NotImplementedError
@abc.abstractmethod
def next(self, state):
"""Gets next bool indicator and advances the `RestartIndicator` state.
Args:
state: The current state.
Returns:
A pair (value, new_state) where value is bool indicator and new_state
is the advanced state.
"""
raise NotImplementedError
class PeriodicRoundRestartIndicator(RestartIndicator):
"""Indicator for resetting the tree state after every a few number of queries.
The indicator will maintain an internal counter as state.
"""
def __init__(self, frequency: int):
"""Construct the `PeriodicRoundRestartIndicator`.
Args:
frequency: The `next` function will return `True` every `frequency` number
of `next` calls.
"""
if frequency < 1:
raise ValueError('Restart frequency should be equal or larger than 1 '
f'got {frequency}')
self.frequency = tf.constant(frequency, tf.int32)
def initialize(self):
"""Returns initialized state of 0 for `PeriodicRoundRestartIndicator`."""
return tf.constant(0, tf.int32)
def next(self, state):
"""Gets next bool indicator and advances the state.
Args:
state: The current state.
Returns:
A pair (value, new_state) where value is the bool indicator and new_state
of `state+1`.
"""
state = state + tf.constant(1, tf.int32)
flag = state % self.frequency == 0
return flag, state
@attr.s(eq=False, frozen=True, slots=True) @attr.s(eq=False, frozen=True, slots=True)
class TreeState(object): class TreeState(object):
"""Class defining state of the tree. """Class defining state of the tree.
@ -166,6 +247,7 @@ class TreeState(object):
value_generator_state = attr.ib(type=Any) value_generator_state = attr.ib(type=Any)
# TODO(b/192464750): move `get_step_idx` to be a property of `TreeState`.
@tf.function @tf.function
def get_step_idx(state: TreeState) -> tf.Tensor: def get_step_idx(state: TreeState) -> tf.Tensor:
"""Returns the current leaf node index based on `TreeState.level_buffer_idx`.""" """Returns the current leaf node index based on `TreeState.level_buffer_idx`."""
@ -188,6 +270,14 @@ class TreeAggregator():
https://dl.acm.org/doi/pdf/10.1145/1806689.1806787. A buffer at the scale of https://dl.acm.org/doi/pdf/10.1145/1806689.1806787. A buffer at the scale of
tree depth is maintained and updated when a new conceptual leaf node arrives. tree depth is maintained and updated when a new conceptual leaf node arrives.
Example usage:
random_generator = GaussianNoiseGenerator(...)
tree_aggregator = TreeAggregator(random_generator)
state = tree_aggregator.init_state()
for leaf_node_idx in range(total_steps):
assert leaf_node_idx == get_step_idx(state))
noise, state = tree_aggregator.get_cumsum_and_update(state)
Attributes: Attributes:
value_generator: A `ValueGenerator` or a no-arg function to generate a noise value_generator: A `ValueGenerator` or a no-arg function to generate a noise
value for each tree node. value for each tree node.
@ -205,14 +295,8 @@ class TreeAggregator():
else: else:
self.value_generator = StatelessValueGenerator(value_generator) self.value_generator = StatelessValueGenerator(value_generator)
def init_state(self) -> TreeState: def _get_init_state(self, value_generator_state) -> TreeState:
"""Returns initial `TreeState`. """Returns initial `TreeState` given `value_generator_state`."""
Initializes `TreeState` for a tree of a single leaf node: the respective
initial node value in `TreeState.level_buffer` is generated by the value
generator function, and the node index is 0.
"""
value_generator_state = self.value_generator.initialize()
level_buffer_idx = tf.TensorArray(dtype=tf.int32, size=1, dynamic_size=True) level_buffer_idx = tf.TensorArray(dtype=tf.int32, size=1, dynamic_size=True)
level_buffer_idx = level_buffer_idx.write(0, tf.constant( level_buffer_idx = level_buffer_idx.write(0, tf.constant(
0, dtype=tf.int32)).stack() 0, dtype=tf.int32)).stack()
@ -224,12 +308,28 @@ class TreeAggregator():
new_val) new_val)
level_buffer = tf.nest.map_structure(lambda x, y: x.write(0, y).stack(), level_buffer = tf.nest.map_structure(lambda x, y: x.write(0, y).stack(),
level_buffer_structure, new_val) level_buffer_structure, new_val)
return TreeState( return TreeState(
level_buffer=level_buffer, level_buffer=level_buffer,
level_buffer_idx=level_buffer_idx, level_buffer_idx=level_buffer_idx,
value_generator_state=value_generator_state) value_generator_state=value_generator_state)
def init_state(self) -> TreeState:
"""Returns initial `TreeState`.
Initializes `TreeState` for a tree of a single leaf node: the respective
initial node value in `TreeState.level_buffer` is generated by the value
generator function, and the node index is 0.
Returns:
An initialized `TreeState`.
"""
value_generator_state = self.value_generator.initialize()
return self._get_init_state(value_generator_state)
def reset_state(self, state: TreeState) -> TreeState:
"""Returns reset `TreeState` after restarting a new tree."""
return self._get_init_state(state.value_generator_state)
@tf.function @tf.function
def _get_cumsum(self, level_buffer: Collection[tf.Tensor]) -> tf.Tensor: def _get_cumsum(self, level_buffer: Collection[tf.Tensor]) -> tf.Tensor:
return tf.nest.map_structure(lambda x: tf.reduce_sum(x, axis=0), return tf.nest.map_structure(lambda x: tf.reduce_sum(x, axis=0),
@ -238,7 +338,7 @@ class TreeAggregator():
@tf.function @tf.function
def get_cumsum_and_update(self, def get_cumsum_and_update(self,
state: TreeState) -> Tuple[tf.Tensor, TreeState]: state: TreeState) -> Tuple[tf.Tensor, TreeState]:
"""Returns tree aggregated value and updated `TreeState` for one step. """Returns tree aggregated noise and updates `TreeState` for the next step.
`TreeState` is updated to prepare for accepting the *next* leaf node. Note `TreeState` is updated to prepare for accepting the *next* leaf node. Note
that `get_step_idx` can be called to get the current index of the leaf node that `get_step_idx` can be called to get the current index of the leaf node
@ -249,10 +349,20 @@ class TreeAggregator():
Args: Args:
state: `TreeState` for the current leaf node, index can be queried by state: `TreeState` for the current leaf node, index can be queried by
`tree_aggregation.get_step_idx(state.level_buffer_idx)`. `tree_aggregation.get_step_idx(state.level_buffer_idx)`.
Returns:
Tuple of (noise, state) where `noise` is generated by tree aggregated
protocol for the cumulative sum of streaming data, and `state` is the
updated `TreeState`.
""" """
level_buffer_idx, level_buffer, value_generator_state = ( level_buffer_idx, level_buffer, value_generator_state = (
state.level_buffer_idx, state.level_buffer, state.value_generator_state) state.level_buffer_idx, state.level_buffer, state.value_generator_state)
# We only publicize a combined function for updating state and returning
# noised results because this DPQuery is designed for the streaming data,
# and we only maintain a dynamic memory buffer of max size logT. Only the
# the most recent noised results can be queried, and the queries are
# expected to happen for every step in the streaming setting.
cumsum = self._get_cumsum(level_buffer) cumsum = self._get_cumsum(level_buffer)
new_level_buffer = tf.nest.map_structure( new_level_buffer = tf.nest.map_structure(
@ -311,6 +421,14 @@ class EfficientTreeAggregator():
`sigma * sqrt(2^{d-1}/(2^d-1))`. which becomes `sigma / sqrt(2)` when `sigma * sqrt(2^{d-1}/(2^d-1))`. which becomes `sigma / sqrt(2)` when
the tree is very tall. the tree is very tall.
Example usage:
random_generator = GaussianNoiseGenerator(...)
tree_aggregator = EfficientTreeAggregator(random_generator)
state = tree_aggregator.init_state()
for leaf_node_idx in range(total_steps):
assert leaf_node_idx == get_step_idx(state))
noise, state = tree_aggregator.get_cumsum_and_update(state)
Attributes: Attributes:
value_generator: A `ValueGenerator` or a no-arg function to generate a noise value_generator: A `ValueGenerator` or a no-arg function to generate a noise
value for each tree node. value for each tree node.
@ -328,17 +446,8 @@ class EfficientTreeAggregator():
else: else:
self.value_generator = StatelessValueGenerator(value_generator) self.value_generator = StatelessValueGenerator(value_generator)
def init_state(self) -> TreeState: def _get_init_state(self, value_generator_state):
"""Returns initial `TreeState`. """Returns initial buffer for `TreeState`."""
Initializes `TreeState` for a tree of a single leaf node: the respective
initial node value in `TreeState.level_buffer` is generated by the value
generator function, and the node index is 0.
Returns:
An initialized `TreeState`.
"""
value_generator_state = self.value_generator.initialize()
level_buffer_idx = tf.TensorArray(dtype=tf.int32, size=1, dynamic_size=True) level_buffer_idx = tf.TensorArray(dtype=tf.int32, size=1, dynamic_size=True)
level_buffer_idx = level_buffer_idx.write(0, tf.constant( level_buffer_idx = level_buffer_idx.write(0, tf.constant(
0, dtype=tf.int32)).stack() 0, dtype=tf.int32)).stack()
@ -350,12 +459,28 @@ class EfficientTreeAggregator():
new_val) new_val)
level_buffer = tf.nest.map_structure(lambda x, y: x.write(0, y).stack(), level_buffer = tf.nest.map_structure(lambda x, y: x.write(0, y).stack(),
level_buffer_structure, new_val) level_buffer_structure, new_val)
return TreeState( return TreeState(
level_buffer=level_buffer, level_buffer=level_buffer,
level_buffer_idx=level_buffer_idx, level_buffer_idx=level_buffer_idx,
value_generator_state=value_generator_state) value_generator_state=value_generator_state)
def init_state(self) -> TreeState:
"""Returns initial `TreeState`.
Initializes `TreeState` for a tree of a single leaf node: the respective
initial node value in `TreeState.level_buffer` is generated by the value
generator function, and the node index is 0.
Returns:
An initialized `TreeState`.
"""
value_generator_state = self.value_generator.initialize()
return self._get_init_state(value_generator_state)
def reset_state(self, state: TreeState) -> TreeState:
"""Returns reset `TreeState` after restarting a new tree."""
return self._get_init_state(state.value_generator_state)
@tf.function @tf.function
def _get_cumsum(self, state: TreeState) -> tf.Tensor: def _get_cumsum(self, state: TreeState) -> tf.Tensor:
"""Returns weighted cumulative sum of noise based on `TreeState`.""" """Returns weighted cumulative sum of noise based on `TreeState`."""
@ -377,7 +502,7 @@ class EfficientTreeAggregator():
@tf.function @tf.function
def get_cumsum_and_update(self, def get_cumsum_and_update(self,
state: TreeState) -> Tuple[tf.Tensor, TreeState]: state: TreeState) -> Tuple[tf.Tensor, TreeState]:
"""Returns tree aggregated value and updated `TreeState` for one step. """Returns tree aggregated noise and updates `TreeState` for the next step.
`TreeState` is updated to prepare for accepting the *next* leaf node. Note `TreeState` is updated to prepare for accepting the *next* leaf node. Note
that `get_step_idx` can be called to get the current index of the leaf node that `get_step_idx` can be called to get the current index of the leaf node
@ -390,7 +515,17 @@ class EfficientTreeAggregator():
Args: Args:
state: `TreeState` for the current leaf node, index can be queried by state: `TreeState` for the current leaf node, index can be queried by
`tree_aggregation.get_step_idx(state.level_buffer_idx)`. `tree_aggregation.get_step_idx(state.level_buffer_idx)`.
Returns:
Tuple of (noise, state) where `noise` is generated by tree aggregated
protocol for the cumulative sum of streaming data, and `state` is the
updated `TreeState`..
""" """
# We only publicize a combined function for updating state and returning
# noised results because this DPQuery is designed for the streaming data,
# and we only maintain a dynamic memory buffer of max size logT. Only the
# the most recent noised results can be queried, and the queries are
# expected to happen for every step in the streaming setting.
cumsum = self._get_cumsum(state) cumsum = self._get_cumsum(state)
level_buffer_idx, level_buffer, value_generator_state = ( level_buffer_idx, level_buffer, value_generator_state = (

View file

@ -32,13 +32,35 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation from tensorflow_privacy.privacy.dp_query import tree_aggregation
class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): # TODO(b/192464750): define `RestartQuery` and move `RestartIndicator` to be
"""Implements dp_query for adding correlated noise through tree structure. # in the same module.
First clips and sums records in current sample, returns cumulative sum of
samples over time (instead of only current sample) with added noise for class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
cumulative sum proportional to log(T), T being the number of times the query """Returns private cumulative sums by clipping and adding correlated noise.
is called.
Consider calling `get_noised_result` T times, and each (x_i, i=0,2,...,T-1) is
the private value returned by `accumulate_record`, i.e. x_i = sum_{j=0}^{n-1}
x_{i,j} where each x_{i,j} is a private record in the database. This class is
intended to make multiple queries, which release privatized values of the
cumulative sums s_i = sum_{k=0}^{i} x_k, for i=0,...,T-1.
Each call to `get_noised_result` releases the next cumulative sum s_i, which
is in contrast to the GaussianSumQuery that releases x_i. Noise for the
cumulative sums is accomplished using the tree aggregation logic in
`tree_aggregation`, which is proportional to log(T).
Example usage:
query = TreeCumulativeSumQuery(...)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i, samples in enumerate(streaming_samples):
sample_state = query.initial_sample_state(samples[0])
# Compute x_i = sum_{j=0}^{n-1} x_{i,j}
for j,sample in enumerate(samples):
sample_state = query.accumulate_record(params, sample_state, sample)
# noised_cumsum is privatized estimate of s_i
noised_cumsum, global_state = query.get_noised_result(
sample_state, global_state)
Attributes: Attributes:
clip_fn: Callable that specifies clipping function. `clip_fn` receives two clip_fn: Callable that specifies clipping function. `clip_fn` receives two
@ -52,6 +74,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
node. Noise stdandard deviation is specified outside the `dp_query` by the node. Noise stdandard deviation is specified outside the `dp_query` by the
user when defining `noise_fn` and should have order user when defining `noise_fn` and should have order
O(clip_norm*log(T)/eps) to guarantee eps-DP. O(clip_norm*log(T)/eps) to guarantee eps-DP.
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
""" """
@attr.s(frozen=True) @attr.s(frozen=True)
@ -63,17 +87,21 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
each level state. each level state.
clip_value: The clipping value to be passed to clip_fn. clip_value: The clipping value to be passed to clip_fn.
samples_cumulative_sum: Noiseless cumulative sum of samples over time. samples_cumulative_sum: Noiseless cumulative sum of samples over time.
restarter_state: Current state of the restarter to indicate whether
the tree state will be reset.
""" """
tree_state = attr.ib() tree_state = attr.ib()
clip_value = attr.ib() clip_value = attr.ib()
samples_cumulative_sum = attr.ib() samples_cumulative_sum = attr.ib()
restarter_state = attr.ib()
def __init__(self, def __init__(self,
record_specs, record_specs,
noise_generator, noise_generator,
clip_fn, clip_fn,
clip_value, clip_value,
use_efficient=True): use_efficient=True,
restart_indicator=None):
"""Initializes the `TreeCumulativeSumQuery`. """Initializes the `TreeCumulativeSumQuery`.
Consider using `build_l2_gaussian_query` for the construction of a Consider using `build_l2_gaussian_query` for the construction of a
@ -91,6 +119,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
use_efficient: Boolean indicating the usage of the efficient tree use_efficient: Boolean indicating the usage of the efficient tree
aggregation algorithm based on the paper "Efficient Use of aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees". Differentially Private Binary Trees".
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
""" """
self._clip_fn = clip_fn self._clip_fn = clip_fn
self._clip_value = clip_value self._clip_value = clip_value
@ -100,17 +130,21 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
noise_generator) noise_generator)
else: else:
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator) self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
self._restart_indicator = restart_indicator
def initial_global_state(self): def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" """Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
initial_tree_state = self._tree_aggregator.init_state() initial_tree_state = self._tree_aggregator.init_state()
initial_samples_cumulative_sum = tf.nest.map_structure( initial_samples_cumulative_sum = tf.nest.map_structure(
lambda spec: tf.zeros(spec.shape), self._record_specs) lambda spec: tf.zeros(spec.shape), self._record_specs)
initial_state = TreeCumulativeSumQuery.GlobalState( restarter_state = None
if self._restart_indicator is not None:
restarter_state = self._restart_indicator.initialize()
return TreeCumulativeSumQuery.GlobalState(
tree_state=initial_tree_state, tree_state=initial_tree_state,
clip_value=tf.constant(self._clip_value, tf.float32), clip_value=tf.constant(self._clip_value, tf.float32),
samples_cumulative_sum=initial_samples_cumulative_sum) samples_cumulative_sum=initial_samples_cumulative_sum,
return initial_state restarter_state=restarter_state)
def derive_sample_params(self, global_state): def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" """Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
@ -151,13 +185,21 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
tf.add, global_state.samples_cumulative_sum, sample_state) tf.add, global_state.samples_cumulative_sum, sample_state)
cumulative_sum_noise, new_tree_state = self._tree_aggregator.get_cumsum_and_update( cumulative_sum_noise, new_tree_state = self._tree_aggregator.get_cumsum_and_update(
global_state.tree_state) global_state.tree_state)
noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
cumulative_sum_noise)
restarter_state = global_state.restarter_state
if self._restart_indicator is not None:
restart_flag, restarter_state = self._restart_indicator.next(
restarter_state)
if restart_flag:
new_cumulative_sum = noised_cumulative_sum
new_tree_state = self._tree_aggregator.reset_state(new_tree_state)
new_global_state = attr.evolve( new_global_state = attr.evolve(
global_state, global_state,
samples_cumulative_sum=new_cumulative_sum, samples_cumulative_sum=new_cumulative_sum,
tree_state=new_tree_state) tree_state=new_tree_state,
noised_cum_sum = tf.nest.map_structure(tf.add, new_cumulative_sum, restarter_state=restarter_state)
cumulative_sum_noise) return noised_cumulative_sum, new_global_state
return noised_cum_sum, new_global_state
@classmethod @classmethod
def build_l2_gaussian_query(cls, def build_l2_gaussian_query(cls,
@ -165,7 +207,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
noise_multiplier, noise_multiplier,
record_specs, record_specs,
noise_seed=None, noise_seed=None,
use_efficient=True): use_efficient=True,
restart_indicator=None):
"""Returns a query instance with L2 norm clipping and Gaussian noise. """Returns a query instance with L2 norm clipping and Gaussian noise.
Args: Args:
@ -180,6 +223,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
use_efficient: Boolean indicating the usage of the efficient tree use_efficient: Boolean indicating the usage of the efficient tree
aggregation algorithm based on the paper "Efficient Use of aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees". Differentially Private Binary Trees".
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
""" """
if clip_norm <= 0: if clip_norm <= 0:
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.') raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
@ -202,22 +247,48 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
clip_value=clip_norm, clip_value=clip_norm,
record_specs=record_specs, record_specs=record_specs,
noise_generator=gaussian_noise_generator, noise_generator=gaussian_noise_generator,
use_efficient=use_efficient) use_efficient=use_efficient,
restart_indicator=restart_indicator)
class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
"""Implements dp_query for adding correlated noise through tree structure. """Implements DPQuery for adding correlated noise through tree structure.
Clips and sums records in current sample; returns the current sample adding Clips and sums records in current sample x_i = sum_{j=0}^{n-1} x_{i,j};
the noise residual from tree aggregation. The returned value is conceptually returns the current sample adding the noise residual from tree aggregation.
equivalent to the following: calculates cumulative sum of samples over time The returned value is conceptually equivalent to the following: calculates
(instead of only current sample) with added noise for cumulative sum cumulative sum of samples over time s_i = sum_{k=0}^i x_i (instead of only
proportional to log(T), T being the number of times the query is called; current sample) with added noise by tree aggregation protocol that is
returns the residual between the current noised cumsum and the previous one proportional to log(T), T being the number of times the query is called; r
when the query is called. Combining this query with a SGD optimizer can be eturns the residual between the current noised cumsum noised(s_i) and the
used to implement the DP-FTRL algorithm in previous one noised(s_{i-1}) when the query is called.
This can be used as a drop-in replacement for `GaussianSumQuery`, and can
offer stronger utility/privacy tradeoffs when aplification-via-sampling is not
possible, or when privacy epsilon is relativly large. This may result in
more noise by a log(T) factor in each individual estimate of x_i, but if the
x_i are used in the underlying code to compute cumulative sums, the noise in
those sums can be less. That is, this allows us to adapt code that was written
to use a regular `SumQuery` to benefit from the tree aggregation protocol.
Combining this query with a SGD optimizer can be used to implement the
DP-FTRL algorithm in
"Practical and Private (Deep) Learning without Sampling or Shuffling". "Practical and Private (Deep) Learning without Sampling or Shuffling".
Example usage:
query = TreeResidualSumQuery(...)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i, samples in enumerate(streaming_samples):
sample_state = query.initial_sample_state(samples[0])
# Compute x_i = sum_{j=0}^{n-1} x_{i,j}
for j,sample in enumerate(samples):
sample_state = query.accumulate_record(params, sample_state, sample)
# noised_sum is privatized estimate of x_i by conceptually postprocessing
# noised cumulative sum s_i
noised_sum, global_state = query.get_noised_result(
sample_state, global_state)
Attributes: Attributes:
clip_fn: Callable that specifies clipping function. `clip_fn` receives two clip_fn: Callable that specifies clipping function. `clip_fn` receives two
arguments: a flat list of vars in a record and a `clip_value` to clip the arguments: a flat list of vars in a record and a `clip_value` to clip the
@ -231,6 +302,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
node. Noise stdandard deviation is specified outside the `dp_query` by the node. Noise stdandard deviation is specified outside the `dp_query` by the
user when defining `noise_fn` and should have order user when defining `noise_fn` and should have order
O(clip_norm*log(T)/eps) to guarantee eps-DP. O(clip_norm*log(T)/eps) to guarantee eps-DP.
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
""" """
@attr.s(frozen=True) @attr.s(frozen=True)
@ -243,21 +316,25 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
clip_value: The clipping value to be passed to clip_fn. clip_value: The clipping value to be passed to clip_fn.
previous_tree_noise: Cumulative noise by tree aggregation from the previous_tree_noise: Cumulative noise by tree aggregation from the
previous time the query is called on a sample. previous time the query is called on a sample.
restarter_state: Current state of the restarter to indicate whether
the tree state will be reset.
""" """
tree_state = attr.ib() tree_state = attr.ib()
clip_value = attr.ib() clip_value = attr.ib()
previous_tree_noise = attr.ib() previous_tree_noise = attr.ib()
restarter_state = attr.ib()
def __init__(self, def __init__(self,
record_specs, record_specs,
noise_generator, noise_generator,
clip_fn, clip_fn,
clip_value, clip_value,
use_efficient=True): use_efficient=True,
"""Initializes the `TreeResidualSumQuery`. restart_indicator=None):
"""Initializes the `TreeCumulativeSumQuery`.
Consider using `build_l2_gaussian_query` for the construction of a Consider using `build_l2_gaussian_query` for the construction of a
`TreeResidualSumQuery` with L2 norm clipping and Gaussian noise. `TreeCumulativeSumQuery` with L2 norm clipping and Gaussian noise.
Args: Args:
record_specs: A nested structure of `tf.TensorSpec`s specifying structure record_specs: A nested structure of `tf.TensorSpec`s specifying structure
@ -271,6 +348,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
use_efficient: Boolean indicating the usage of the efficient tree use_efficient: Boolean indicating the usage of the efficient tree
aggregation algorithm based on the paper "Efficient Use of aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees". Differentially Private Binary Trees".
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
""" """
self._clip_fn = clip_fn self._clip_fn = clip_fn
self._clip_value = clip_value self._clip_value = clip_value
@ -280,16 +359,23 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
noise_generator) noise_generator)
else: else:
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator) self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
self._restart_indicator = restart_indicator
def _zero_initial_noise(self):
return tf.nest.map_structure(lambda spec: tf.zeros(spec.shape),
self._record_specs)
def initial_global_state(self): def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" """Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
initial_tree_state = self._tree_aggregator.init_state() initial_tree_state = self._tree_aggregator.init_state()
initial_noise = tf.nest.map_structure(lambda spec: tf.zeros(spec.shape), restarter_state = None
self._record_specs) if self._restart_indicator is not None:
restarter_state = self._restart_indicator.initialize()
return TreeResidualSumQuery.GlobalState( return TreeResidualSumQuery.GlobalState(
tree_state=initial_tree_state, tree_state=initial_tree_state,
clip_value=tf.constant(self._clip_value, tf.float32), clip_value=tf.constant(self._clip_value, tf.float32),
previous_tree_noise=initial_noise) previous_tree_noise=self._zero_initial_noise(),
restarter_state=restarter_state)
def derive_sample_params(self, global_state): def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" """Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
@ -328,8 +414,18 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c, noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c,
sample_state, tree_noise, sample_state, tree_noise,
global_state.previous_tree_noise) global_state.previous_tree_noise)
restarter_state = global_state.restarter_state
if self._restart_indicator is not None:
restart_flag, restarter_state = self._restart_indicator.next(
restarter_state)
if restart_flag:
tree_noise = self._zero_initial_noise()
new_tree_state = self._tree_aggregator.reset_state(new_tree_state)
new_global_state = attr.evolve( new_global_state = attr.evolve(
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state) global_state,
previous_tree_noise=tree_noise,
tree_state=new_tree_state,
restarter_state=restarter_state)
return noised_sample, new_global_state return noised_sample, new_global_state
@classmethod @classmethod
@ -338,7 +434,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
noise_multiplier, noise_multiplier,
record_specs, record_specs,
noise_seed=None, noise_seed=None,
use_efficient=True): use_efficient=True,
restart_indicator=None):
"""Returns `TreeResidualSumQuery` with L2 norm clipping and Gaussian noise. """Returns `TreeResidualSumQuery` with L2 norm clipping and Gaussian noise.
Args: Args:
@ -353,6 +450,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
use_efficient: Boolean indicating the usage of the efficient tree use_efficient: Boolean indicating the usage of the efficient tree
aggregation algorithm based on the paper "Efficient Use of aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees". Differentially Private Binary Trees".
restart_indicator: `tree_aggregation.RestartIndicator` to generate the
boolean indicator for resetting the tree state.
""" """
if clip_norm <= 0: if clip_norm <= 0:
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.') raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
@ -375,7 +474,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
clip_value=clip_norm, clip_value=clip_norm,
record_specs=record_specs, record_specs=record_specs,
noise_generator=gaussian_noise_generator, noise_generator=gaussian_noise_generator,
use_efficient=use_efficient) use_efficient=use_efficient,
restart_indicator=restart_indicator)
@tf.function @tf.function

View file

@ -263,13 +263,13 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(query_result, expected_sum) self.assertEqual(query_result, expected_sum)
@parameterized.named_parameters( @parameterized.named_parameters(
('s0t1step8', 0., 1., [1., 1., 2., 1., 2., 2., 3., 1.]), ('s0t1', 0., 1.),
('s1t1step8', 1., 1., [2., 3., 5., 5., 7., 8., 10., 9.]), ('s1t1', 1., 1.),
('s1t2step8', 1., 2., [3., 4., 7., 6., 9., 10., 13., 10.]), ('s1t2', 1., 2.),
) )
def test_partial_sum_scalar_tree_aggregation(self, scalar_value, def test_partial_sum_scalar_tree_aggregation(self, scalar_value,
tree_node_value, tree_node_value):
expected_values): total_steps = 8
query = tree_aggregation_query.TreeCumulativeSumQuery( query = tree_aggregation_query.TreeCumulativeSumQuery(
clip_fn=_get_l2_clip_fn(), clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip clip_value=scalar_value + 1., # no clip
@ -279,14 +279,54 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
) )
global_state = query.initial_global_state() global_state = query.initial_global_state()
params = query.derive_sample_params(global_state) params = query.derive_sample_params(global_state)
for val in expected_values: for i in range(total_steps):
# For each streaming step i , the expected value is roughly
# `scalar_value*i + tree_aggregation(tree_node_value, i)`
sample_state = query.initial_sample_state(scalar_value) sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value) sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result( query_result, global_state = query.get_noised_result(
sample_state, global_state) sample_state, global_state)
self.assertEqual(query_result, val) # For each streaming step i , the expected value is roughly
# `scalar_value*(i+1) + tree_aggregation(tree_node_value, i)`.
# The tree aggregation value can be inferred from the binary
# representation of the current step.
self.assertEqual(
query_result,
scalar_value * (i + 1) + tree_node_value * bin(i + 1)[2:].count('1'))
@parameterized.named_parameters(
('s0t1f1', 0., 1., 1),
('s0t1f2', 0., 1., 2),
('s0t1f5', 0., 1., 5),
('s1t1f5', 1., 1., 5),
('s1t2f2', 1., 2., 2),
('s1t5f6', 1., 5., 6),
)
def test_sum_scalar_tree_aggregation_reset(self, scalar_value,
tree_node_value, frequency):
total_steps = 20
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
query = tree_aggregation_query.TreeCumulativeSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip
noise_generator=lambda: tree_node_value,
record_specs=tf.TensorSpec([]),
use_efficient=False,
restart_indicator=indicator,
)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result(
sample_state, global_state)
# Expected value is the combination of cumsum of signal; sum of trees
# that have been reset; current tree sum. The tree aggregation value can
# be inferred from the binary representation of the current step.
expected = (
scalar_value * (i + 1) +
i // frequency * tree_node_value * bin(frequency)[2:].count('1') +
tree_node_value * bin(i % frequency + 1)[2:].count('1'))
self.assertEqual(query_result, expected)
@parameterized.named_parameters( @parameterized.named_parameters(
('efficient', True, tree_aggregation.EfficientTreeAggregator), ('efficient', True, tree_aggregation.EfficientTreeAggregator),
@ -395,6 +435,42 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
) )
self.assertIsInstance(query._tree_aggregator, tree_class) self.assertIsInstance(query._tree_aggregator, tree_class)
@parameterized.named_parameters(
('s0t1f1', 0., 1., 1),
('s0t1f2', 0., 1., 2),
('s0t1f5', 0., 1., 5),
('s1t1f5', 1., 1., 5),
('s1t2f2', 1., 2., 2),
('s1t5f6', 1., 5., 6),
)
def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value,
frequency):
total_steps = 20
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
query = tree_aggregation_query.TreeResidualSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip
noise_generator=lambda: tree_node_value,
record_specs=tf.TensorSpec([]),
use_efficient=False,
restart_indicator=indicator,
)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result(
sample_state, global_state)
# Expected value is the signal of the current round plus the residual of
# two continous tree aggregation values. The tree aggregation value can
# be inferred from the binary representation of the current step.
expected = scalar_value + tree_node_value * (
bin(i % frequency + 1)[2:].count('1') -
bin(i % frequency)[2:].count('1'))
print(i, query_result, expected)
self.assertEqual(query_result, expected)
class BuildTreeTest(tf.test.TestCase, parameterized.TestCase): class BuildTreeTest(tf.test.TestCase, parameterized.TestCase):

View file

@ -365,5 +365,26 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
self.assertAllEqual(gstate, gstate2) self.assertAllEqual(gstate, gstate2)
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('zero', 0), ('negative', -1))
def test_round_raise(self, frequency):
with self.assertRaisesRegex(
ValueError, 'Restart frequency should be equal or larger than 1'):
tree_aggregation.PeriodicRoundRestartIndicator(frequency)
@parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5))
def test_round_indicator(self, frequency):
total_steps = 20
indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency)
state = indicator.initialize()
for i in range(total_steps):
flag, state = indicator.next(state)
if i % frequency == frequency - 1:
self.assertTrue(flag)
else:
self.assertFalse(flag)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()