From 944dcd0e178c9e277a30013bcf887338d67cf67f Mon Sep 17 00:00:00 2001 From: Zheng Xu Date: Tue, 1 Jun 2021 17:26:38 -0700 Subject: [PATCH] Implement the tree aggregation query in TFP. The core `tree_aggregation` algorithm is from https://github.com/google-research/federated/tree/master/dp_ftrl. The tree_aggregation_query is partially developed by Monica Ribero Diaz when she was a student researcher at Google. PiperOrigin-RevId: 376953302 --- tensorflow_privacy/__init__.py | 3 + .../privacy/dp_query/tree_aggregation.py | 367 ++++++++++++++++ .../dp_query/tree_aggregation_query.py | 355 ++++++++++++++++ .../dp_query/tree_aggregation_query_test.py | 399 ++++++++++++++++++ .../privacy/dp_query/tree_aggregation_test.py | 369 ++++++++++++++++ 5 files changed, 1493 insertions(+) create mode 100644 tensorflow_privacy/privacy/dp_query/tree_aggregation.py create mode 100644 tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py create mode 100644 tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py create mode 100644 tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py diff --git a/tensorflow_privacy/__init__.py b/tensorflow_privacy/__init__.py index fcdfd50..1ef5d2d 100644 --- a/tensorflow_privacy/__init__.py +++ b/tensorflow_privacy/__init__.py @@ -47,6 +47,9 @@ else: from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import QuantileEstimatorQuery from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import NoPrivacyQuantileEstimatorQuery from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery + from tensorflow_privacy.privacy.dp_query import tree_aggregation + from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeCumulativeSumQuery + from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeResidualSumQuery # Estimators from tensorflow_privacy.privacy.estimators.dnn import DNNClassifier diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation.py new file mode 100644 index 0000000..ad84053 --- /dev/null +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation.py @@ -0,0 +1,367 @@ +# Copyright 2021, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tree aggregation algorithm. + +This algorithm computes cumulative sums of noise based on tree aggregation. When +using an appropriate noise function (e.g., Gaussian noise), it allows for +efficient differentially private algorithms under continual observation, without +prior subsampling or shuffling assumptions. +""" + +import abc +from typing import Any, Callable, Collection, Optional, Tuple, Union + +import attr +import tensorflow as tf + + +class ValueGenerator(metaclass=abc.ABCMeta): + """Base class establishing interface for stateful value generation.""" + + @abc.abstractmethod + def initialize(self): + """Returns initialized state.""" + + @abc.abstractmethod + def next(self, state): + """Returns tree node value and updated state.""" + + +class GaussianNoiseGenerator(ValueGenerator): + """Gaussian noise generator with counter as pseudo state.""" + + def __init__(self, + noise_std: float, + specs: Collection[tf.TensorSpec], + seed: Optional[int] = None): + self.noise_std = noise_std + self.specs = specs + self.seed = seed + + def initialize(self): + if self.seed is None: + return tf.cast( + tf.stack([ + tf.math.floor(tf.timestamp() * 1e6), + tf.math.floor(tf.math.log(tf.timestamp() * 1e6)) + ]), + dtype=tf.int64) + else: + return tf.constant(self.seed, dtype=tf.int64, shape=(2,)) + + def next(self, state): + flat_structure = tf.nest.flatten(self.specs) + flat_seeds = [state + i for i in range(len(flat_structure))] + nest_seeds = tf.nest.pack_sequence_as(self.specs, flat_seeds) + + def _get_noise(spec, seed): + return tf.random.stateless_normal( + shape=spec.shape, seed=seed, stddev=self.noise_std) + + nest_noise = tf.nest.map_structure(_get_noise, self.specs, nest_seeds) + return nest_noise, flat_seeds[-1] + 1 + + +class StatelessValueGenerator(ValueGenerator): + """A wrapper for stateless value generator initialized by a no-arg function.""" + + def __init__(self, value_fn): + self.value_fn = value_fn + + def initialize(self): + return () + + def next(self, state): + return self.value_fn(), state + + +@attr.s(eq=False, frozen=True, slots=True) +class TreeState(object): + """Class defining state of the tree. + + Attributes: + level_buffer: A `tf.Tensor` saves the last node value of the left child + entered for the tree levels recorded in `level_buffer_idx`. + level_buffer_idx: A `tf.Tensor` for the tree level index of the + `level_buffer`. The tree level index starts from 0, i.e., + `level_buffer[0]` when `level_buffer_idx[0]==0` recorded the noise value + for the most recent leaf node. + value_generator_state: State of a stateful `ValueGenerator` for tree node. + """ + level_buffer = attr.ib(type=tf.Tensor) + level_buffer_idx = attr.ib(type=tf.Tensor) + value_generator_state = attr.ib(type=Any) + + +@tf.function +def get_step_idx(state: TreeState) -> tf.Tensor: + """Returns the current leaf node index based on `TreeState.level_buffer_idx`.""" + step_idx = tf.constant(-1, dtype=tf.int32) + for i in tf.range(len(state.level_buffer_idx)): + step_idx += tf.math.pow(2, state.level_buffer_idx[i]) + return step_idx + + +class TreeAggregator(): + """Tree aggregator to compute accumulated noise in private algorithms. + + This class implements the tree aggregation algorithm for noise values to + efficiently privatize streaming algorithms based on Dwork et al. (2010) + https://dl.acm.org/doi/pdf/10.1145/1806689.1806787. A buffer at the scale of + tree depth is maintained and updated when a new conceptual leaf node arrives. + + Attributes: + value_generator: A `ValueGenerator` or a no-arg function to generate a noise + value for each tree node. + """ + + def __init__(self, value_generator: Union[ValueGenerator, Callable[[], Any]]): + """Initialize the aggregator with a noise generator.""" + if isinstance(value_generator, ValueGenerator): + self.value_generator = value_generator + else: + self.value_generator = StatelessValueGenerator(value_generator) + + def init_state(self) -> TreeState: + """Returns initial `TreeState`. + + Initializes `TreeState` for a tree of a single leaf node: the respective + initial node value in `TreeState.level_buffer` is generated by the value + generator function, and the node index is 0. + """ + value_generator_state = self.value_generator.initialize() + level_buffer_idx = tf.TensorArray(dtype=tf.int32, size=1, dynamic_size=True) + level_buffer_idx = level_buffer_idx.write(0, tf.constant( + 0, dtype=tf.int32)).stack() + + new_val, value_generator_state = self.value_generator.next( + value_generator_state) + level_buffer_structure = tf.nest.map_structure( + lambda x: tf.TensorArray(dtype=tf.float32, size=1, dynamic_size=True), + new_val) + level_buffer = tf.nest.map_structure(lambda x, y: x.write(0, y).stack(), + level_buffer_structure, new_val) + + return TreeState( + level_buffer=level_buffer, + level_buffer_idx=level_buffer_idx, + value_generator_state=value_generator_state) + + @tf.function + def _get_cumsum(self, level_buffer: Collection[tf.Tensor]) -> tf.Tensor: + return tf.nest.map_structure(lambda x: tf.reduce_sum(x, axis=0), + level_buffer) + + @tf.function + def get_cumsum_and_update(self, + state: TreeState) -> Tuple[tf.Tensor, TreeState]: + """Returns tree aggregated value and updated `TreeState` for one step. + + `TreeState` is updated to prepare for accepting the *next* leaf node. Note + that `get_step_idx` can be called to get the current index of the leaf node + before calling this function. This function accept state for the current + leaf node and prepare for the next leaf node because TFF prefers to know + the types of state at initialization. + + Args: + state: `TreeState` for the current leaf node, index can be queried by + `tree_aggregation.get_step_idx(state.level_buffer_idx)`. + """ + + level_buffer_idx, level_buffer, value_generator_state = ( + state.level_buffer_idx, state.level_buffer, state.value_generator_state) + cumsum = self._get_cumsum(level_buffer) + + new_level_buffer = tf.nest.map_structure( + lambda x: tf.TensorArray( # pylint: disable=g-long-lambda + dtype=tf.float32, + size=0, + dynamic_size=True), + level_buffer) + new_level_buffer_idx = tf.TensorArray( + dtype=tf.int32, size=0, dynamic_size=True) + # `TreeState` stores the left child node necessary for computing the cumsum + # noise. To update the buffer, let us find the lowest level that will switch + # from a right child (not in the buffer) to a left child. + level_idx = 0 # new leaf node starts from level 0 + while tf.less(level_idx, len(level_buffer_idx)) and tf.equal( + level_idx, level_buffer_idx[level_idx]): + level_idx += 1 + # Left child nodes for the level lower than `level_idx` will be removed + # and a new node will be created at `level_idx`. + write_buffer_idx = 0 + new_level_buffer_idx = new_level_buffer_idx.write(write_buffer_idx, + level_idx) + new_value, value_generator_state = self.value_generator.next( + value_generator_state) + new_level_buffer = tf.nest.map_structure( + lambda x, y: x.write(write_buffer_idx, y), new_level_buffer, new_value) + write_buffer_idx += 1 + # Buffer index will now different from level index for the old `TreeState` + # i.e., `level_buffer_idx[level_idx] != level_idx`. Rename parameter to + # buffer index for clarity. + buffer_idx = level_idx + while tf.less(buffer_idx, len(level_buffer_idx)): + new_level_buffer_idx = new_level_buffer_idx.write( + write_buffer_idx, level_buffer_idx[buffer_idx]) + new_level_buffer = tf.nest.map_structure( + lambda nb, b: nb.write(write_buffer_idx, b[buffer_idx]), + new_level_buffer, level_buffer) + buffer_idx += 1 + write_buffer_idx += 1 + new_level_buffer_idx = new_level_buffer_idx.stack() + new_level_buffer = tf.nest.map_structure(lambda x: x.stack(), + new_level_buffer) + new_state = TreeState( + level_buffer=new_level_buffer, + level_buffer_idx=new_level_buffer_idx, + value_generator_state=value_generator_state) + return cumsum, new_state + + +class EfficientTreeAggregator(): + """Efficient tree aggregator to compute accumulated noise. + + This class implements the efficient tree aggregation algorithm based on + Honaker 2015 "Efficient Use of Differentially Private Binary Trees". + The noise standard deviation for the note at depth d is roughly + `sigma * sqrt(2^{d-1}/(2^d-1))`. which becomes `sigma / sqrt(2)` when + the tree is very tall. + + Attributes: + value_generator: A `ValueGenerator` or a no-arg function to generate a noise + value for each tree node. + """ + + def __init__(self, value_generator: Union[ValueGenerator, Callable[[], Any]]): + """Initialize the aggregator with a noise generator.""" + if isinstance(value_generator, ValueGenerator): + self.value_generator = value_generator + else: + self.value_generator = StatelessValueGenerator(value_generator) + + def init_state(self) -> TreeState: + """Returns initial `TreeState`. + + Initializes `TreeState` for a tree of a single leaf node: the respective + initial node value in `TreeState.level_buffer` is generated by the value + generator function, and the node index is 0. + """ + value_generator_state = self.value_generator.initialize() + level_buffer_idx = tf.TensorArray(dtype=tf.int32, size=1, dynamic_size=True) + level_buffer_idx = level_buffer_idx.write(0, tf.constant( + 0, dtype=tf.int32)).stack() + + new_val, value_generator_state = self.value_generator.next( + value_generator_state) + level_buffer_structure = tf.nest.map_structure( + lambda x: tf.TensorArray(dtype=tf.float32, size=1, dynamic_size=True), + new_val) + level_buffer = tf.nest.map_structure(lambda x, y: x.write(0, y).stack(), + level_buffer_structure, new_val) + + return TreeState( + level_buffer=level_buffer, + level_buffer_idx=level_buffer_idx, + value_generator_state=value_generator_state) + + @tf.function + def _get_cumsum(self, state: TreeState) -> tf.Tensor: + """Returns weighted cumulative sum of noise based on `TreeState`.""" + # Note that the buffer saved recursive results of the weighted average of + # the node value (v) and its two children (l, r), i.e., node = v + (l+r)/2. + # To get unbiased estimation with reduced variance for each node, we have to + # reweight it by 1/(2-2^{-d}) where d is the depth of the node. + level_weights = tf.math.divide( + 1., 2. - tf.math.pow(.5, tf.cast(state.level_buffer_idx, tf.float32))) + + def _weighted_sum(buffer): + expand_shape = [len(level_weights)] + [1] * (len(tf.shape(buffer)) - 1) + weighted_buffer = tf.math.multiply( + buffer, tf.reshape(level_weights, expand_shape)) + return tf.reduce_sum(weighted_buffer, axis=0) + + return tf.nest.map_structure(_weighted_sum, state.level_buffer) + + @tf.function + def get_cumsum_and_update(self, + state: TreeState) -> Tuple[tf.Tensor, TreeState]: + """Returns tree aggregated value and updated `TreeState` for one step. + + `TreeState` is updated to prepare for accepting the *next* leaf node. Note + that `get_step_idx` can be called to get the current index of the leaf node + before calling this function. This function accept state for the current + leaf node and prepare for the next leaf node because TFF prefers to know + the types of state at initialization. Note that the value of new node in + `TreeState.level_buffer` will depend on its two children, and is updated + from bottom up for the right child. + + Args: + state: `TreeState` for the current leaf node, index can be queried by + `tree_aggregation.get_step_idx(state.level_buffer_idx)`. + """ + cumsum = self._get_cumsum(state) + + level_buffer_idx, level_buffer, value_generator_state = ( + state.level_buffer_idx, state.level_buffer, state.value_generator_state) + new_level_buffer = tf.nest.map_structure( + lambda x: tf.TensorArray( # pylint: disable=g-long-lambda + dtype=tf.float32, + size=0, + dynamic_size=True), + level_buffer) + new_level_buffer_idx = tf.TensorArray( + dtype=tf.int32, size=0, dynamic_size=True) + # `TreeState` stores the left child node necessary for computing the cumsum + # noise. To update the buffer, let us find the lowest level that will switch + # from a right child (not in the buffer) to a left child. + level_idx = 0 # new leaf node starts from level 0 + new_value, value_generator_state = self.value_generator.next( + value_generator_state) + while tf.less(level_idx, len(level_buffer_idx)) and tf.equal( + level_idx, level_buffer_idx[level_idx]): + # Recursively update if the current node is a right child. + node_value, value_generator_state = self.value_generator.next( + value_generator_state) + new_value = tf.nest.map_structure( + lambda l, r, n: 0.5 * (l[level_idx] + r) + n, level_buffer, new_value, + node_value) + level_idx += 1 + # A new (left) node will be created at `level_idx`. + write_buffer_idx = 0 + new_level_buffer_idx = new_level_buffer_idx.write(write_buffer_idx, + level_idx) + new_level_buffer = tf.nest.map_structure( + lambda x, y: x.write(write_buffer_idx, y), new_level_buffer, new_value) + write_buffer_idx += 1 + # Buffer index will now different from level index for the old `TreeState` + # i.e., `level_buffer_idx[level_idx] != level_idx`. Rename parameter to + # buffer index for clarity. + buffer_idx = level_idx + while tf.less(buffer_idx, len(level_buffer_idx)): + new_level_buffer_idx = new_level_buffer_idx.write( + write_buffer_idx, level_buffer_idx[buffer_idx]) + new_level_buffer = tf.nest.map_structure( + lambda nb, b: nb.write(write_buffer_idx, b[buffer_idx]), + new_level_buffer, level_buffer) + buffer_idx += 1 + write_buffer_idx += 1 + new_level_buffer_idx = new_level_buffer_idx.stack() + new_level_buffer = tf.nest.map_structure(lambda x: x.stack(), + new_level_buffer) + new_state = TreeState( + level_buffer=new_level_buffer, + level_buffer_idx=new_level_buffer_idx, + value_generator_state=value_generator_state) + return cumsum, new_state diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py new file mode 100644 index 0000000..5580222 --- /dev/null +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -0,0 +1,355 @@ +# Copyright 2021, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DPQuery for continual observation queries relying on `tree_aggregation`.""" + +import attr +import tensorflow as tf + +from tensorflow_privacy.privacy.dp_query import dp_query +from tensorflow_privacy.privacy.dp_query import tree_aggregation + + +class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): + """Implements dp_query for adding correlated noise through tree structure. + + First clips and sums records in current sample, returns cumulative sum of + samples over time (instead of only current sample) with added noise for + cumulative sum proportional to log(T), T being the number of times the query + is called. + + Attributes: + clip_fn: Callable that specifies clipping function. `clip_fn` receives two + arguments: a flat list of vars in a record and a `clip_value` to clip the + corresponding record, e.g. clip_fn(flat_record, clip_value). + clip_value: float indicating the value at which to clip the record. + record_specs: `Collection[tf.TensorSpec]` specifying shapes of records. + tree_aggregator: `tree_aggregation.TreeAggregator` initialized with + user defined `noise_generator`. `noise_generator` is a + `tree_aggregation.ValueGenerator` to generate the noise value for a tree + node. Noise stdandard deviation is specified outside the `dp_query` by the + user when defining `noise_fn` and should have order + O(clip_norm*log(T)/eps) to guarantee eps-DP. + """ + + @attr.s(frozen=True) + class GlobalState(object): + """Class defining global state for Tree sum queries. + + Attributes: + tree_state: Current state of noise tree keeping track of current leaf and + each level state. + clip_value: The clipping value to be passed to clip_fn. + samples_cumulative_sum: Noiseless cumulative sum of samples over time. + """ + tree_state = attr.ib() + clip_value = attr.ib() + samples_cumulative_sum = attr.ib() + + def __init__(self, + record_specs, + noise_generator, + clip_fn, + clip_value, + use_efficient=True): + """Initializes the `TreeCumulativeSumQuery`. + + Consider using `build_l2_gaussian_query` for the construction of a + `TreeCumulativeSumQuery` with L2 norm clipping and Gaussian noise. + + Args: + record_specs: `Collection[tf.TensorSpec]` specifying shapes of records. + noise_generator: `tree_aggregation.ValueGenerator` to generate the noise + value for a tree node. Should be coupled with clipping norm to guarantee + privacy. + clip_fn: Callable that specifies clipping function. Input to clip is a + flat list of vars in a record. + clip_value: Float indicating the value at which to clip the record. + use_efficient: Boolean indicating the usage of the efficient tree + aggregation algorithm based on the paper "Efficient Use of + Differentially Private Binary Trees". + """ + self._clip_fn = clip_fn + self._clip_value = clip_value + self._record_specs = record_specs + if use_efficient: + self._tree_aggregator = tree_aggregation.EfficientTreeAggregator( + noise_generator) + else: + self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator) + + def initial_global_state(self): + """Returns initial global state.""" + initial_tree_state = self._tree_aggregator.init_state() + initial_samples_cumulative_sum = tf.nest.map_structure( + lambda spec: tf.zeros(spec.shape), self._record_specs) + initial_state = TreeCumulativeSumQuery.GlobalState( + tree_state=initial_tree_state, + clip_value=tf.constant(self._clip_value, tf.float32), + samples_cumulative_sum=initial_samples_cumulative_sum) + return initial_state + + def derive_sample_params(self, global_state): + return global_state.clip_value + + def preprocess_record(self, params, record): + """Returns the clipped record using `clip_fn` and params. + + Args: + params: `clip_value` for the record. + record: The record to be processed. + + Returns: + Structure of clipped tensors. + """ + clip_value = params + record_as_list = tf.nest.flatten(record) + clipped_as_list = self._clip_fn(record_as_list, clip_value) + return tf.nest.pack_sequence_as(record, clipped_as_list) + + def get_noised_result(self, sample_state, global_state): + """Updates tree, state, and returns noised cumulative sum and updated state. + + Computes new cumulative sum, and returns its noised value. Grows tree_state + by one new leaf, and returns the new state. + + Args: + sample_state: Sum of clipped records for this round. + global_state: Global state with current samples cumulative sum and tree + state. + + Returns: + A tuple of (noised_cumulative_sum, new_global_state). + """ + new_cumulative_sum = tf.nest.map_structure( + tf.add, global_state.samples_cumulative_sum, sample_state) + cumulative_sum_noise, new_tree_state = self._tree_aggregator.get_cumsum_and_update( + global_state.tree_state) + new_global_state = attr.evolve( + global_state, + samples_cumulative_sum=new_cumulative_sum, + tree_state=new_tree_state) + noised_cum_sum = tf.nest.map_structure(tf.add, new_cumulative_sum, + cumulative_sum_noise) + return noised_cum_sum, new_global_state + + @classmethod + def build_l2_gaussian_query(cls, + clip_norm, + noise_multiplier, + record_specs, + noise_seed=None, + use_efficient=True): + """Returns a query instance with L2 norm clipping and Gaussian noise. + + Args: + clip_norm: Each record will be clipped so that it has L2 norm at most + `clip_norm`. + noise_multiplier: The effective noise multiplier for the sum of records. + Noise standard deviation is `clip_norm*noise_multiplier`. + record_specs: `Collection[tf.TensorSpec]` specifying shapes of records. + noise_seed: Integer seed for the Gaussian noise generator. If `None`, a + nondeterministic seed based on system time will be generated. + use_efficient: Boolean indicating the usage of the efficient tree + aggregation algorithm based on the paper "Efficient Use of + Differentially Private Binary Trees". + """ + if clip_norm <= 0: + raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.') + + if noise_multiplier < 0: + raise ValueError( + f'`noise_multiplier` must be non-negative, got {noise_multiplier}.') + + gaussian_noise_generator = tree_aggregation.GaussianNoiseGenerator( + noise_std=clip_norm * noise_multiplier, + specs=record_specs, + seed=noise_seed) + + def l2_clip_fn(record_as_list, clip_norm): + clipped_record, _ = tf.clip_by_global_norm(record_as_list, clip_norm) + return clipped_record + + return cls( + clip_fn=l2_clip_fn, + clip_value=clip_norm, + record_specs=record_specs, + noise_generator=gaussian_noise_generator, + use_efficient=use_efficient) + + +class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): + """Implements dp_query for adding correlated noise through tree structure. + + Clips and sums records in current sample; returns the current sample adding + the noise residual from tree aggregation. The returned value is conceptually + equivalent to the following: calculates cumulative sum of samples over time + (instead of only current sample) with added noise for cumulative sum + proportional to log(T), T being the number of times the query is called; + returns the residual between the current noised cumsum and the previous one + when the query is called. Combining this query with a SGD optimizer can be + used to implement the DP-FTRL algorithm in + "Practical and Private (Deep) Learning without Sampling or Shuffling". + + Attributes: + clip_fn: Callable that specifies clipping function. `clip_fn` receives two + arguments: a flat list of vars in a record and a `clip_value` to clip the + corresponding record, e.g. clip_fn(flat_record, clip_value). + clip_value: float indicating the value at which to clip the record. + record_specs: `Collection[tf.TensorSpec]` specifying shapes of records. + tree_aggregator: `tree_aggregation.TreeAggregator` initialized with user + defined `noise_generator`. `noise_generator` is a + `tree_aggregation.ValueGenerator` to generate the noise value for a tree + node. Noise stdandard deviation is specified outside the `dp_query` by the + user when defining `noise_fn` and should have order + O(clip_norm*log(T)/eps) to guarantee eps-DP. + """ + + @attr.s(frozen=True) + class GlobalState(object): + """Class defining global state for Tree sum queries. + + Attributes: + tree_state: Current state of noise tree keeping track of current leaf and + each level state. + clip_value: The clipping value to be passed to clip_fn. + previous_tree_noise: Cumulative noise by tree aggregation from the + previous time the query is called on a sample. + """ + tree_state = attr.ib() + clip_value = attr.ib() + previous_tree_noise = attr.ib() + + def __init__(self, + record_specs, + noise_generator, + clip_fn, + clip_value, + use_efficient=True): + """Initializes the `TreeResidualSumQuery`. + + Consider using `build_l2_gaussian_query` for the construction of a + `TreeResidualSumQuery` with L2 norm clipping and Gaussian noise. + + Args: + record_specs: `Collection[tf.TensorSpec]` specifying shapes of records. + noise_generator: `tree_aggregation.ValueGenerator` to generate the noise + value for a tree node. Should be coupled with clipping norm to guarantee + privacy. + clip_fn: Callable that specifies clipping function. Input to clip is a + flat list of vars in a record. + clip_value: Float indicating the value at which to clip the record. + use_efficient: Boolean indicating the usage of the efficient tree + aggregation algorithm based on the paper "Efficient Use of + Differentially Private Binary Trees". + """ + self._clip_fn = clip_fn + self._clip_value = clip_value + self._record_specs = record_specs + if use_efficient: + self._tree_aggregator = tree_aggregation.EfficientTreeAggregator( + noise_generator) + else: + self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator) + + def initial_global_state(self): + """Returns initial global state.""" + initial_tree_state = self._tree_aggregator.init_state() + initial_noise = tf.nest.map_structure(lambda spec: tf.zeros(spec.shape), + self._record_specs) + return TreeResidualSumQuery.GlobalState( + tree_state=initial_tree_state, + clip_value=tf.constant(self._clip_value, tf.float32), + previous_tree_noise=initial_noise) + + def derive_sample_params(self, global_state): + return global_state.clip_value + + def preprocess_record(self, params, record): + """Returns the clipped record using `clip_fn` and params. + + Args: + params: `clip_value` for the record. + record: The record to be processed. + + Returns: + Structure of clipped tensors. + """ + clip_value = params + record_as_list = tf.nest.flatten(record) + clipped_as_list = self._clip_fn(record_as_list, clip_value) + return tf.nest.pack_sequence_as(record, clipped_as_list) + + def get_noised_result(self, sample_state, global_state): + """Updates tree state, and returns residual of noised cumulative sum. + + Args: + sample_state: Sum of clipped records for this round. + global_state: Global state with current samples cumulative sum and tree + state. + + Returns: + A tuple of (noised_cumulative_sum, new_global_state). + """ + tree_noise, new_tree_state = self._tree_aggregator.get_cumsum_and_update( + global_state.tree_state) + noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c, + sample_state, tree_noise, + global_state.previous_tree_noise) + new_global_state = attr.evolve( + global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state) + return noised_sample, new_global_state + + @classmethod + def build_l2_gaussian_query(cls, + clip_norm, + noise_multiplier, + record_specs, + noise_seed=None, + use_efficient=True): + """Returns `TreeResidualSumQuery` with L2 norm clipping and Gaussian noise. + + Args: + clip_norm: Each record will be clipped so that it has L2 norm at most + `clip_norm`. + noise_multiplier: The effective noise multiplier for the sum of records. + Noise standard deviation is `clip_norm*noise_multiplier`. + record_specs: `Collection[tf.TensorSpec]` specifying shapes of records. + noise_seed: Integer seed for the Gaussian noise generator. If `None`, a + nondeterministic seed based on system time will be generated. + use_efficient: Boolean indicating the usage of the efficient tree + aggregation algorithm based on the paper "Efficient Use of + Differentially Private Binary Trees". + """ + if clip_norm <= 0: + raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.') + + if noise_multiplier < 0: + raise ValueError( + f'`noise_multiplier` must be non-negative, got {noise_multiplier}.') + + gaussian_noise_generator = tree_aggregation.GaussianNoiseGenerator( + noise_std=clip_norm * noise_multiplier, + specs=record_specs, + seed=noise_seed) + + def l2_clip_fn(record_as_list, clip_norm): + clipped_record, _ = tf.clip_by_global_norm(record_as_list, clip_norm) + return clipped_record + + return cls( + clip_fn=l2_clip_fn, + clip_value=clip_norm, + record_specs=record_specs, + noise_generator=gaussian_noise_generator, + use_efficient=use_efficient) diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py new file mode 100644 index 0000000..8cf2157 --- /dev/null +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py @@ -0,0 +1,399 @@ +# Copyright 2021, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for `tree_aggregation_query`.""" + +from absl.testing import parameterized + +import numpy as np +import tensorflow as tf + +from tensorflow_privacy.privacy.dp_query import test_utils +from tensorflow_privacy.privacy.dp_query import tree_aggregation +from tensorflow_privacy.privacy.dp_query import tree_aggregation_query + + +STRUCT_RECORD = [ + tf.constant([[2.0, 0.0], [0.0, 1.0]]), + tf.constant([-1.0, 0.0]) +] + +SINGLE_VALUE_RECORDS = [tf.constant(1.), tf.constant(3.), tf.constant(5.)] + +STRUCTURE_SPECS = tf.nest.map_structure(lambda t: tf.TensorSpec(tf.shape(t)), + STRUCT_RECORD) +NOISE_STD = 5.0 + +STREAMING_SCALARS = np.array(range(7), dtype=np.single) + + +def _get_noise_generator(specs, stddev=NOISE_STD, seed=1): + return tree_aggregation.GaussianNoiseGenerator( + noise_std=stddev, specs=specs, seed=seed) + + +def _get_noise_fn(specs, stddev=NOISE_STD, seed=1): + random_generator = tf.random.Generator.from_seed(seed) + + def noise_fn(): + shape = tf.nest.map_structure(lambda spec: spec.shape, specs) + return tf.nest.map_structure( + lambda x: random_generator.normal(x, stddev=stddev), shape) + + return noise_fn + + +def _get_no_noise_fn(specs): + shape = tf.nest.map_structure(lambda spec: spec.shape, specs) + def no_noise_fn(): + return tf.nest.map_structure(tf.zeros, shape) + + return no_noise_fn + + +def _get_l2_clip_fn(): + + def l2_clip_fn(record_as_list, clip_value): + clipped_record, _ = tf.clip_by_global_norm(record_as_list, clip_value) + return clipped_record + + return l2_clip_fn + + +def _get_l_infty_clip_fn(): + + def l_infty_clip_fn(record_as_list, clip_value): + def clip(record): + return tf.clip_by_value( + record, clip_value_min=-clip_value, clip_value_max=clip_value) + + clipped_record = tf.nest.map_structure(clip, record_as_list) + return clipped_record + + return l_infty_clip_fn + + +class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase): + + def test_correct_initial_global_state_struct_type(self): + query = tree_aggregation_query.TreeCumulativeSumQuery( + clip_fn=_get_l2_clip_fn(), + clip_value=10., + noise_generator=_get_no_noise_fn(STRUCTURE_SPECS), + record_specs=STRUCTURE_SPECS) + global_state = query.initial_global_state() + + self.assertIsInstance(global_state.tree_state, tree_aggregation.TreeState) + expected_cum_sum = tf.nest.map_structure(lambda spec: tf.zeros(spec.shape), + STRUCTURE_SPECS) + self.assertAllClose(expected_cum_sum, global_state.samples_cumulative_sum) + + def test_correct_initial_global_state_single_value_type(self): + record_specs = tf.nest.map_structure(lambda t: tf.TensorSpec(tf.shape(t)), + SINGLE_VALUE_RECORDS[0]) + query = tree_aggregation_query.TreeCumulativeSumQuery( + clip_fn=_get_l2_clip_fn(), + clip_value=10., + noise_generator=_get_no_noise_fn(record_specs), + record_specs=record_specs) + global_state = query.initial_global_state() + + self.assertIsInstance(global_state.tree_state, tree_aggregation.TreeState) + expected_cum_sum = tf.nest.map_structure(lambda spec: tf.zeros(spec.shape), + record_specs) + self.assertAllClose(expected_cum_sum, global_state.samples_cumulative_sum) + + @parameterized.named_parameters( + ('not_clip_single_record', SINGLE_VALUE_RECORDS[0], 10.0), + ('clip_single_record', SINGLE_VALUE_RECORDS[1], 1.0)) + def test_l2_clips_single_record(self, record, l2_norm_clip): + record_specs = tf.nest.map_structure(lambda t: tf.TensorSpec(tf.shape(t)), + SINGLE_VALUE_RECORDS[0]) + query = tree_aggregation_query.TreeCumulativeSumQuery( + clip_fn=_get_l2_clip_fn(), + clip_value=l2_norm_clip, + noise_generator=_get_no_noise_fn(record_specs), + record_specs=record_specs) + global_state = query.initial_global_state() + record_norm = tf.norm(record) + if record_norm > l2_norm_clip: + expected_clipped_record = tf.nest.map_structure( + lambda t: t * l2_norm_clip / record_norm, record) + else: + expected_clipped_record = record + clipped_record = query.preprocess_record(global_state.clip_value, record) + self.assertAllClose(expected_clipped_record, clipped_record) + + @parameterized.named_parameters( + ('not_clip_structure_record', STRUCT_RECORD, 10.0), + ('clip_structure_record', STRUCT_RECORD, 1.0)) + def test_l2_clips_structure_type_record(self, record, l2_norm_clip): + query = tree_aggregation_query.TreeCumulativeSumQuery( + clip_fn=_get_l2_clip_fn(), + clip_value=l2_norm_clip, + noise_generator=_get_no_noise_fn(STRUCTURE_SPECS), + record_specs=tf.nest.map_structure(lambda t: tf.TensorSpec(tf.shape(t)), + record)) + global_state = query.initial_global_state() + record_norm = tf.linalg.global_norm(record) + if record_norm > l2_norm_clip: + expected_clipped_record = tf.nest.map_structure( + lambda t: t * l2_norm_clip / record_norm, record) + else: + expected_clipped_record = record + clipped_record = query.preprocess_record(global_state.clip_value, record) + self.assertAllClose(expected_clipped_record, clipped_record) + + @parameterized.named_parameters( + ('not_clip_single_record', SINGLE_VALUE_RECORDS[0], 10.0), + ('clip_single_record', SINGLE_VALUE_RECORDS[1], 1.0)) + def test_l_infty_clips_single_record(self, record, norm_clip): + record_specs = tf.nest.map_structure(lambda t: tf.TensorSpec(tf.shape(t)), + SINGLE_VALUE_RECORDS[0]) + query = tree_aggregation_query.TreeCumulativeSumQuery( + clip_fn=_get_l_infty_clip_fn(), + clip_value=norm_clip, + noise_generator=_get_no_noise_fn(record_specs), + record_specs=record_specs) + global_state = query.initial_global_state() + expected_clipped_record = tf.nest.map_structure( + lambda t: tf.clip_by_value(t, -norm_clip, norm_clip), record) + clipped_record = query.preprocess_record(global_state.clip_value, record) + self.assertAllClose(expected_clipped_record, clipped_record) + + @parameterized.named_parameters( + ('not_clip_structure_record', STRUCT_RECORD, 10.0), + ('clip_structure_record', STRUCT_RECORD, 1.0)) + def test_linfty_clips_structure_type_record(self, record, norm_clip): + query = tree_aggregation_query.TreeCumulativeSumQuery( + clip_fn=_get_l_infty_clip_fn(), + clip_value=norm_clip, + noise_generator=_get_no_noise_fn(STRUCTURE_SPECS), + record_specs=tf.nest.map_structure(lambda t: tf.TensorSpec(tf.shape(t)), + record)) + global_state = query.initial_global_state() + expected_clipped_record = tf.nest.map_structure( + lambda t: tf.clip_by_value(t, -norm_clip, norm_clip), record) + clipped_record = query.preprocess_record(global_state.clip_value, record) + self.assertAllClose(expected_clipped_record, clipped_record) + + def test_noiseless_query_single_value_type_record(self): + record_specs = tf.nest.map_structure(lambda t: tf.TensorSpec(tf.shape(t)), + SINGLE_VALUE_RECORDS[0]) + query = tree_aggregation_query.TreeCumulativeSumQuery( + clip_fn=_get_l2_clip_fn(), + clip_value=10., + noise_generator=_get_no_noise_fn(record_specs), + record_specs=record_specs) + query_result, _ = test_utils.run_query(query, SINGLE_VALUE_RECORDS) + expected = tf.constant(9.) + self.assertAllClose(query_result, expected) + + def test_noiseless_query_structure_type_record(self): + query = tree_aggregation_query.TreeCumulativeSumQuery( + clip_fn=_get_l2_clip_fn(), + clip_value=10., + noise_generator=_get_no_noise_fn(STRUCTURE_SPECS), + record_specs=STRUCTURE_SPECS) + query_result, _ = test_utils.run_query(query, + [STRUCT_RECORD, STRUCT_RECORD]) + expected = tf.nest.map_structure(lambda a, b: a + b, STRUCT_RECORD, + STRUCT_RECORD) + self.assertAllClose(query_result, expected) + + @parameterized.named_parameters( + ('two_records_noise_fn', [2.71828, 3.14159], _get_noise_fn), + ('five_records_noise_fn', np.random.uniform(size=5).tolist(), + _get_noise_fn), + ('two_records_generator', [2.71828, 3.14159], _get_noise_generator), + ('five_records_generator', np.random.uniform(size=5).tolist(), + _get_noise_generator), + ) + def test_noisy_cumsum_and_state_update(self, records, value_generator): + num_trials = 200 + record_specs = tf.nest.map_structure(lambda t: tf.TensorSpec(tf.shape(t)), + records[0]) + noised_sums = [] + for i in range(num_trials): + query = tree_aggregation_query.TreeCumulativeSumQuery( + clip_fn=_get_l2_clip_fn(), + clip_value=10., + noise_generator=value_generator(record_specs, seed=i), + record_specs=record_specs) + query_result, _ = test_utils.run_query(query, records) + noised_sums.append(query_result) + result_stddev = np.std(noised_sums) + self.assertNear(result_stddev, NOISE_STD, 0.7) # value for chi-squared test + + @parameterized.named_parameters( + ('no_clip', STREAMING_SCALARS, 10., np.cumsum(STREAMING_SCALARS)), + ('all_clip', STREAMING_SCALARS, 0.5, STREAMING_SCALARS * 0.5), + # STREAMING_SCALARS is list(range(7)), only the last element is clipped + # for the following test, which makes the expected value for the last sum + # to be `cumsum(5)+5`. + ('partial_clip', STREAMING_SCALARS, 5., + np.append(np.cumsum(STREAMING_SCALARS[:-1]), 20.)), + ) + def test_partial_sum_scalar_no_noise(self, streaming_scalars, clip_norm, + partial_sum): + query = tree_aggregation_query.TreeCumulativeSumQuery( + clip_fn=_get_l2_clip_fn(), + clip_value=clip_norm, + noise_generator=lambda: 0., + record_specs=tf.TensorSpec([]), + ) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + for scalar, expected_sum in zip(streaming_scalars, partial_sum): + sample_state = query.initial_sample_state(scalar) + sample_state = query.accumulate_record(params, sample_state, scalar) + query_result, global_state = query.get_noised_result( + sample_state, global_state) + self.assertEqual(query_result, expected_sum) + + @parameterized.named_parameters( + ('s0t1step8', 0., 1., [1., 1., 2., 1., 2., 2., 3., 1.]), + ('s1t1step8', 1., 1., [2., 3., 5., 5., 7., 8., 10., 9.]), + ('s1t2step8', 1., 2., [3., 4., 7., 6., 9., 10., 13., 10.]), + ) + def test_partial_sum_scalar_tree_aggregation(self, scalar_value, + tree_node_value, + expected_values): + query = tree_aggregation_query.TreeCumulativeSumQuery( + clip_fn=_get_l2_clip_fn(), + clip_value=scalar_value + 1., # no clip + noise_generator=lambda: tree_node_value, + record_specs=tf.TensorSpec([]), + use_efficient=False, + ) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + for val in expected_values: + # For each streaming step i , the expected value is roughly + # `scalar_value*i + tree_aggregation(tree_node_value, i)` + sample_state = query.initial_sample_state(scalar_value) + sample_state = query.accumulate_record(params, sample_state, scalar_value) + query_result, global_state = query.get_noised_result( + sample_state, global_state) + self.assertEqual(query_result, val) + + @parameterized.named_parameters( + ('efficient', True, tree_aggregation.EfficientTreeAggregator), + ('normal', False, tree_aggregation.TreeAggregator), + ) + def test_sum_tree_aggregator_instance(self, use_efficient, tree_class): + specs = tf.TensorSpec([]) + query = tree_aggregation_query.TreeCumulativeSumQuery( + clip_fn=_get_l2_clip_fn(), + clip_value=1., + noise_generator=_get_noise_fn(specs, 1.), + record_specs=specs, + use_efficient=use_efficient, + ) + self.assertIsInstance(query._tree_aggregator, tree_class) + + @parameterized.named_parameters( + ('r5d10n0s1s16eff', 5, 10, 0., 1, 16, 0.1, True), + ('r3d5n1s1s32eff', 3, 5, 1., 1, 32, 1., True), + ('r10d3n1s2s16eff', 10, 3, 1., 2, 16, 10., True), + ('r10d3n1s2s16', 10, 3, 1., 2, 16, 10., False), + ) + def test_build_l2_gaussian_query(self, records_num, record_dim, + noise_multiplier, seed, total_steps, clip, + use_efficient): + record_specs = tf.TensorSpec(shape=[record_dim]) + query = tree_aggregation_query.TreeCumulativeSumQuery.build_l2_gaussian_query( + clip_norm=clip, + noise_multiplier=noise_multiplier, + record_specs=record_specs, + noise_seed=seed, + use_efficient=use_efficient) + reference_query = tree_aggregation_query.TreeCumulativeSumQuery( + clip_fn=_get_l2_clip_fn(), + clip_value=clip, + noise_generator=_get_noise_generator(record_specs, + clip * noise_multiplier, seed), + record_specs=record_specs, + use_efficient=use_efficient) + global_state = query.initial_global_state() + reference_global_state = reference_query.initial_global_state() + + for _ in range(total_steps): + records = [ + tf.random.uniform(shape=[record_dim], maxval=records_num) + for _ in range(records_num) + ] + query_result, global_state = test_utils.run_query(query, records, + global_state) + reference_query_result, reference_global_state = test_utils.run_query( + reference_query, records, reference_global_state) + self.assertAllClose(query_result, reference_query_result, rtol=1e-6) + + +class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + ('r5d10n0s1s16eff', 5, 10, 0., 1, 16, 0.1, True), + ('r3d5n1s1s32eff', 3, 5, 1., 1, 32, 1., True), + ('r10d3n1s2s16eff', 10, 3, 1., 2, 16, 10., True), + ('r10d3n1s2s16', 10, 3, 1., 2, 16, 10., False), + ) + def test_sum(self, records_num, record_dim, noise_multiplier, seed, + total_steps, clip, use_efficient): + record_specs = tf.TensorSpec(shape=[record_dim]) + query = tree_aggregation_query.TreeResidualSumQuery.build_l2_gaussian_query( + clip_norm=clip, + noise_multiplier=noise_multiplier, + record_specs=record_specs, + noise_seed=seed, + use_efficient=use_efficient) + sum_query = tree_aggregation_query.TreeCumulativeSumQuery.build_l2_gaussian_query( + clip_norm=clip, + noise_multiplier=noise_multiplier, + record_specs=record_specs, + noise_seed=seed, + use_efficient=use_efficient) + global_state = query.initial_global_state() + sum_global_state = sum_query.initial_global_state() + + cumsum_result = tf.zeros(shape=[record_dim]) + for _ in range(total_steps): + records = [ + tf.random.uniform(shape=[record_dim], maxval=records_num) + for _ in range(records_num) + ] + query_result, global_state = test_utils.run_query(query, records, + global_state) + sum_query_result, sum_global_state = test_utils.run_query( + sum_query, records, sum_global_state) + cumsum_result += query_result + self.assertAllClose(cumsum_result, sum_query_result, rtol=1e-6) + + @parameterized.named_parameters( + ('efficient', True, tree_aggregation.EfficientTreeAggregator), + ('normal', False, tree_aggregation.TreeAggregator), + ) + def test_sum_tree_aggregator_instance(self, use_efficient, tree_class): + specs = tf.TensorSpec([]) + query = tree_aggregation_query.TreeResidualSumQuery( + clip_fn=_get_l2_clip_fn(), + clip_value=1., + noise_generator=_get_noise_fn(specs, 1.), + record_specs=specs, + use_efficient=use_efficient, + ) + self.assertIsInstance(query._tree_aggregator, tree_class) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py new file mode 100644 index 0000000..9a8be35 --- /dev/null +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py @@ -0,0 +1,369 @@ +# Copyright 2021, Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for `tree_aggregation`.""" +import math +from absl.testing import parameterized + +import tensorflow as tf +from tensorflow_privacy.privacy.dp_query import tree_aggregation + + +class ConstantValueGenerator(tree_aggregation.ValueGenerator): + + def __init__(self, constant_value): + self.constant_value = constant_value + + def initialize(self): + return () + + def next(self, state): + return self.constant_value, state + + +class TreeAggregatorTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + ('total4_step1', 4, [1, 1, 2, 1], 1), + ('total5_step1', 5, [1, 1, 2, 1, 2], 1), + ('total6_step1', 6, [1, 1, 2, 1, 2, 2], 1), + ('total7_step1', 7, [1, 1, 2, 1, 2, 2, 3], 1), + ('total8_step1', 8, [1, 1, 2, 1, 2, 2, 3, 1], 1), + ('total8_step2', 8, [2, 2, 4, 2, 4, 4, 6, 2], 2), + ('total8_step0d5', 8, [0.5, 0.5, 1, 0.5, 1, 1, 1.5, 0.5], 0.5)) + def test_tree_sum_steps_expected(self, total_steps, expected_values, + node_value): + # Test whether `tree_aggregator` will output `expected_value` in each step + # when `total_steps` of leaf nodes are traversed. The value of each tree + # node is a constant `node_value` for test purpose. Note that `node_value` + # denotes the "noise" without private values in private algorithms. + tree_aggregator = tree_aggregation.TreeAggregator( + value_generator=ConstantValueGenerator(node_value)) + state = tree_aggregator.init_state() + for leaf_node_idx in range(total_steps): + self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) + val, state = tree_aggregator.get_cumsum_and_update(state) + self.assertEqual(expected_values[leaf_node_idx], val) + + @parameterized.named_parameters( + ('total16_step1', 16, 1, 1), + ('total17_step1', 17, 2, 1), + ('total18_step1', 18, 2, 1), + ('total19_step1', 19, 3, 1), + ('total20_step0d5', 20, 1, 0.5), + ('total21_step2', 21, 6, 2), + ('total1024_step1', 1024, 1, 1), + ('total1025_step1', 1025, 2, 1), + ('total1026_step1', 1026, 2, 1), + ('total1027_step1', 1027, 3, 1), + ('total1028_step0d5', 1028, 1, 0.5), + ('total1029_step2', 1029, 6, 2), + ) + def test_tree_sum_last_step_expected(self, total_steps, expected_value, + node_value): + # Test whether `tree_aggregator` will output `expected_value` after + # `total_steps` of leaf nodes are traversed. The value of each tree node + # is a constant `node_value` for test purpose. Note that `node_value` + # denotes the "noise" without private values in private algorithms. + tree_aggregator = tree_aggregation.TreeAggregator( + value_generator=ConstantValueGenerator(node_value)) + state = tree_aggregator.init_state() + for leaf_node_idx in range(total_steps): + self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) + val, state = tree_aggregator.get_cumsum_and_update(state) + self.assertEqual(expected_value, val) + + @parameterized.named_parameters( + ('total16_step1', 16, 1, 1), + ('total17_step1', 17, 2, 1), + ('total18_step1', 18, 2, 1), + ('total19_step1', 19, 3, 1), + ('total20_step0d5', 20, 1, 0.5), + ('total21_step2', 21, 6, 2), + ('total1024_step1', 1024, 1, 1), + ('total1025_step1', 1025, 2, 1), + ('total1026_step1', 1026, 2, 1), + ('total1027_step1', 1027, 3, 1), + ('total1028_step0d5', 1028, 1, 0.5), + ('total1029_step2', 1029, 6, 2), + ) + def test_tree_sum_last_step_expected_value_fn(self, total_steps, + expected_value, node_value): + # Test no-arg function as stateless value generator. + tree_aggregator = tree_aggregation.TreeAggregator( + value_generator=lambda: node_value) + state = tree_aggregator.init_state() + for leaf_node_idx in range(total_steps): + self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) + val, state = tree_aggregator.get_cumsum_and_update(state) + self.assertEqual(expected_value, val) + + @parameterized.named_parameters( + ('total8_step1', 8, 1), + ('total8_step2', 8, 2), + ('total8_step0d5', 8, 0.5), + ('total32_step0d5', 32, 0.5), + ('total1024_step0d5', 1024, 0.5), + ('total2020_step0d5', 2020, 0.5), + ) + def test_tree_sum_steps_max(self, total_steps, node_value): + tree_aggregator = tree_aggregation.TreeAggregator( + value_generator=ConstantValueGenerator(node_value)) + max_val = node_value * math.ceil(math.log2(total_steps)) + state = tree_aggregator.init_state() + for leaf_node_idx in range(total_steps): + self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) + val, state = tree_aggregator.get_cumsum_and_update(state) + self.assertLessEqual(val, max_val) + + @parameterized.named_parameters( + ('total4_std1_d1000', 4, [1, 1, 2, 1], 1, [1000], 0.15), + ('total4_std1_d10000', 4, [1, 1, 2, 1], 1, [10000], 0.05), + ('total7_std1_d1000', 7, [1, 1, 2, 1, 2, 2, 3], 1, [1000], 0.15), + ('total8_std1_d1000', 8, [1, 1, 2, 1, 2, 2, 3, 1], 1, [1000], 0.15), + ('total8_std2_d1000', 8, [4, 4, 8, 4, 8, 8, 12, 4], 2, [1000], 0.15), + ('total8_std0d5_d1000', 8, [0.25, 0.25, 0.5, 0.25, 0.5, 0.5, 0.75, 0.25 + ], 0.5, [1000], 0.15)) + def test_tree_sum_noise_expected(self, total_steps, expected_variance, + noise_std, variable_shape, tolerance): + # Test whether `tree_aggregator` will output `expected_variance` (within a + # relative `tolerance`) in each step when `total_steps` of leaf nodes are + # traversed. Each tree node is a `variable_shape` tensor of Gaussian noise + # with `noise_std`. + random_generator = tree_aggregation.GaussianNoiseGenerator( + noise_std, tf.TensorSpec(variable_shape), seed=2020) + tree_aggregator = tree_aggregation.TreeAggregator( + value_generator=random_generator) + state = tree_aggregator.init_state() + for leaf_node_idx in range(total_steps): + self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) + val, state = tree_aggregator.get_cumsum_and_update(state) + self.assertAllClose( + math.sqrt(expected_variance[leaf_node_idx]), + tf.math.reduce_std(val), + rtol=tolerance) + + def test_cumsum_vector(self, total_steps=15): + + tree_aggregator = tree_aggregation.TreeAggregator( + value_generator=ConstantValueGenerator([ + tf.ones([2, 2], dtype=tf.float32), + tf.constant([2], dtype=tf.float32) + ])) + tree_aggregator_truth = tree_aggregation.TreeAggregator( + value_generator=ConstantValueGenerator(1.)) + state = tree_aggregator.init_state() + truth_state = tree_aggregator_truth.init_state() + for leaf_node_idx in range(total_steps): + self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) + val, state = tree_aggregator.get_cumsum_and_update(state) + expected_val, truth_state = tree_aggregator_truth.get_cumsum_and_update( + truth_state) + self.assertEqual( + tree_aggregation.get_step_idx(state), + tree_aggregation.get_step_idx(truth_state)) + expected_result = [ + expected_val * tf.ones([2, 2], dtype=tf.float32), + expected_val * tf.constant([2], dtype=tf.float32), + ] + tf.nest.map_structure(self.assertAllEqual, val, expected_result) + + +class EfficientTreeAggregatorTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + ('total1_step1', 1, 1, 1.), + ('total2_step1', 2, 4. / 3., 1.), + ('total3_step1', 3, 4. / 3. + 1., 1.), + ('total4_step1', 4, 12. / 7., 1.), + ('total5_step1', 5, 12. / 7. + 1., 1.), + ('total6_step1', 6, 12. / 7. + 4. / 3., 1.), + ('total7_step1', 7, 12. / 7. + 4. / 3. + 1., 1.), + ('total8_step1', 8, 32. / 15., 1.), + ('total1024_step1', 1024, 11. / (2 - .5**10), 1.), + ('total1025_step1', 1025, 11. / (2 - .5**10) + 1., 1.), + ('total1026_step1', 1026, 11. / (2 - .5**10) + 4. / 3., 1.), + ('total1027_step1', 1027, 11. / (2 - .5**10) + 4. / 3. + 1.0, 1.), + ('total1028_step0d5', 1028, (11. / (2 - .5**10) + 12. / 7.) * .5, .5), + ('total1029_step2', 1029, (11. / (2 - .5**10) + 12. / 7. + 1.) * 2., 2.), + ) + def test_tree_sum_last_step_expected(self, total_steps, expected_value, + step_value): + # Test whether `tree_aggregator` will output `expected_value` after + # `total_steps` of leaf nodes are traversed. The value of each tree node + # is a constant `node_value` for test purpose. Note that `node_value` + # denotes the "noise" without private values in private algorithms. The + # `expected_value` is based on a weighting schema strongly depends on the + # depth of the binary tree. + tree_aggregator = tree_aggregation.EfficientTreeAggregator( + value_generator=ConstantValueGenerator(step_value)) + state = tree_aggregator.init_state() + for leaf_node_idx in range(total_steps): + self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) + val, state = tree_aggregator.get_cumsum_and_update(state) + self.assertAllClose(expected_value, val) + + @parameterized.named_parameters( + ('total4_std1_d1000', 4, 4. / 7., 1., [1000], 0.15), + ('total4_std1_d10000', 4, 4. / 7., 1., [10000], 0.05), + ('total7_std1_d1000', 7, 4. / 7. + 2. / 3. + 1., 1, [1000], 0.15), + ('total8_std1_d1000', 8, 8. / 15., 1., [1000], 0.15), + ('total8_std2_d1000', 8, 8. / 15. * 4, 2., [1000], 0.15), + ('total8_std0d5_d1000', 8, 8. / 15. * .25, .5, [1000], 0.15)) + def test_tree_sum_noise_expected(self, total_steps, expected_variance, + noise_std, variable_shape, tolerance): + # Test whether `tree_aggregator` will output `expected_variance` (within a + # relative `tolerance`) after `total_steps` of leaf nodes are traversed. + # Each tree node is a `variable_shape` tensor of Gaussian noise with + # `noise_std`. Note that the variance of a tree node is smaller than + # the given vanilla node `noise_std` because of the update rule of + # `EfficientTreeAggregator`. + random_generator = tree_aggregation.GaussianNoiseGenerator( + noise_std, tf.TensorSpec(variable_shape), seed=2020) + tree_aggregator = tree_aggregation.EfficientTreeAggregator( + value_generator=random_generator) + state = tree_aggregator.init_state() + for leaf_node_idx in range(total_steps): + self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) + val, state = tree_aggregator.get_cumsum_and_update(state) + self.assertAllClose( + math.sqrt(expected_variance), tf.math.reduce_std(val), rtol=tolerance) + + @parameterized.named_parameters( + ('total4_std1_d1000', 4, 1., [1000], 1e-6), + ('total30_std2_d1000', 30, 2, [1000], 1e-6), + ('total32_std0d5_d1000', 32, .5, [1000], 1e-6), + ('total60_std1_d1000', 60, 1, [1000], 1e-6), + ) + def test_tree_sum_noise_efficient(self, total_steps, noise_std, + variable_shape, tolerance): + # Test the variance returned by `EfficientTreeAggregator` is smaller than + # `TreeAggregator` (within a relative `tolerance`) after `total_steps` of + # leaf nodes are traversed. Each tree node is a `variable_shape` tensor of + # Gaussian noise with `noise_std`. A small `tolerance` is used for numerical + # stability, `tolerance==0` means `EfficientTreeAggregator` is strictly + # better than `TreeAggregator` for reducing variance. + random_generator = tree_aggregation.GaussianNoiseGenerator( + noise_std, tf.TensorSpec(variable_shape)) + tree_aggregator = tree_aggregation.EfficientTreeAggregator( + value_generator=random_generator) + tree_aggregator_baseline = tree_aggregation.TreeAggregator( + value_generator=random_generator) + + state = tree_aggregator.init_state() + state_baseline = tree_aggregator_baseline.init_state() + for leaf_node_idx in range(total_steps): + self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) + val, state = tree_aggregator.get_cumsum_and_update(state) + val_baseline, state_baseline = tree_aggregator_baseline.get_cumsum_and_update( + state_baseline) + self.assertLess( + tf.math.reduce_variance(val), + (1 + tolerance) * tf.math.reduce_variance(val_baseline)) + + def test_cumsum_vector(self, total_steps=15): + + tree_aggregator = tree_aggregation.EfficientTreeAggregator( + value_generator=ConstantValueGenerator([ + tf.ones([2, 2], dtype=tf.float32), + tf.constant([2], dtype=tf.float32) + ])) + tree_aggregator_truth = tree_aggregation.EfficientTreeAggregator( + value_generator=ConstantValueGenerator(1.)) + state = tree_aggregator.init_state() + truth_state = tree_aggregator_truth.init_state() + for leaf_node_idx in range(total_steps): + self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) + val, state = tree_aggregator.get_cumsum_and_update(state) + expected_val, truth_state = tree_aggregator_truth.get_cumsum_and_update( + truth_state) + self.assertEqual( + tree_aggregation.get_step_idx(state), + tree_aggregation.get_step_idx(truth_state)) + expected_result = [ + expected_val * tf.ones([2, 2], dtype=tf.float32), + expected_val * tf.constant([2], dtype=tf.float32), + ] + tf.nest.map_structure(self.assertAllClose, val, expected_result) + + +class GaussianNoiseGeneratorTest(tf.test.TestCase): + + def test_random_generator_tf(self, + noise_mean=1.0, + noise_std=1.0, + samples=1000, + tolerance=0.15): + g = tree_aggregation.GaussianNoiseGenerator( + noise_std, specs=tf.TensorSpec([]), seed=2020) + gstate = g.initialize() + + @tf.function + def return_noise(state): + value, state = g.next(state) + return noise_mean + value, state + + noise_values = [] + for _ in range(samples): + value, gstate = return_noise(gstate) + noise_values.append(value) + noise_values = tf.stack(noise_values) + self.assertAllClose( + [tf.math.reduce_mean(noise_values), + tf.math.reduce_std(noise_values)], [noise_mean, noise_std], + rtol=tolerance) + + def test_seed_state(self, seed=1, steps=32, noise_std=0.1): + g = tree_aggregation.GaussianNoiseGenerator( + noise_std=noise_std, specs=tf.TensorSpec([]), seed=seed) + gstate = g.initialize() + g2 = tree_aggregation.GaussianNoiseGenerator( + noise_std=noise_std, specs=tf.TensorSpec([]), seed=seed) + gstate2 = g.initialize() + self.assertAllEqual(gstate, gstate2) + for _ in range(steps): + value, gstate = g.next(gstate) + value2, gstate2 = g2.next(gstate2) + self.assertAllEqual(value, value2) + self.assertAllEqual(gstate, gstate2) + + def test_seed_state_nondeterministic(self, steps=32, noise_std=0.1): + g = tree_aggregation.GaussianNoiseGenerator( + noise_std=noise_std, specs=tf.TensorSpec([])) + gstate = g.initialize() + g2 = tree_aggregation.GaussianNoiseGenerator( + noise_std=noise_std, specs=tf.TensorSpec([])) + gstate2 = g2.initialize() + for _ in range(steps): + value, gstate = g.next(gstate) + value2, gstate2 = g2.next(gstate2) + self.assertNotAllEqual(value, value2) + self.assertNotAllEqual(gstate, gstate2) + + def test_seed_state_structure(self, seed=1, steps=32, noise_std=0.1): + specs = [tf.TensorSpec([]), tf.TensorSpec([1]), tf.TensorSpec([2, 2])] + g = tree_aggregation.GaussianNoiseGenerator( + noise_std=noise_std, specs=specs, seed=seed) + gstate = g.initialize() + g2 = tree_aggregation.GaussianNoiseGenerator( + noise_std=noise_std, specs=specs, seed=seed) + gstate2 = g2.initialize() + for _ in range(steps): + value, gstate = g.next(gstate) + value2, gstate2 = g2.next(gstate2) + self.assertAllClose(value, value2) + self.assertAllEqual(gstate, gstate2) + + +if __name__ == '__main__': + tf.test.main()