From b4c04093cfa4250d44579204c2c1c196646e482e Mon Sep 17 00:00:00 2001 From: Zheng Xu Date: Wed, 11 Aug 2021 16:25:32 -0700 Subject: [PATCH] Restart the tree state in tree related DPQuery for streaming data: a general abstract class and an instance of restarting every a few rounds. PiperOrigin-RevId: 390244330 --- .../privacy/dp_query/tree_aggregation.py | 183 +++++++++++++++--- .../dp_query/tree_aggregation_query.py | 168 ++++++++++++---- .../dp_query/tree_aggregation_query_test.py | 94 ++++++++- .../privacy/dp_query/tree_aggregation_test.py | 21 ++ 4 files changed, 399 insertions(+), 67 deletions(-) diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation.py index ba8ea2f..6015545 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation.py @@ -16,7 +16,10 @@ `TreeAggregator` and `EfficientTreeAggregator` compute cumulative sums of noise based on tree aggregation. When using an appropriate noise function (e.g., Gaussian noise), it allows for efficient differentially private algorithms under -continual observation, without prior subsampling or shuffling assumptions. +continual observation, without prior subsampling or shuffling assumptions. This +module implements the core logic of tree aggregation in Tensorflow, which serves +as helper functions for `tree_aggregation_query`. This module and helper +functions are publicly accessible. """ import abc @@ -26,6 +29,10 @@ import attr import tensorflow as tf +# TODO(b/192464750): find a proper place for the helper functions, privatize +# the tree aggregation logic, and encourage users to use the DPQuery API. + + class ValueGenerator(metaclass=abc.ABCMeta): """Base class establishing interface for stateful value generation. @@ -40,6 +47,7 @@ class ValueGenerator(metaclass=abc.ABCMeta): Returns: An initial state. """ + raise NotImplementedError @abc.abstractmethod def next(self, state): @@ -52,6 +60,7 @@ class ValueGenerator(metaclass=abc.ABCMeta): A pair (value, new_state) where value is the next value and new_state is the advanced state. """ + raise NotImplementedError class GaussianNoiseGenerator(ValueGenerator): @@ -148,6 +157,78 @@ class StatelessValueGenerator(ValueGenerator): return self.value_fn(), state +# TODO(b/192464750): define `RestartQuery` and move `RestartIndicator` to be +# in the same module. + + +class RestartIndicator(metaclass=abc.ABCMeta): + """Base class establishing interface for restarting the tree state. + + A `RestartIndicator` maintains a state, and each time `next` is called, a bool + value is generated to indicate whether to restart, and the indicator state is + advanced. + """ + + @abc.abstractmethod + def initialize(self): + """Makes an initialized state for `RestartIndicator`. + + Returns: + An initial state. + """ + raise NotImplementedError + + @abc.abstractmethod + def next(self, state): + """Gets next bool indicator and advances the `RestartIndicator` state. + + Args: + state: The current state. + + Returns: + A pair (value, new_state) where value is bool indicator and new_state + is the advanced state. + """ + raise NotImplementedError + + +class PeriodicRoundRestartIndicator(RestartIndicator): + """Indicator for resetting the tree state after every a few number of queries. + + The indicator will maintain an internal counter as state. + """ + + def __init__(self, frequency: int): + """Construct the `PeriodicRoundRestartIndicator`. + + Args: + frequency: The `next` function will return `True` every `frequency` number + of `next` calls. + """ + if frequency < 1: + raise ValueError('Restart frequency should be equal or larger than 1 ' + f'got {frequency}') + self.frequency = tf.constant(frequency, tf.int32) + + def initialize(self): + """Returns initialized state of 0 for `PeriodicRoundRestartIndicator`.""" + return tf.constant(0, tf.int32) + + def next(self, state): + """Gets next bool indicator and advances the state. + + Args: + state: The current state. + + Returns: + A pair (value, new_state) where value is the bool indicator and new_state + of `state+1`. + """ + state = state + tf.constant(1, tf.int32) + flag = state % self.frequency == 0 + return flag, state + + @attr.s(eq=False, frozen=True, slots=True) class TreeState(object): """Class defining state of the tree. @@ -166,6 +247,7 @@ class TreeState(object): value_generator_state = attr.ib(type=Any) +# TODO(b/192464750): move `get_step_idx` to be a property of `TreeState`. @tf.function def get_step_idx(state: TreeState) -> tf.Tensor: """Returns the current leaf node index based on `TreeState.level_buffer_idx`.""" @@ -188,6 +270,14 @@ class TreeAggregator(): https://dl.acm.org/doi/pdf/10.1145/1806689.1806787. A buffer at the scale of tree depth is maintained and updated when a new conceptual leaf node arrives. + Example usage: + random_generator = GaussianNoiseGenerator(...) + tree_aggregator = TreeAggregator(random_generator) + state = tree_aggregator.init_state() + for leaf_node_idx in range(total_steps): + assert leaf_node_idx == get_step_idx(state)) + noise, state = tree_aggregator.get_cumsum_and_update(state) + Attributes: value_generator: A `ValueGenerator` or a no-arg function to generate a noise value for each tree node. @@ -205,14 +295,8 @@ class TreeAggregator(): else: self.value_generator = StatelessValueGenerator(value_generator) - def init_state(self) -> TreeState: - """Returns initial `TreeState`. - - Initializes `TreeState` for a tree of a single leaf node: the respective - initial node value in `TreeState.level_buffer` is generated by the value - generator function, and the node index is 0. - """ - value_generator_state = self.value_generator.initialize() + def _get_init_state(self, value_generator_state) -> TreeState: + """Returns initial `TreeState` given `value_generator_state`.""" level_buffer_idx = tf.TensorArray(dtype=tf.int32, size=1, dynamic_size=True) level_buffer_idx = level_buffer_idx.write(0, tf.constant( 0, dtype=tf.int32)).stack() @@ -224,12 +308,28 @@ class TreeAggregator(): new_val) level_buffer = tf.nest.map_structure(lambda x, y: x.write(0, y).stack(), level_buffer_structure, new_val) - return TreeState( level_buffer=level_buffer, level_buffer_idx=level_buffer_idx, value_generator_state=value_generator_state) + def init_state(self) -> TreeState: + """Returns initial `TreeState`. + + Initializes `TreeState` for a tree of a single leaf node: the respective + initial node value in `TreeState.level_buffer` is generated by the value + generator function, and the node index is 0. + + Returns: + An initialized `TreeState`. + """ + value_generator_state = self.value_generator.initialize() + return self._get_init_state(value_generator_state) + + def reset_state(self, state: TreeState) -> TreeState: + """Returns reset `TreeState` after restarting a new tree.""" + return self._get_init_state(state.value_generator_state) + @tf.function def _get_cumsum(self, level_buffer: Collection[tf.Tensor]) -> tf.Tensor: return tf.nest.map_structure(lambda x: tf.reduce_sum(x, axis=0), @@ -238,7 +338,7 @@ class TreeAggregator(): @tf.function def get_cumsum_and_update(self, state: TreeState) -> Tuple[tf.Tensor, TreeState]: - """Returns tree aggregated value and updated `TreeState` for one step. + """Returns tree aggregated noise and updates `TreeState` for the next step. `TreeState` is updated to prepare for accepting the *next* leaf node. Note that `get_step_idx` can be called to get the current index of the leaf node @@ -249,10 +349,20 @@ class TreeAggregator(): Args: state: `TreeState` for the current leaf node, index can be queried by `tree_aggregation.get_step_idx(state.level_buffer_idx)`. + + Returns: + Tuple of (noise, state) where `noise` is generated by tree aggregated + protocol for the cumulative sum of streaming data, and `state` is the + updated `TreeState`. """ level_buffer_idx, level_buffer, value_generator_state = ( state.level_buffer_idx, state.level_buffer, state.value_generator_state) + # We only publicize a combined function for updating state and returning + # noised results because this DPQuery is designed for the streaming data, + # and we only maintain a dynamic memory buffer of max size logT. Only the + # the most recent noised results can be queried, and the queries are + # expected to happen for every step in the streaming setting. cumsum = self._get_cumsum(level_buffer) new_level_buffer = tf.nest.map_structure( @@ -311,6 +421,14 @@ class EfficientTreeAggregator(): `sigma * sqrt(2^{d-1}/(2^d-1))`. which becomes `sigma / sqrt(2)` when the tree is very tall. + Example usage: + random_generator = GaussianNoiseGenerator(...) + tree_aggregator = EfficientTreeAggregator(random_generator) + state = tree_aggregator.init_state() + for leaf_node_idx in range(total_steps): + assert leaf_node_idx == get_step_idx(state)) + noise, state = tree_aggregator.get_cumsum_and_update(state) + Attributes: value_generator: A `ValueGenerator` or a no-arg function to generate a noise value for each tree node. @@ -328,17 +446,8 @@ class EfficientTreeAggregator(): else: self.value_generator = StatelessValueGenerator(value_generator) - def init_state(self) -> TreeState: - """Returns initial `TreeState`. - - Initializes `TreeState` for a tree of a single leaf node: the respective - initial node value in `TreeState.level_buffer` is generated by the value - generator function, and the node index is 0. - - Returns: - An initialized `TreeState`. - """ - value_generator_state = self.value_generator.initialize() + def _get_init_state(self, value_generator_state): + """Returns initial buffer for `TreeState`.""" level_buffer_idx = tf.TensorArray(dtype=tf.int32, size=1, dynamic_size=True) level_buffer_idx = level_buffer_idx.write(0, tf.constant( 0, dtype=tf.int32)).stack() @@ -350,12 +459,28 @@ class EfficientTreeAggregator(): new_val) level_buffer = tf.nest.map_structure(lambda x, y: x.write(0, y).stack(), level_buffer_structure, new_val) - return TreeState( level_buffer=level_buffer, level_buffer_idx=level_buffer_idx, value_generator_state=value_generator_state) + def init_state(self) -> TreeState: + """Returns initial `TreeState`. + + Initializes `TreeState` for a tree of a single leaf node: the respective + initial node value in `TreeState.level_buffer` is generated by the value + generator function, and the node index is 0. + + Returns: + An initialized `TreeState`. + """ + value_generator_state = self.value_generator.initialize() + return self._get_init_state(value_generator_state) + + def reset_state(self, state: TreeState) -> TreeState: + """Returns reset `TreeState` after restarting a new tree.""" + return self._get_init_state(state.value_generator_state) + @tf.function def _get_cumsum(self, state: TreeState) -> tf.Tensor: """Returns weighted cumulative sum of noise based on `TreeState`.""" @@ -377,7 +502,7 @@ class EfficientTreeAggregator(): @tf.function def get_cumsum_and_update(self, state: TreeState) -> Tuple[tf.Tensor, TreeState]: - """Returns tree aggregated value and updated `TreeState` for one step. + """Returns tree aggregated noise and updates `TreeState` for the next step. `TreeState` is updated to prepare for accepting the *next* leaf node. Note that `get_step_idx` can be called to get the current index of the leaf node @@ -390,7 +515,17 @@ class EfficientTreeAggregator(): Args: state: `TreeState` for the current leaf node, index can be queried by `tree_aggregation.get_step_idx(state.level_buffer_idx)`. + + Returns: + Tuple of (noise, state) where `noise` is generated by tree aggregated + protocol for the cumulative sum of streaming data, and `state` is the + updated `TreeState`.. """ + # We only publicize a combined function for updating state and returning + # noised results because this DPQuery is designed for the streaming data, + # and we only maintain a dynamic memory buffer of max size logT. Only the + # the most recent noised results can be queried, and the queries are + # expected to happen for every step in the streaming setting. cumsum = self._get_cumsum(state) level_buffer_idx, level_buffer, value_generator_state = ( diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index 082bf01..943cf9f 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -32,13 +32,35 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query from tensorflow_privacy.privacy.dp_query import tree_aggregation -class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): - """Implements dp_query for adding correlated noise through tree structure. +# TODO(b/192464750): define `RestartQuery` and move `RestartIndicator` to be +# in the same module. - First clips and sums records in current sample, returns cumulative sum of - samples over time (instead of only current sample) with added noise for - cumulative sum proportional to log(T), T being the number of times the query - is called. + +class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): + """Returns private cumulative sums by clipping and adding correlated noise. + + Consider calling `get_noised_result` T times, and each (x_i, i=0,2,...,T-1) is + the private value returned by `accumulate_record`, i.e. x_i = sum_{j=0}^{n-1} + x_{i,j} where each x_{i,j} is a private record in the database. This class is + intended to make multiple queries, which release privatized values of the + cumulative sums s_i = sum_{k=0}^{i} x_k, for i=0,...,T-1. + Each call to `get_noised_result` releases the next cumulative sum s_i, which + is in contrast to the GaussianSumQuery that releases x_i. Noise for the + cumulative sums is accomplished using the tree aggregation logic in + `tree_aggregation`, which is proportional to log(T). + + Example usage: + query = TreeCumulativeSumQuery(...) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + for i, samples in enumerate(streaming_samples): + sample_state = query.initial_sample_state(samples[0]) + # Compute x_i = sum_{j=0}^{n-1} x_{i,j} + for j,sample in enumerate(samples): + sample_state = query.accumulate_record(params, sample_state, sample) + # noised_cumsum is privatized estimate of s_i + noised_cumsum, global_state = query.get_noised_result( + sample_state, global_state) Attributes: clip_fn: Callable that specifies clipping function. `clip_fn` receives two @@ -52,6 +74,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): node. Noise stdandard deviation is specified outside the `dp_query` by the user when defining `noise_fn` and should have order O(clip_norm*log(T)/eps) to guarantee eps-DP. + restart_indicator: `tree_aggregation.RestartIndicator` to generate the + boolean indicator for resetting the tree state. """ @attr.s(frozen=True) @@ -63,17 +87,21 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): each level state. clip_value: The clipping value to be passed to clip_fn. samples_cumulative_sum: Noiseless cumulative sum of samples over time. + restarter_state: Current state of the restarter to indicate whether + the tree state will be reset. """ tree_state = attr.ib() clip_value = attr.ib() samples_cumulative_sum = attr.ib() + restarter_state = attr.ib() def __init__(self, record_specs, noise_generator, clip_fn, clip_value, - use_efficient=True): + use_efficient=True, + restart_indicator=None): """Initializes the `TreeCumulativeSumQuery`. Consider using `build_l2_gaussian_query` for the construction of a @@ -91,6 +119,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): use_efficient: Boolean indicating the usage of the efficient tree aggregation algorithm based on the paper "Efficient Use of Differentially Private Binary Trees". + restart_indicator: `tree_aggregation.RestartIndicator` to generate the + boolean indicator for resetting the tree state. """ self._clip_fn = clip_fn self._clip_value = clip_value @@ -100,17 +130,21 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): noise_generator) else: self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator) + self._restart_indicator = restart_indicator def initial_global_state(self): """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" initial_tree_state = self._tree_aggregator.init_state() initial_samples_cumulative_sum = tf.nest.map_structure( lambda spec: tf.zeros(spec.shape), self._record_specs) - initial_state = TreeCumulativeSumQuery.GlobalState( + restarter_state = None + if self._restart_indicator is not None: + restarter_state = self._restart_indicator.initialize() + return TreeCumulativeSumQuery.GlobalState( tree_state=initial_tree_state, clip_value=tf.constant(self._clip_value, tf.float32), - samples_cumulative_sum=initial_samples_cumulative_sum) - return initial_state + samples_cumulative_sum=initial_samples_cumulative_sum, + restarter_state=restarter_state) def derive_sample_params(self, global_state): """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" @@ -151,13 +185,21 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): tf.add, global_state.samples_cumulative_sum, sample_state) cumulative_sum_noise, new_tree_state = self._tree_aggregator.get_cumsum_and_update( global_state.tree_state) + noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum, + cumulative_sum_noise) + restarter_state = global_state.restarter_state + if self._restart_indicator is not None: + restart_flag, restarter_state = self._restart_indicator.next( + restarter_state) + if restart_flag: + new_cumulative_sum = noised_cumulative_sum + new_tree_state = self._tree_aggregator.reset_state(new_tree_state) new_global_state = attr.evolve( global_state, samples_cumulative_sum=new_cumulative_sum, - tree_state=new_tree_state) - noised_cum_sum = tf.nest.map_structure(tf.add, new_cumulative_sum, - cumulative_sum_noise) - return noised_cum_sum, new_global_state + tree_state=new_tree_state, + restarter_state=restarter_state) + return noised_cumulative_sum, new_global_state @classmethod def build_l2_gaussian_query(cls, @@ -165,7 +207,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): noise_multiplier, record_specs, noise_seed=None, - use_efficient=True): + use_efficient=True, + restart_indicator=None): """Returns a query instance with L2 norm clipping and Gaussian noise. Args: @@ -180,6 +223,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): use_efficient: Boolean indicating the usage of the efficient tree aggregation algorithm based on the paper "Efficient Use of Differentially Private Binary Trees". + restart_indicator: `tree_aggregation.RestartIndicator` to generate the + boolean indicator for resetting the tree state. """ if clip_norm <= 0: raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.') @@ -202,22 +247,48 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): clip_value=clip_norm, record_specs=record_specs, noise_generator=gaussian_noise_generator, - use_efficient=use_efficient) + use_efficient=use_efficient, + restart_indicator=restart_indicator) class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): - """Implements dp_query for adding correlated noise through tree structure. + """Implements DPQuery for adding correlated noise through tree structure. - Clips and sums records in current sample; returns the current sample adding - the noise residual from tree aggregation. The returned value is conceptually - equivalent to the following: calculates cumulative sum of samples over time - (instead of only current sample) with added noise for cumulative sum - proportional to log(T), T being the number of times the query is called; - returns the residual between the current noised cumsum and the previous one - when the query is called. Combining this query with a SGD optimizer can be - used to implement the DP-FTRL algorithm in + Clips and sums records in current sample x_i = sum_{j=0}^{n-1} x_{i,j}; + returns the current sample adding the noise residual from tree aggregation. + The returned value is conceptually equivalent to the following: calculates + cumulative sum of samples over time s_i = sum_{k=0}^i x_i (instead of only + current sample) with added noise by tree aggregation protocol that is + proportional to log(T), T being the number of times the query is called; r + eturns the residual between the current noised cumsum noised(s_i) and the + previous one noised(s_{i-1}) when the query is called. + + This can be used as a drop-in replacement for `GaussianSumQuery`, and can + offer stronger utility/privacy tradeoffs when aplification-via-sampling is not + possible, or when privacy epsilon is relativly large. This may result in + more noise by a log(T) factor in each individual estimate of x_i, but if the + x_i are used in the underlying code to compute cumulative sums, the noise in + those sums can be less. That is, this allows us to adapt code that was written + to use a regular `SumQuery` to benefit from the tree aggregation protocol. + + Combining this query with a SGD optimizer can be used to implement the + DP-FTRL algorithm in "Practical and Private (Deep) Learning without Sampling or Shuffling". + Example usage: + query = TreeResidualSumQuery(...) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + for i, samples in enumerate(streaming_samples): + sample_state = query.initial_sample_state(samples[0]) + # Compute x_i = sum_{j=0}^{n-1} x_{i,j} + for j,sample in enumerate(samples): + sample_state = query.accumulate_record(params, sample_state, sample) + # noised_sum is privatized estimate of x_i by conceptually postprocessing + # noised cumulative sum s_i + noised_sum, global_state = query.get_noised_result( + sample_state, global_state) + Attributes: clip_fn: Callable that specifies clipping function. `clip_fn` receives two arguments: a flat list of vars in a record and a `clip_value` to clip the @@ -231,6 +302,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): node. Noise stdandard deviation is specified outside the `dp_query` by the user when defining `noise_fn` and should have order O(clip_norm*log(T)/eps) to guarantee eps-DP. + restart_indicator: `tree_aggregation.RestartIndicator` to generate the + boolean indicator for resetting the tree state. """ @attr.s(frozen=True) @@ -243,21 +316,25 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): clip_value: The clipping value to be passed to clip_fn. previous_tree_noise: Cumulative noise by tree aggregation from the previous time the query is called on a sample. + restarter_state: Current state of the restarter to indicate whether + the tree state will be reset. """ tree_state = attr.ib() clip_value = attr.ib() previous_tree_noise = attr.ib() + restarter_state = attr.ib() def __init__(self, record_specs, noise_generator, clip_fn, clip_value, - use_efficient=True): - """Initializes the `TreeResidualSumQuery`. + use_efficient=True, + restart_indicator=None): + """Initializes the `TreeCumulativeSumQuery`. Consider using `build_l2_gaussian_query` for the construction of a - `TreeResidualSumQuery` with L2 norm clipping and Gaussian noise. + `TreeCumulativeSumQuery` with L2 norm clipping and Gaussian noise. Args: record_specs: A nested structure of `tf.TensorSpec`s specifying structure @@ -271,6 +348,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): use_efficient: Boolean indicating the usage of the efficient tree aggregation algorithm based on the paper "Efficient Use of Differentially Private Binary Trees". + restart_indicator: `tree_aggregation.RestartIndicator` to generate the + boolean indicator for resetting the tree state. """ self._clip_fn = clip_fn self._clip_value = clip_value @@ -280,16 +359,23 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): noise_generator) else: self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator) + self._restart_indicator = restart_indicator + + def _zero_initial_noise(self): + return tf.nest.map_structure(lambda spec: tf.zeros(spec.shape), + self._record_specs) def initial_global_state(self): """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" initial_tree_state = self._tree_aggregator.init_state() - initial_noise = tf.nest.map_structure(lambda spec: tf.zeros(spec.shape), - self._record_specs) + restarter_state = None + if self._restart_indicator is not None: + restarter_state = self._restart_indicator.initialize() return TreeResidualSumQuery.GlobalState( tree_state=initial_tree_state, clip_value=tf.constant(self._clip_value, tf.float32), - previous_tree_noise=initial_noise) + previous_tree_noise=self._zero_initial_noise(), + restarter_state=restarter_state) def derive_sample_params(self, global_state): """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" @@ -328,8 +414,18 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c, sample_state, tree_noise, global_state.previous_tree_noise) + restarter_state = global_state.restarter_state + if self._restart_indicator is not None: + restart_flag, restarter_state = self._restart_indicator.next( + restarter_state) + if restart_flag: + tree_noise = self._zero_initial_noise() + new_tree_state = self._tree_aggregator.reset_state(new_tree_state) new_global_state = attr.evolve( - global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state) + global_state, + previous_tree_noise=tree_noise, + tree_state=new_tree_state, + restarter_state=restarter_state) return noised_sample, new_global_state @classmethod @@ -338,7 +434,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): noise_multiplier, record_specs, noise_seed=None, - use_efficient=True): + use_efficient=True, + restart_indicator=None): """Returns `TreeResidualSumQuery` with L2 norm clipping and Gaussian noise. Args: @@ -353,6 +450,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): use_efficient: Boolean indicating the usage of the efficient tree aggregation algorithm based on the paper "Efficient Use of Differentially Private Binary Trees". + restart_indicator: `tree_aggregation.RestartIndicator` to generate the + boolean indicator for resetting the tree state. """ if clip_norm <= 0: raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.') @@ -375,7 +474,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): clip_value=clip_norm, record_specs=record_specs, noise_generator=gaussian_noise_generator, - use_efficient=use_efficient) + use_efficient=use_efficient, + restart_indicator=restart_indicator) @tf.function diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py index a958f26..1bfaa21 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py @@ -263,13 +263,13 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase): self.assertEqual(query_result, expected_sum) @parameterized.named_parameters( - ('s0t1step8', 0., 1., [1., 1., 2., 1., 2., 2., 3., 1.]), - ('s1t1step8', 1., 1., [2., 3., 5., 5., 7., 8., 10., 9.]), - ('s1t2step8', 1., 2., [3., 4., 7., 6., 9., 10., 13., 10.]), + ('s0t1', 0., 1.), + ('s1t1', 1., 1.), + ('s1t2', 1., 2.), ) def test_partial_sum_scalar_tree_aggregation(self, scalar_value, - tree_node_value, - expected_values): + tree_node_value): + total_steps = 8 query = tree_aggregation_query.TreeCumulativeSumQuery( clip_fn=_get_l2_clip_fn(), clip_value=scalar_value + 1., # no clip @@ -279,14 +279,54 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase): ) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) - for val in expected_values: - # For each streaming step i , the expected value is roughly - # `scalar_value*i + tree_aggregation(tree_node_value, i)` + for i in range(total_steps): sample_state = query.initial_sample_state(scalar_value) sample_state = query.accumulate_record(params, sample_state, scalar_value) query_result, global_state = query.get_noised_result( sample_state, global_state) - self.assertEqual(query_result, val) + # For each streaming step i , the expected value is roughly + # `scalar_value*(i+1) + tree_aggregation(tree_node_value, i)`. + # The tree aggregation value can be inferred from the binary + # representation of the current step. + self.assertEqual( + query_result, + scalar_value * (i + 1) + tree_node_value * bin(i + 1)[2:].count('1')) + + @parameterized.named_parameters( + ('s0t1f1', 0., 1., 1), + ('s0t1f2', 0., 1., 2), + ('s0t1f5', 0., 1., 5), + ('s1t1f5', 1., 1., 5), + ('s1t2f2', 1., 2., 2), + ('s1t5f6', 1., 5., 6), + ) + def test_sum_scalar_tree_aggregation_reset(self, scalar_value, + tree_node_value, frequency): + total_steps = 20 + indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency) + query = tree_aggregation_query.TreeCumulativeSumQuery( + clip_fn=_get_l2_clip_fn(), + clip_value=scalar_value + 1., # no clip + noise_generator=lambda: tree_node_value, + record_specs=tf.TensorSpec([]), + use_efficient=False, + restart_indicator=indicator, + ) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + for i in range(total_steps): + sample_state = query.initial_sample_state(scalar_value) + sample_state = query.accumulate_record(params, sample_state, scalar_value) + query_result, global_state = query.get_noised_result( + sample_state, global_state) + # Expected value is the combination of cumsum of signal; sum of trees + # that have been reset; current tree sum. The tree aggregation value can + # be inferred from the binary representation of the current step. + expected = ( + scalar_value * (i + 1) + + i // frequency * tree_node_value * bin(frequency)[2:].count('1') + + tree_node_value * bin(i % frequency + 1)[2:].count('1')) + self.assertEqual(query_result, expected) @parameterized.named_parameters( ('efficient', True, tree_aggregation.EfficientTreeAggregator), @@ -395,6 +435,42 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase): ) self.assertIsInstance(query._tree_aggregator, tree_class) + @parameterized.named_parameters( + ('s0t1f1', 0., 1., 1), + ('s0t1f2', 0., 1., 2), + ('s0t1f5', 0., 1., 5), + ('s1t1f5', 1., 1., 5), + ('s1t2f2', 1., 2., 2), + ('s1t5f6', 1., 5., 6), + ) + def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value, + frequency): + total_steps = 20 + indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency) + query = tree_aggregation_query.TreeResidualSumQuery( + clip_fn=_get_l2_clip_fn(), + clip_value=scalar_value + 1., # no clip + noise_generator=lambda: tree_node_value, + record_specs=tf.TensorSpec([]), + use_efficient=False, + restart_indicator=indicator, + ) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + for i in range(total_steps): + sample_state = query.initial_sample_state(scalar_value) + sample_state = query.accumulate_record(params, sample_state, scalar_value) + query_result, global_state = query.get_noised_result( + sample_state, global_state) + # Expected value is the signal of the current round plus the residual of + # two continous tree aggregation values. The tree aggregation value can + # be inferred from the binary representation of the current step. + expected = scalar_value + tree_node_value * ( + bin(i % frequency + 1)[2:].count('1') - + bin(i % frequency)[2:].count('1')) + print(i, query_result, expected) + self.assertEqual(query_result, expected) + class BuildTreeTest(tf.test.TestCase, parameterized.TestCase): diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py index 9a8be35..fc5e6cc 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py @@ -365,5 +365,26 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase): self.assertAllEqual(gstate, gstate2) +class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters(('zero', 0), ('negative', -1)) + def test_round_raise(self, frequency): + with self.assertRaisesRegex( + ValueError, 'Restart frequency should be equal or larger than 1'): + tree_aggregation.PeriodicRoundRestartIndicator(frequency) + + @parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5)) + def test_round_indicator(self, frequency): + total_steps = 20 + indicator = tree_aggregation.PeriodicRoundRestartIndicator(frequency) + state = indicator.initialize() + for i in range(total_steps): + flag, state = indicator.next(state) + if i % frequency == frequency - 1: + self.assertTrue(flag) + else: + self.assertFalse(flag) + + if __name__ == '__main__': tf.test.main()