From eef5810d94d56428717757574ed39bb0430cf246 Mon Sep 17 00:00:00 2001 From: Keith Rush Date: Tue, 27 Jul 2021 20:03:58 -0700 Subject: [PATCH] Automated rollback of commit 4d335d1b69206712b6325626d7df1063b9815ade PiperOrigin-RevId: 387254617 --- .../dp_query/tree_aggregation_query.py | 292 +++++++++++------ .../dp_query/tree_aggregation_query_test.py | 308 ++++++++++++------ 2 files changed, 393 insertions(+), 207 deletions(-) diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index bd6ff3c..5717e4f 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -15,18 +15,21 @@ `TreeCumulativeSumQuery` and `TreeResidualSumQuery` are `DPQuery`s for continual online observation queries relying on `tree_aggregation`. 'Online' means that -the leaf nodes of the tree arrive one by one as the time proceeds. +the leaf nodes of the tree arrive one by one as the time proceeds. The leaves +are vector records as defined in `dp_query.DPQuery`. -`TreeRangeSumQuery` is a `DPQuery`s for offline tree aggregation protocol. -'Offline' means all the leaf nodes are ready before the protocol starts. +`CentralTreeSumQuery` and `DistributedTreeSumQuery` are `DPQuery`s for +central/distributed offline tree aggregation protocol. 'Offline' means all the +leaf nodes are ready before the protocol starts. Each record, different from +what is defined in `dp_query.DPQuery`, is a histogram (i.e. the leaf nodes). """ +import distutils import math +from typing import Optional import attr import tensorflow as tf -from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query from tensorflow_privacy.privacy.dp_query import dp_query -from tensorflow_privacy.privacy.dp_query import gaussian_query from tensorflow_privacy.privacy.dp_query import tree_aggregation @@ -439,84 +442,217 @@ def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor: return tree -class TreeRangeSumQuery(dp_query.SumAggregationDPQuery): - """Implements dp_query for accurate range queries using tree aggregation. +def _get_add_noise(stddev, seed: int = None): + """Utility function to decide which `add_noise` to use according to tf version.""" + if distutils.version.LooseVersion( + tf.__version__) < distutils.version.LooseVersion('2.0.0'): - Implements a variant of the tree aggregation protocol from. "Is interaction - necessary for distributed private learning?. Adam Smith, Abhradeep Thakurta, - Jalaj Upadhyay." Builds a tree on top of the input record and adds noise to - the tree for differential privacy. Any range query can be decomposed into the - sum of O(log(n)) nodes in the tree compared to O(n) when using a histogram. - Improves efficiency and reduces noise scale. + # The seed should be only used for testing purpose. + if seed is not None: + tf.random.set_seed(seed) + + def add_noise(v): + return v + tf.random.normal( + tf.shape(input=v), stddev=stddev, dtype=v.dtype) + else: + random_normal = tf.random_normal_initializer(stddev=stddev, seed=seed) + + def add_noise(v): + return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype) + + return add_noise + + +class CentralTreeSumQuery(dp_query.SumAggregationDPQuery): + """Implements dp_query for differentially private tree aggregation protocol. + + Implements a central variant of the tree aggregation protocol from the paper + "'Is interaction necessary for distributed private learning?.' Adam Smith, + Abhradeep Thakurta, Jalaj Upadhyay" by replacing their local randomizer with + gaussian mechanism. The first step is to clip the clients' local updates (i.e. + a 1-D array containing the leaf nodes of the tree) by L1 norm to make sure it + does not exceed a prespecified upper bound. The second step is to construct + the tree on the clipped update. The third step is to add independent gaussian + noise to each node in the tree. The returned tree can support efficient and + accurate range queries with differential privacy. """ @attr.s(frozen=True) class GlobalState(object): - """Class defining global state for TreeRangeSumQuery. + """Class defining global state for `CentralTreeSumQuery`. Attributes: - arity: The branching factor of the tree (i.e. the number of children each - internal node has). - inner_query_state: The global state of the inner query. + l1_bound: An upper bound on the L1 norm of the input record. This is + needed to bound the sensitivity and deploy differential privacy. """ - arity = attr.ib() - inner_query_state = attr.ib() + l1_bound = attr.ib() def __init__(self, - inner_query: dp_query.SumAggregationDPQuery, - arity: int = 2): - """Initializes the `TreeRangeSumQuery`. + stddev: float, + arity: int = 2, + l1_bound: int = 10, + seed: Optional[int] = None): + """Initializes the `CentralTreeSumQuery`. Args: - inner_query: The inner `DPQuery` that adds noise to the tree. - arity: The branching factor of the tree (i.e. the number of children each - internal node has). Defaults to 2. + stddev: The stddev of the noise added to each internal node of the + constructed tree. + arity: The branching factor of the tree. + l1_bound: An upper bound on the L1 norm of the input record. This is + needed to bound the sensitivity and deploy differential privacy. + seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for + test purpose. """ - self._inner_query = inner_query + self._stddev = stddev self._arity = arity - - if self._arity < 1: - raise ValueError(f'Invalid arity={arity} smaller than 2.') + self._l1_bound = l1_bound + self._seed = seed def initial_global_state(self): """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" - return TreeRangeSumQuery.GlobalState( - arity=self._arity, - inner_query_state=self._inner_query.initial_global_state()) + return CentralTreeSumQuery.GlobalState(l1_bound=self._l1_bound) def derive_sample_params(self, global_state): """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" - return (global_state.arity, - self._inner_query.derive_sample_params( - global_state.inner_query_state)) + return global_state.l1_bound + + def preprocess_record(self, params, record): + """Implements `tensorflow_privacy.DPQuery.preprocess_record`.""" + casted_record = tf.cast(record, tf.float32) + l1_norm = tf.norm(casted_record, ord=1) + + l1_bound = tf.cast(params, tf.float32) + + preprocessed_record, _ = tf.clip_by_global_norm([casted_record], + l1_bound, + use_norm=l1_norm) + + return preprocessed_record[0] + + def get_noised_result(self, sample_state, global_state): + """Implements `tensorflow_privacy.DPQuery.get_noised_result`. + + Args: + sample_state: a frequency histogram. + global_state: hyper-parameters of the query. + + Returns: + a `tf.RaggedTensor` representing the tree built on top of `sample_state`. + The jth node on the ith layer of the tree can be accessed by tree[i][j] + where tree is the returned value. + """ + add_noise = _get_add_noise(self._stddev, self._seed) + tree = _build_tree_from_leaf(sample_state, self._arity) + return tf.map_fn(add_noise, tree), global_state + + +class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery): + """Implements dp_query for differentially private tree aggregation protocol. + + The difference from `CentralTreeSumQuery` is that the tree construction and + gaussian noise addition happen in `preprocess_records`. The difference only + takes effect when used together with + `tff.aggregators.DifferentiallyPrivateFactory`. In other cases, this class + should be treated as equal with `CentralTreeSumQuery`. + + Implements a distributed version of the tree aggregation protocol from. "Is + interaction necessary for distributed private learning?." by replacing their + local randomizer with gaussian mechanism. The first step is to check the L1 + norm of the clients' local updates (i.e. a 1-D array containing the leaf nodes + of the tree) to make sure it does not exceed a prespecified upper bound. The + second step is to construct the tree. The third step is to add independent + gaussian noise to each node in the tree. The returned tree can support + efficient and accurate range queries with differential privacy. + """ + + @attr.s(frozen=True) + class GlobalState(object): + """Class defining global state for DistributedTreeSumQuery. + + Attributes: + stddev: The stddev of the noise added to each internal node in the + constructed tree. + arity: The branching factor of the tree (i.e. the number of children each + internal node has). + l1_bound: An upper bound on the L1 norm of the input record. This is + needed to bound the sensitivity and deploy differential privacy. + """ + stddev = attr.ib() + arity = attr.ib() + l1_bound = attr.ib() + + def __init__(self, + stddev: float, + arity: int = 2, + l1_bound: int = 10, + seed: Optional[int] = None): + """Initializes the `DistributedTreeSumQuery`. + + Args: + stddev: The stddev of the noise added to each node in the tree. + arity: The branching factor of the tree. + l1_bound: An upper bound on the L1 norm of the input record. This is + needed to bound the sensitivity and deploy differential privacy. + seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for + test purpose. + """ + self._stddev = stddev + self._arity = arity + self._l1_bound = l1_bound + self._seed = seed + + def initial_global_state(self): + """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" + return DistributedTreeSumQuery.GlobalState( + stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound) + + def derive_sample_params(self, global_state): + """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" + return (global_state.stddev, global_state.arity, global_state.l1_bound) def preprocess_record(self, params, record): """Implements `tensorflow_privacy.DPQuery.preprocess_record`. - This method builds the tree, flattens it and applies - `inner_query.preprocess_record` to the flattened tree. + This method clips the input record by L1 norm, constructs a tree on top of + it, and adds gaussian noise to each node of the tree for differential + privacy. Unlike `get_noised_result` in `CentralTreeSumQuery`, this function + flattens the `tf.RaggedTensor` before outputting it. This is useful when + used inside `tff.aggregators.DifferentiallyPrivateFactory` because it does + not accept ragged output tensor. Args: - params: Hyper-parameters for preprocessing record. - record: A histogram representing the leaf nodes of the tree. + params: hyper-parameters for preprocessing record, (stddev, aritry, + l1_bound) + record: leaf nodes for the tree. Returns: - A `tf.Tensor` representing the flattened version of the preprocessed tree. + `tf.Tensor` representing the flattened version of the tree. """ - arity, inner_query_params = params - preprocessed_record = _build_tree_from_leaf(record, arity).flat_values - preprocessed_record = self._inner_query.preprocess_record( - inner_query_params, preprocessed_record) + _, arity, l1_bound_ = params + l1_bound = tf.cast(l1_bound_, tf.float32) + + casted_record = tf.cast(record, tf.float32) + l1_norm = tf.norm(casted_record, ord=1) + + preprocessed_record, _ = tf.clip_by_global_norm([casted_record], + l1_bound, + use_norm=l1_norm) + preprocessed_record = preprocessed_record[0] + + add_noise = _get_add_noise(self._stddev, self._seed) + tree = _build_tree_from_leaf(preprocessed_record, arity) + noisy_tree = tf.map_fn(add_noise, tree) # The following codes reshape the output vector so the output shape of can # be statically inferred. This is useful when used with # `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know # the output shape of this function statically and explicitly. - preprocessed_record_shape = [ + flat_noisy_tree = noisy_tree.flat_values + flat_tree_shape = [ (self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) - 1) // (self._arity - 1) ] - return tf.reshape(preprocessed_record, preprocessed_record_shape) + return tf.reshape(flat_noisy_tree, flat_tree_shape) def get_noised_result(self, sample_state, global_state): """Implements `tensorflow_privacy.DPQuery.get_noised_result`. @@ -525,11 +661,12 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery): output by `preprocess_records.` Args: - sample_state: A `tf.Tensor` for the flattened tree. - global_state: The global state of the protocol. + sample_state: `tf.Tensor` for the flattened tree. + global_state: hyper-parameters including noise multiplier, the branching + factor of the tree and the maximum records per user. Returns: - A `tf.RaggedTensor` representing the tree. + a `tf.RaggedTensor` for the tree. """ # The [0] is needed because of how tf.RaggedTensor.from_two_splits works. # print(tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6], @@ -545,60 +682,3 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery): tree = tf.RaggedTensor.from_row_splits( values=sample_state, row_splits=row_splits) return tree, global_state - - @classmethod - def build_central_gaussian_query(cls, - l2_norm_clip: float, - stddev: float, - arity: int = 2): - """Returns `TreeRangeSumQuery` with central Gaussian noise. - - Args: - l2_norm_clip: Each record should be clipped so that it has L2 norm at most - `l2_norm_clip`. - stddev: Stddev of the central Gaussian noise. - arity: The branching factor of the tree (i.e. the number of children each - internal node has). Defaults to 2. - """ - if l2_norm_clip <= 0: - raise ValueError(f'`l2_norm_clip` must be positive, got {l2_norm_clip}.') - - if stddev < 0: - raise ValueError(f'`stddev` must be non-negative, got {stddev}.') - - if arity < 2: - raise ValueError(f'`arity` must be at least 2, got {arity}.') - - inner_query = gaussian_query.GaussianSumQuery(l2_norm_clip, stddev) - - return cls(arity=arity, inner_query=inner_query) - - @classmethod - def build_distributed_discrete_gaussian_query(cls, - l2_norm_bound: float, - local_stddev: float, - arity: int = 2): - """Returns `TreeRangeSumQuery` with central Gaussian noise. - - Args: - l2_norm_bound: Each record should be clipped so that it has L2 norm at - most `l2_norm_bound`. - local_stddev: Scale/stddev of the local discrete Gaussian noise. - arity: The branching factor of the tree (i.e. the number of children each - internal node has). Defaults to 2. - """ - if l2_norm_bound <= 0: - raise ValueError( - f'`l2_clip_bound` must be positive, got {l2_norm_bound}.') - - if local_stddev < 0: - raise ValueError( - f'`local_stddev` must be non-negative, got {local_stddev}.') - - if arity < 2: - raise ValueError(f'`arity` must be at least 2, got {arity}.') - - inner_query = distributed_discrete_gaussian_query.DistributedDiscreteGaussianSumQuery( - l2_norm_bound, local_stddev) - - return cls(arity=arity, inner_query=inner_query) diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py index 3713b5d..cc3a89a 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py @@ -423,115 +423,72 @@ class BuildTreeTest(tf.test.TestCase, parameterized.TestCase): self.assertEqual(tree[layer][idx], expected_value) -class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase): +class CentralTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase): - @parameterized.product( - inner_query=['central', 'distributed'], - params=[(0., 1., 2), (1., -1., 2), (1., 1., 1)], - ) - def test_raises_error(self, inner_query, params): - clip_norm, stddev, arity = params - with self.assertRaises(ValueError): - if inner_query == 'central': - tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( - clip_norm, stddev, arity) - elif inner_query == 'distributed': - tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( - clip_norm, stddev, arity) + def test_initial_global_state_type(self): - @parameterized.product( - inner_query=['central', 'distributed'], - clip_norm=[0.1, 1.0, 10.0], - stddev=[0.1, 1.0, 10.0]) - def test_initial_global_state_type(self, inner_query, clip_norm, stddev): - - if inner_query == 'central': - query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( - clip_norm, stddev) - elif inner_query == 'distributed': - query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( - clip_norm, stddev) + query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD) global_state = query.initial_global_state() - self.assertIsInstance(global_state, - tree_aggregation_query.TreeRangeSumQuery.GlobalState) + self.assertIsInstance( + global_state, tree_aggregation_query.CentralTreeSumQuery.GlobalState) - @parameterized.product( - inner_query=['central', 'distributed'], - clip_norm=[0.1, 1.0, 10.0], - stddev=[0.1, 1.0, 10.0], - arity=[2, 3, 4]) - def test_derive_sample_params(self, inner_query, clip_norm, stddev, arity): - if inner_query == 'central': - query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( - clip_norm, stddev, arity) - elif inner_query == 'distributed': - query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( - clip_norm, stddev, arity) - global_state = query.initial_global_state() - derived_arity, inner_query_state = query.derive_sample_params(global_state) - self.assertAllClose(derived_arity, arity) - if inner_query == 'central': - self.assertAllClose(inner_query_state, clip_norm) - elif inner_query == 'distributed': - self.assertAllClose(inner_query_state.l2_norm_bound, clip_norm) - self.assertAllClose(inner_query_state.local_stddev, stddev) - - @parameterized.product( - (dict(arity=2, expected_tree=[1, 1, 0, 1, 0, 0, 0]), - dict(arity=3, expected_tree=[1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])), - inner_query=['central', 'distributed'], - ) - def test_preprocess_record(self, inner_query, arity, expected_tree): - if inner_query == 'central': - query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( - 10., 0., arity) - record = tf.constant([1, 0, 0, 0], dtype=tf.float32) - expected_tree = tf.cast(expected_tree, tf.float32) - elif inner_query == 'distributed': - query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( - 10., 0., arity) - record = tf.constant([1, 0, 0, 0], dtype=tf.int32) + def test_derive_sample_params(self): + query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - self.assertAllClose(preprocessed_record, expected_tree) + self.assertAllClose(params, 10.) @parameterized.named_parameters( - ('stddev_1', 1, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]), - ('stddev_0_1', 4, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]), + ('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32)), + ('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32)), + ('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32)), + ('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], + dtype=tf.float32)), ) - def test_distributed_preprocess_record_with_noise(self, local_stddev, record, - expected_tree): - query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( - 10., local_stddev) + def test_preprocess_record(self, arity, record): + query = tree_aggregation_query.CentralTreeSumQuery( + stddev=NOISE_STD, arity=arity) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - self.assertAllClose( - preprocessed_record, expected_tree, atol=10 * local_stddev) + self.assertAllClose(preprocessed_record, record) - @parameterized.product( - (dict( - arity=2, - expected_tree=tf.ragged.constant([[1], [1, 0], [1, 0, 0, 0]])), - dict( - arity=3, - expected_tree=tf.ragged.constant([[1], [1, 0, 0], - [1, 0, 0, 0, 0, 0, 0, 0, 0]]))), - inner_query=['central', 'distributed'], + @parameterized.named_parameters( + ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32), + tf.constant([5, 5, 0, 0], dtype=tf.int32)), + ('binary_test_float', 2, tf.constant( + [10., 10., 0., 0.], + dtype=tf.float32), tf.constant([5., 5., 0., 0.], dtype=tf.float32)), + ('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32), + tf.constant([5, 5, 0, 0], dtype=tf.int32)), + ('ternary_test_float', 3, tf.constant([10., 10., 0., 0.], + dtype=tf.float32), + tf.constant([5., 5., 0., 0.], dtype=tf.float32)), ) - def test_get_noised_result(self, inner_query, arity, expected_tree): - if inner_query == 'central': - query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( - 10., 0., arity) - record = tf.constant([1, 0, 0, 0], dtype=tf.float32) - expected_tree = tf.cast(expected_tree, tf.float32) - elif inner_query == 'distributed': - query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( - 10., 0., arity) - record = tf.constant([1, 0, 0, 0], dtype=tf.int32) + def test_preprocess_record_clipped(self, arity, record, + expected_clipped_value): + query = tree_aggregation_query.CentralTreeSumQuery( + stddev=NOISE_STD, arity=arity) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + preprocessed_record = query.preprocess_record(params, record) + self.assertAllClose(preprocessed_record, expected_clipped_value) + + @parameterized.named_parameters( + ('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32), + tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])), + ('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32), + tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])), + ('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32), + tf.ragged.constant([[1.], [1., 0., 0.], + [1., 0., 0., 0., 0., 0., 0., 0., 0.]])), + ('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32), + tf.ragged.constant([[1.], [1., 0., 0.], + [1., 0., 0., 0., 0., 0., 0., 0., 0.]])), + ) + def test_get_noised_result(self, arity, record, expected_tree): + query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) preprocessed_record = query.preprocess_record(params, record) @@ -540,18 +497,167 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase): self.assertAllClose(sample_state, expected_tree) - @parameterized.product(stddev=[0.1, 1.0, 10.0]) - def test_central_get_noised_result_with_noise(self, stddev): - query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( - 10., stddev) + @parameterized.named_parameters( + ('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), + ('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), + ) + def test_get_noised_result_with_noise(self, stddev, record, expected_tree): + query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev, seed=0) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.])) + preprocessed_record = query.preprocess_record(params, record) + + sample_state, _ = query.get_noised_result(preprocessed_record, global_state) + + self.assertAllClose( + sample_state.flat_values, expected_tree, atol=3 * stddev) + + @parameterized.named_parameters( + ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32), + tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])), + ('binary_test_float', 2, tf.constant([10., 10., 0., 0.], + dtype=tf.float32), + tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])), + ('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32), + tf.ragged.constant([[10.], [10., 0., 0.], + [5., 5., 0., 0., 0., 0., 0., 0., 0.]])), + ('ternary_test_float', 3, tf.constant([10., 10., 0., 0.], + dtype=tf.float32), + tf.ragged.constant([[10.], [10., 0., 0.], + [5., 5., 0., 0., 0., 0., 0., 0., 0.]])), + ) + def test_get_noised_result_clipped(self, arity, record, expected_tree): + query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + preprocessed_record = query.preprocess_record(params, record) sample_state, global_state = query.get_noised_result( preprocessed_record, global_state) - self.assertAllClose( - sample_state, tf.ragged.constant([[1.], [1., 0.]]), atol=10 * stddev) + self.assertAllClose(sample_state, expected_tree) + + +class DistributedTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase): + + def test_initial_global_state_type(self): + + query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD) + global_state = query.initial_global_state() + self.assertIsInstance( + global_state, + tree_aggregation_query.DistributedTreeSumQuery.GlobalState) + + def test_derive_sample_params(self): + query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD) + global_state = query.initial_global_state() + stddev, arity, l1_bound = query.derive_sample_params(global_state) + self.assertAllClose(stddev, NOISE_STD) + self.assertAllClose(arity, 2) + self.assertAllClose(l1_bound, 10) + + @parameterized.named_parameters( + ('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32), + tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])), + ('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32), + tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])), + ('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32), + tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0. + ])), + ('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32), + tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0. + ])), + ) + def test_preprocess_record(self, arity, record, expected_tree): + query = tree_aggregation_query.DistributedTreeSumQuery( + stddev=0., arity=arity) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + preprocessed_record = query.preprocess_record(params, record) + self.assertAllClose(preprocessed_record, expected_tree) + + @parameterized.named_parameters( + ('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), + ('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), + ) + def test_preprocess_record_with_noise(self, stddev, record, expected_tree): + query = tree_aggregation_query.DistributedTreeSumQuery( + stddev=stddev, seed=0) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + + preprocessed_record = query.preprocess_record(params, record) + + self.assertAllClose(preprocessed_record, expected_tree, atol=3 * stddev) + + @parameterized.named_parameters( + ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32), + tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])), + ('binary_test_float', 2, tf.constant([10., 10., 0., 0.], + dtype=tf.float32), + tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])), + ('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32), + tf.ragged.constant( + [10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])), + ('ternary_test_float', 3, tf.constant([10., 10., 0., 0.], + dtype=tf.float32), + tf.ragged.constant( + [10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])), + ) + def test_preprocess_record_clipped(self, arity, record, expected_tree): + query = tree_aggregation_query.DistributedTreeSumQuery( + stddev=0., arity=arity) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + preprocessed_record = query.preprocess_record(params, record) + self.assertAllClose(preprocessed_record, expected_tree) + + @parameterized.named_parameters( + ('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32), + tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])), + ('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32), + tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])), + ('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32), + tf.ragged.constant([[1.], [1., 0., 0.], + [1., 0., 0., 0., 0., 0., 0., 0., 0.]])), + ('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32), + tf.ragged.constant([[1.], [1., 0., 0.], + [1., 0., 0., 0., 0., 0., 0., 0., 0.]])), + ) + def test_get_noised_result(self, arity, record, expected_tree): + query = tree_aggregation_query.DistributedTreeSumQuery( + stddev=0., arity=arity) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + preprocessed_record = query.preprocess_record(params, record) + sample_state, global_state = query.get_noised_result( + preprocessed_record, global_state) + + self.assertAllClose(sample_state, expected_tree) + + @parameterized.named_parameters( + ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32), + tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])), + ('binary_test_float', 2, tf.constant([10., 10., 0., 0.], + dtype=tf.float32), + tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])), + ('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32), + tf.ragged.constant([[10.], [10., 0., 0.], + [5., 5., 0., 0., 0., 0., 0., 0., 0.]])), + ('ternary_test_float', 3, tf.constant([10., 10., 0., 0.], + dtype=tf.float32), + tf.ragged.constant([[10.], [10., 0., 0.], + [5., 5., 0., 0., 0., 0., 0., 0., 0.]])), + ) + def test_get_noised_result_clipped(self, arity, record, expected_tree): + query = tree_aggregation_query.DistributedTreeSumQuery( + stddev=0., arity=arity) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + preprocessed_record = query.preprocess_record(params, record) + sample_state, global_state = query.get_noised_result( + preprocessed_record, global_state) + + self.assertAllClose(sample_state, expected_tree) if __name__ == '__main__':