From 4d335d1b69206712b6325626d7df1063b9815ade Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 27 Jul 2021 17:42:18 -0700 Subject: [PATCH] (1) Merge `CentralTreeSumQuery` and `DistributedTreeSumQuery` into one DPQuery to modularize things. The new query takes in an `inner_query` argument. Depending on the behavior of inner query, the query will follow central DP or distributed DP. (2) Remove the hard-coded L1 clipping and replace with norm bound checking in the inner query. This design allows us to use whatever clipping factory we want outside the DPQuery. PiperOrigin-RevId: 387236482 --- .../dp_query/tree_aggregation_query.py | 286 ++++++---------- .../dp_query/tree_aggregation_query_test.py | 308 ++++++------------ 2 files changed, 204 insertions(+), 390 deletions(-) diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index 5717e4f..bd6ff3c 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -15,21 +15,18 @@ `TreeCumulativeSumQuery` and `TreeResidualSumQuery` are `DPQuery`s for continual online observation queries relying on `tree_aggregation`. 'Online' means that -the leaf nodes of the tree arrive one by one as the time proceeds. The leaves -are vector records as defined in `dp_query.DPQuery`. +the leaf nodes of the tree arrive one by one as the time proceeds. -`CentralTreeSumQuery` and `DistributedTreeSumQuery` are `DPQuery`s for -central/distributed offline tree aggregation protocol. 'Offline' means all the -leaf nodes are ready before the protocol starts. Each record, different from -what is defined in `dp_query.DPQuery`, is a histogram (i.e. the leaf nodes). +`TreeRangeSumQuery` is a `DPQuery`s for offline tree aggregation protocol. +'Offline' means all the leaf nodes are ready before the protocol starts. """ -import distutils import math -from typing import Optional import attr import tensorflow as tf +from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query from tensorflow_privacy.privacy.dp_query import dp_query +from tensorflow_privacy.privacy.dp_query import gaussian_query from tensorflow_privacy.privacy.dp_query import tree_aggregation @@ -442,217 +439,84 @@ def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor: return tree -def _get_add_noise(stddev, seed: int = None): - """Utility function to decide which `add_noise` to use according to tf version.""" - if distutils.version.LooseVersion( - tf.__version__) < distutils.version.LooseVersion('2.0.0'): +class TreeRangeSumQuery(dp_query.SumAggregationDPQuery): + """Implements dp_query for accurate range queries using tree aggregation. - # The seed should be only used for testing purpose. - if seed is not None: - tf.random.set_seed(seed) - - def add_noise(v): - return v + tf.random.normal( - tf.shape(input=v), stddev=stddev, dtype=v.dtype) - else: - random_normal = tf.random_normal_initializer(stddev=stddev, seed=seed) - - def add_noise(v): - return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype) - - return add_noise - - -class CentralTreeSumQuery(dp_query.SumAggregationDPQuery): - """Implements dp_query for differentially private tree aggregation protocol. - - Implements a central variant of the tree aggregation protocol from the paper - "'Is interaction necessary for distributed private learning?.' Adam Smith, - Abhradeep Thakurta, Jalaj Upadhyay" by replacing their local randomizer with - gaussian mechanism. The first step is to clip the clients' local updates (i.e. - a 1-D array containing the leaf nodes of the tree) by L1 norm to make sure it - does not exceed a prespecified upper bound. The second step is to construct - the tree on the clipped update. The third step is to add independent gaussian - noise to each node in the tree. The returned tree can support efficient and - accurate range queries with differential privacy. + Implements a variant of the tree aggregation protocol from. "Is interaction + necessary for distributed private learning?. Adam Smith, Abhradeep Thakurta, + Jalaj Upadhyay." Builds a tree on top of the input record and adds noise to + the tree for differential privacy. Any range query can be decomposed into the + sum of O(log(n)) nodes in the tree compared to O(n) when using a histogram. + Improves efficiency and reduces noise scale. """ @attr.s(frozen=True) class GlobalState(object): - """Class defining global state for `CentralTreeSumQuery`. + """Class defining global state for TreeRangeSumQuery. Attributes: - l1_bound: An upper bound on the L1 norm of the input record. This is - needed to bound the sensitivity and deploy differential privacy. - """ - l1_bound = attr.ib() - - def __init__(self, - stddev: float, - arity: int = 2, - l1_bound: int = 10, - seed: Optional[int] = None): - """Initializes the `CentralTreeSumQuery`. - - Args: - stddev: The stddev of the noise added to each internal node of the - constructed tree. - arity: The branching factor of the tree. - l1_bound: An upper bound on the L1 norm of the input record. This is - needed to bound the sensitivity and deploy differential privacy. - seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for - test purpose. - """ - self._stddev = stddev - self._arity = arity - self._l1_bound = l1_bound - self._seed = seed - - def initial_global_state(self): - """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" - return CentralTreeSumQuery.GlobalState(l1_bound=self._l1_bound) - - def derive_sample_params(self, global_state): - """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" - return global_state.l1_bound - - def preprocess_record(self, params, record): - """Implements `tensorflow_privacy.DPQuery.preprocess_record`.""" - casted_record = tf.cast(record, tf.float32) - l1_norm = tf.norm(casted_record, ord=1) - - l1_bound = tf.cast(params, tf.float32) - - preprocessed_record, _ = tf.clip_by_global_norm([casted_record], - l1_bound, - use_norm=l1_norm) - - return preprocessed_record[0] - - def get_noised_result(self, sample_state, global_state): - """Implements `tensorflow_privacy.DPQuery.get_noised_result`. - - Args: - sample_state: a frequency histogram. - global_state: hyper-parameters of the query. - - Returns: - a `tf.RaggedTensor` representing the tree built on top of `sample_state`. - The jth node on the ith layer of the tree can be accessed by tree[i][j] - where tree is the returned value. - """ - add_noise = _get_add_noise(self._stddev, self._seed) - tree = _build_tree_from_leaf(sample_state, self._arity) - return tf.map_fn(add_noise, tree), global_state - - -class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery): - """Implements dp_query for differentially private tree aggregation protocol. - - The difference from `CentralTreeSumQuery` is that the tree construction and - gaussian noise addition happen in `preprocess_records`. The difference only - takes effect when used together with - `tff.aggregators.DifferentiallyPrivateFactory`. In other cases, this class - should be treated as equal with `CentralTreeSumQuery`. - - Implements a distributed version of the tree aggregation protocol from. "Is - interaction necessary for distributed private learning?." by replacing their - local randomizer with gaussian mechanism. The first step is to check the L1 - norm of the clients' local updates (i.e. a 1-D array containing the leaf nodes - of the tree) to make sure it does not exceed a prespecified upper bound. The - second step is to construct the tree. The third step is to add independent - gaussian noise to each node in the tree. The returned tree can support - efficient and accurate range queries with differential privacy. - """ - - @attr.s(frozen=True) - class GlobalState(object): - """Class defining global state for DistributedTreeSumQuery. - - Attributes: - stddev: The stddev of the noise added to each internal node in the - constructed tree. arity: The branching factor of the tree (i.e. the number of children each internal node has). - l1_bound: An upper bound on the L1 norm of the input record. This is - needed to bound the sensitivity and deploy differential privacy. + inner_query_state: The global state of the inner query. """ - stddev = attr.ib() arity = attr.ib() - l1_bound = attr.ib() + inner_query_state = attr.ib() def __init__(self, - stddev: float, - arity: int = 2, - l1_bound: int = 10, - seed: Optional[int] = None): - """Initializes the `DistributedTreeSumQuery`. + inner_query: dp_query.SumAggregationDPQuery, + arity: int = 2): + """Initializes the `TreeRangeSumQuery`. Args: - stddev: The stddev of the noise added to each node in the tree. - arity: The branching factor of the tree. - l1_bound: An upper bound on the L1 norm of the input record. This is - needed to bound the sensitivity and deploy differential privacy. - seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for - test purpose. + inner_query: The inner `DPQuery` that adds noise to the tree. + arity: The branching factor of the tree (i.e. the number of children each + internal node has). Defaults to 2. """ - self._stddev = stddev + self._inner_query = inner_query self._arity = arity - self._l1_bound = l1_bound - self._seed = seed + + if self._arity < 1: + raise ValueError(f'Invalid arity={arity} smaller than 2.') def initial_global_state(self): """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" - return DistributedTreeSumQuery.GlobalState( - stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound) + return TreeRangeSumQuery.GlobalState( + arity=self._arity, + inner_query_state=self._inner_query.initial_global_state()) def derive_sample_params(self, global_state): """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" - return (global_state.stddev, global_state.arity, global_state.l1_bound) + return (global_state.arity, + self._inner_query.derive_sample_params( + global_state.inner_query_state)) def preprocess_record(self, params, record): """Implements `tensorflow_privacy.DPQuery.preprocess_record`. - This method clips the input record by L1 norm, constructs a tree on top of - it, and adds gaussian noise to each node of the tree for differential - privacy. Unlike `get_noised_result` in `CentralTreeSumQuery`, this function - flattens the `tf.RaggedTensor` before outputting it. This is useful when - used inside `tff.aggregators.DifferentiallyPrivateFactory` because it does - not accept ragged output tensor. + This method builds the tree, flattens it and applies + `inner_query.preprocess_record` to the flattened tree. Args: - params: hyper-parameters for preprocessing record, (stddev, aritry, - l1_bound) - record: leaf nodes for the tree. + params: Hyper-parameters for preprocessing record. + record: A histogram representing the leaf nodes of the tree. Returns: - `tf.Tensor` representing the flattened version of the tree. + A `tf.Tensor` representing the flattened version of the preprocessed tree. """ - _, arity, l1_bound_ = params - l1_bound = tf.cast(l1_bound_, tf.float32) - - casted_record = tf.cast(record, tf.float32) - l1_norm = tf.norm(casted_record, ord=1) - - preprocessed_record, _ = tf.clip_by_global_norm([casted_record], - l1_bound, - use_norm=l1_norm) - preprocessed_record = preprocessed_record[0] - - add_noise = _get_add_noise(self._stddev, self._seed) - tree = _build_tree_from_leaf(preprocessed_record, arity) - noisy_tree = tf.map_fn(add_noise, tree) + arity, inner_query_params = params + preprocessed_record = _build_tree_from_leaf(record, arity).flat_values + preprocessed_record = self._inner_query.preprocess_record( + inner_query_params, preprocessed_record) # The following codes reshape the output vector so the output shape of can # be statically inferred. This is useful when used with # `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know # the output shape of this function statically and explicitly. - flat_noisy_tree = noisy_tree.flat_values - flat_tree_shape = [ + preprocessed_record_shape = [ (self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) - 1) // (self._arity - 1) ] - return tf.reshape(flat_noisy_tree, flat_tree_shape) + return tf.reshape(preprocessed_record, preprocessed_record_shape) def get_noised_result(self, sample_state, global_state): """Implements `tensorflow_privacy.DPQuery.get_noised_result`. @@ -661,12 +525,11 @@ class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery): output by `preprocess_records.` Args: - sample_state: `tf.Tensor` for the flattened tree. - global_state: hyper-parameters including noise multiplier, the branching - factor of the tree and the maximum records per user. + sample_state: A `tf.Tensor` for the flattened tree. + global_state: The global state of the protocol. Returns: - a `tf.RaggedTensor` for the tree. + A `tf.RaggedTensor` representing the tree. """ # The [0] is needed because of how tf.RaggedTensor.from_two_splits works. # print(tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6], @@ -682,3 +545,60 @@ class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery): tree = tf.RaggedTensor.from_row_splits( values=sample_state, row_splits=row_splits) return tree, global_state + + @classmethod + def build_central_gaussian_query(cls, + l2_norm_clip: float, + stddev: float, + arity: int = 2): + """Returns `TreeRangeSumQuery` with central Gaussian noise. + + Args: + l2_norm_clip: Each record should be clipped so that it has L2 norm at most + `l2_norm_clip`. + stddev: Stddev of the central Gaussian noise. + arity: The branching factor of the tree (i.e. the number of children each + internal node has). Defaults to 2. + """ + if l2_norm_clip <= 0: + raise ValueError(f'`l2_norm_clip` must be positive, got {l2_norm_clip}.') + + if stddev < 0: + raise ValueError(f'`stddev` must be non-negative, got {stddev}.') + + if arity < 2: + raise ValueError(f'`arity` must be at least 2, got {arity}.') + + inner_query = gaussian_query.GaussianSumQuery(l2_norm_clip, stddev) + + return cls(arity=arity, inner_query=inner_query) + + @classmethod + def build_distributed_discrete_gaussian_query(cls, + l2_norm_bound: float, + local_stddev: float, + arity: int = 2): + """Returns `TreeRangeSumQuery` with central Gaussian noise. + + Args: + l2_norm_bound: Each record should be clipped so that it has L2 norm at + most `l2_norm_bound`. + local_stddev: Scale/stddev of the local discrete Gaussian noise. + arity: The branching factor of the tree (i.e. the number of children each + internal node has). Defaults to 2. + """ + if l2_norm_bound <= 0: + raise ValueError( + f'`l2_clip_bound` must be positive, got {l2_norm_bound}.') + + if local_stddev < 0: + raise ValueError( + f'`local_stddev` must be non-negative, got {local_stddev}.') + + if arity < 2: + raise ValueError(f'`arity` must be at least 2, got {arity}.') + + inner_query = distributed_discrete_gaussian_query.DistributedDiscreteGaussianSumQuery( + l2_norm_bound, local_stddev) + + return cls(arity=arity, inner_query=inner_query) diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py index cc3a89a..3713b5d 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py @@ -423,111 +423,115 @@ class BuildTreeTest(tf.test.TestCase, parameterized.TestCase): self.assertEqual(tree[layer][idx], expected_value) -class CentralTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase): +class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase): - def test_initial_global_state_type(self): - - query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD) - global_state = query.initial_global_state() - self.assertIsInstance( - global_state, tree_aggregation_query.CentralTreeSumQuery.GlobalState) - - def test_derive_sample_params(self): - query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - self.assertAllClose(params, 10.) - - @parameterized.named_parameters( - ('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32)), - ('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32)), - ('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32)), - ('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], - dtype=tf.float32)), + @parameterized.product( + inner_query=['central', 'distributed'], + params=[(0., 1., 2), (1., -1., 2), (1., 1., 1)], ) - def test_preprocess_record(self, arity, record): - query = tree_aggregation_query.CentralTreeSumQuery( - stddev=NOISE_STD, arity=arity) + def test_raises_error(self, inner_query, params): + clip_norm, stddev, arity = params + with self.assertRaises(ValueError): + if inner_query == 'central': + tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( + clip_norm, stddev, arity) + elif inner_query == 'distributed': + tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( + clip_norm, stddev, arity) + + @parameterized.product( + inner_query=['central', 'distributed'], + clip_norm=[0.1, 1.0, 10.0], + stddev=[0.1, 1.0, 10.0]) + def test_initial_global_state_type(self, inner_query, clip_norm, stddev): + + if inner_query == 'central': + query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( + clip_norm, stddev) + elif inner_query == 'distributed': + query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( + clip_norm, stddev) + global_state = query.initial_global_state() + self.assertIsInstance(global_state, + tree_aggregation_query.TreeRangeSumQuery.GlobalState) + + @parameterized.product( + inner_query=['central', 'distributed'], + clip_norm=[0.1, 1.0, 10.0], + stddev=[0.1, 1.0, 10.0], + arity=[2, 3, 4]) + def test_derive_sample_params(self, inner_query, clip_norm, stddev, arity): + if inner_query == 'central': + query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( + clip_norm, stddev, arity) + elif inner_query == 'distributed': + query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( + clip_norm, stddev, arity) + global_state = query.initial_global_state() + derived_arity, inner_query_state = query.derive_sample_params(global_state) + self.assertAllClose(derived_arity, arity) + if inner_query == 'central': + self.assertAllClose(inner_query_state, clip_norm) + elif inner_query == 'distributed': + self.assertAllClose(inner_query_state.l2_norm_bound, clip_norm) + self.assertAllClose(inner_query_state.local_stddev, stddev) + + @parameterized.product( + (dict(arity=2, expected_tree=[1, 1, 0, 1, 0, 0, 0]), + dict(arity=3, expected_tree=[1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])), + inner_query=['central', 'distributed'], + ) + def test_preprocess_record(self, inner_query, arity, expected_tree): + if inner_query == 'central': + query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( + 10., 0., arity) + record = tf.constant([1, 0, 0, 0], dtype=tf.float32) + expected_tree = tf.cast(expected_tree, tf.float32) + elif inner_query == 'distributed': + query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( + 10., 0., arity) + record = tf.constant([1, 0, 0, 0], dtype=tf.int32) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) preprocessed_record = query.preprocess_record(params, record) - - self.assertAllClose(preprocessed_record, record) + self.assertAllClose(preprocessed_record, expected_tree) @parameterized.named_parameters( - ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32), - tf.constant([5, 5, 0, 0], dtype=tf.int32)), - ('binary_test_float', 2, tf.constant( - [10., 10., 0., 0.], - dtype=tf.float32), tf.constant([5., 5., 0., 0.], dtype=tf.float32)), - ('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32), - tf.constant([5, 5, 0, 0], dtype=tf.int32)), - ('ternary_test_float', 3, tf.constant([10., 10., 0., 0.], - dtype=tf.float32), - tf.constant([5., 5., 0., 0.], dtype=tf.float32)), + ('stddev_1', 1, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]), + ('stddev_0_1', 4, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]), ) - def test_preprocess_record_clipped(self, arity, record, - expected_clipped_value): - query = tree_aggregation_query.CentralTreeSumQuery( - stddev=NOISE_STD, arity=arity) + def test_distributed_preprocess_record_with_noise(self, local_stddev, record, + expected_tree): + query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( + 10., local_stddev) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) + preprocessed_record = query.preprocess_record(params, record) - self.assertAllClose(preprocessed_record, expected_clipped_value) - - @parameterized.named_parameters( - ('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32), - tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])), - ('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32), - tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])), - ('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32), - tf.ragged.constant([[1.], [1., 0., 0.], - [1., 0., 0., 0., 0., 0., 0., 0., 0.]])), - ('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32), - tf.ragged.constant([[1.], [1., 0., 0.], - [1., 0., 0., 0., 0., 0., 0., 0., 0.]])), - ) - def test_get_noised_result(self, arity, record, expected_tree): - query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - sample_state, global_state = query.get_noised_result( - preprocessed_record, global_state) - - self.assertAllClose(sample_state, expected_tree) - - @parameterized.named_parameters( - ('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), - ('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), - ) - def test_get_noised_result_with_noise(self, stddev, record, expected_tree): - query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev, seed=0) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - - sample_state, _ = query.get_noised_result(preprocessed_record, global_state) self.assertAllClose( - sample_state.flat_values, expected_tree, atol=3 * stddev) + preprocessed_record, expected_tree, atol=10 * local_stddev) - @parameterized.named_parameters( - ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32), - tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])), - ('binary_test_float', 2, tf.constant([10., 10., 0., 0.], - dtype=tf.float32), - tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])), - ('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32), - tf.ragged.constant([[10.], [10., 0., 0.], - [5., 5., 0., 0., 0., 0., 0., 0., 0.]])), - ('ternary_test_float', 3, tf.constant([10., 10., 0., 0.], - dtype=tf.float32), - tf.ragged.constant([[10.], [10., 0., 0.], - [5., 5., 0., 0., 0., 0., 0., 0., 0.]])), + @parameterized.product( + (dict( + arity=2, + expected_tree=tf.ragged.constant([[1], [1, 0], [1, 0, 0, 0]])), + dict( + arity=3, + expected_tree=tf.ragged.constant([[1], [1, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 0]]))), + inner_query=['central', 'distributed'], ) - def test_get_noised_result_clipped(self, arity, record, expected_tree): - query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity) + def test_get_noised_result(self, inner_query, arity, expected_tree): + if inner_query == 'central': + query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( + 10., 0., arity) + record = tf.constant([1, 0, 0, 0], dtype=tf.float32) + expected_tree = tf.cast(expected_tree, tf.float32) + elif inner_query == 'distributed': + query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( + 10., 0., arity) + record = tf.constant([1, 0, 0, 0], dtype=tf.int32) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) preprocessed_record = query.preprocess_record(params, record) @@ -536,128 +540,18 @@ class CentralTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase): self.assertAllClose(sample_state, expected_tree) - -class DistributedTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase): - - def test_initial_global_state_type(self): - - query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD) - global_state = query.initial_global_state() - self.assertIsInstance( - global_state, - tree_aggregation_query.DistributedTreeSumQuery.GlobalState) - - def test_derive_sample_params(self): - query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD) - global_state = query.initial_global_state() - stddev, arity, l1_bound = query.derive_sample_params(global_state) - self.assertAllClose(stddev, NOISE_STD) - self.assertAllClose(arity, 2) - self.assertAllClose(l1_bound, 10) - - @parameterized.named_parameters( - ('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32), - tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])), - ('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32), - tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])), - ('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32), - tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0. - ])), - ('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32), - tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0. - ])), - ) - def test_preprocess_record(self, arity, record, expected_tree): - query = tree_aggregation_query.DistributedTreeSumQuery( - stddev=0., arity=arity) + @parameterized.product(stddev=[0.1, 1.0, 10.0]) + def test_central_get_noised_result_with_noise(self, stddev): + query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( + 10., stddev) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - self.assertAllClose(preprocessed_record, expected_tree) - - @parameterized.named_parameters( - ('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), - ('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), - ) - def test_preprocess_record_with_noise(self, stddev, record, expected_tree): - query = tree_aggregation_query.DistributedTreeSumQuery( - stddev=stddev, seed=0) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - - preprocessed_record = query.preprocess_record(params, record) - - self.assertAllClose(preprocessed_record, expected_tree, atol=3 * stddev) - - @parameterized.named_parameters( - ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32), - tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])), - ('binary_test_float', 2, tf.constant([10., 10., 0., 0.], - dtype=tf.float32), - tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])), - ('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32), - tf.ragged.constant( - [10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])), - ('ternary_test_float', 3, tf.constant([10., 10., 0., 0.], - dtype=tf.float32), - tf.ragged.constant( - [10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])), - ) - def test_preprocess_record_clipped(self, arity, record, expected_tree): - query = tree_aggregation_query.DistributedTreeSumQuery( - stddev=0., arity=arity) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - self.assertAllClose(preprocessed_record, expected_tree) - - @parameterized.named_parameters( - ('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32), - tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])), - ('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32), - tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])), - ('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32), - tf.ragged.constant([[1.], [1., 0., 0.], - [1., 0., 0., 0., 0., 0., 0., 0., 0.]])), - ('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32), - tf.ragged.constant([[1.], [1., 0., 0.], - [1., 0., 0., 0., 0., 0., 0., 0., 0.]])), - ) - def test_get_noised_result(self, arity, record, expected_tree): - query = tree_aggregation_query.DistributedTreeSumQuery( - stddev=0., arity=arity) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) + preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.])) sample_state, global_state = query.get_noised_result( preprocessed_record, global_state) - self.assertAllClose(sample_state, expected_tree) - - @parameterized.named_parameters( - ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32), - tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])), - ('binary_test_float', 2, tf.constant([10., 10., 0., 0.], - dtype=tf.float32), - tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])), - ('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32), - tf.ragged.constant([[10.], [10., 0., 0.], - [5., 5., 0., 0., 0., 0., 0., 0., 0.]])), - ('ternary_test_float', 3, tf.constant([10., 10., 0., 0.], - dtype=tf.float32), - tf.ragged.constant([[10.], [10., 0., 0.], - [5., 5., 0., 0., 0., 0., 0., 0., 0.]])), - ) - def test_get_noised_result_clipped(self, arity, record, expected_tree): - query = tree_aggregation_query.DistributedTreeSumQuery( - stddev=0., arity=arity) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - sample_state, global_state = query.get_noised_result( - preprocessed_record, global_state) - - self.assertAllClose(sample_state, expected_tree) + self.assertAllClose( + sample_state, tf.ragged.constant([[1.], [1., 0.]]), atol=10 * stddev) if __name__ == '__main__':