From caf6f36b80b4f2dd4ff0ea7fcbd6c02054c3d71c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 7 Jul 2021 15:54:53 -0700 Subject: [PATCH] (1) add `CentralTreeSumQuery` and `DistributedTreeSumQuery` to tree_aggregation_query.py. (2) move `build_tree_from_leaf` to tree_aggregation_query.py together with `CentralTreeSumQuery`. PiperOrigin-RevId: 383511025 --- .../privacy/dp_query/tree_aggregation.py | 80 ----- .../dp_query/tree_aggregation_query.py | 317 +++++++++++++++++- .../dp_query/tree_aggregation_query_test.py | 285 +++++++++++++++- .../privacy/dp_query/tree_aggregation_test.py | 34 -- 4 files changed, 594 insertions(+), 122 deletions(-) diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation.py index 3561be3..ba8ea2f 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation.py @@ -17,10 +17,6 @@ based on tree aggregation. When using an appropriate noise function (e.g., Gaussian noise), it allows for efficient differentially private algorithms under continual observation, without prior subsampling or shuffling assumptions. - -`build_tree` constructs a tree given the leaf nodes by recursively summing the -children nodes to get the parent node. It allows for efficient range queries and -other statistics such as quantiles on the leaf nodes. """ import abc @@ -449,79 +445,3 @@ class EfficientTreeAggregator(): level_buffer_idx=new_level_buffer_idx, value_generator_state=value_generator_state) return cumsum, new_state - - -@tf.function -def build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor: - """A function constructs a complete tree given all the leaf nodes. - - The function takes a 1-D array representing the leaf nodes of a tree and the - tree's arity, and constructs a complete tree by recursively summing the - adjacent children to get the parent until reaching the root node. Because we - assume a complete tree, if the number of leaf nodes does not divide arity, the - leaf nodes will be padded with zeros. - - Args: - leaf_nodes: A 1-D array storing the leaf nodes of the tree. - arity: A `int` for the branching factor of the tree, i.e. the number of - children for each internal node. - - Returns: - `tf.RaggedTensor` representing the tree. For example, if - `leaf_nodes=tf.Tensor([1, 2, 3, 4])` and `arity=2`, then the returned value - should be `tree=tf.RaggedTensor([[10],[3,7],[1,2,3,4]])`. In this way, - `tree[layer][index]` can be used to access the node indexed by (layer, - index) in the tree, - - Raises: - ValueError: if parameters don't meet expectations. There are two situations - where the error is raised: (1) the input tensor has length smaller than 1; - (2) The arity is less than 2. - """ - - if len(leaf_nodes) <= 0: - raise ValueError( - 'The number of leaf nodes should at least be 1.' - f'However, an array of length {len(leaf_nodes)} is detected') - - if arity <= 1: - raise ValueError('The branching factor should be at least 2.' - f'However, a branching factor of {arity} is detected.') - - def pad_zero(leaf_nodes, size): - paddings = [[0, size - len(leaf_nodes)]] - return tf.pad(leaf_nodes, paddings) - - leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32) - num_layers = tf.math.ceil( - tf.math.log(leaf_nodes_size) / - tf.math.log(tf.constant(arity, dtype=tf.float32))) + 1 - leaf_nodes = pad_zero(leaf_nodes, tf.math.pow(float(arity), num_layers - 1)) - - def _shrink_layer(layer: tf.Tensor, arity: int) -> tf.Tensor: - return tf.reduce_sum((tf.reshape(layer, (-1, arity))), 1) - - # The following `tf.while_loop` constructs the tree from bottom up by - # iteratively applying `_shrink_layer` to each layer of the tree. The reason - # for the choice of TF1.0-style `tf.while_loop` is that @tf.function does not - # support auto-translation from python loop to tf loop when loop variables - # contain a `RaggedTensor` whose shape changes across iterations. - - idx = tf.identity(num_layers) - loop_cond = lambda i, h: tf.less_equal(2.0, i) - - def _loop_body(i, h): - return [ - tf.add(i, -1.0), - tf.concat(([_shrink_layer(h[0], arity)], h), axis=0) - ] - - _, tree = tf.while_loop( - loop_cond, - _loop_body, [idx, tf.RaggedTensor.from_tensor([leaf_nodes])], - shape_invariants=[ - idx.get_shape(), - tf.RaggedTensorSpec(dtype=leaf_nodes.dtype, ragged_rank=1) - ]) - - return tree diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index fb7dc76..79bc243 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -11,9 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""DPQuery for continual observation queries relying on `tree_aggregation`.""" +"""`DPQuery`s for differentially private tree aggregation protocols. +`TreeCumulativeSumQuery` and `TreeResidualSumQuery` are `DPQuery`s for continual +online observation queries relying on `tree_aggregation`. 'Online' means that +the leaf nodes of the tree arrive one by one as the time proceeds. The leaves +are vector records as defined in `dp_query.DPQuery`. + +`CentralTreeSumQuery` and `DistributedTreeSumQuery` are `DPQuery`s for +central/distributed offline tree aggregation protocol. 'Offline' means all the +leaf nodes are ready before the protocol starts. Each record, different from +what is defined in `dp_query.DPQuery`, is a histogram (i.e. the leaf nodes). +""" +import distutils +import math import attr + import tensorflow as tf from tensorflow_privacy.privacy.dp_query import dp_query @@ -31,11 +44,11 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): Attributes: clip_fn: Callable that specifies clipping function. `clip_fn` receives two arguments: a flat list of vars in a record and a `clip_value` to clip the - corresponding record, e.g. clip_fn(flat_record, clip_value). + corresponding record, e.g. clip_fn(flat_record, clip_value). clip_value: float indicating the value at which to clip the record. record_specs: `Collection[tf.TensorSpec]` specifying shapes of records. - tree_aggregator: `tree_aggregation.TreeAggregator` initialized with - user defined `noise_generator`. `noise_generator` is a + tree_aggregator: `tree_aggregation.TreeAggregator` initialized with user + defined `noise_generator`. `noise_generator` is a `tree_aggregation.ValueGenerator` to generate the noise value for a tree node. Noise stdandard deviation is specified outside the `dp_query` by the user when defining `noise_fn` and should have order @@ -209,7 +222,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): Attributes: clip_fn: Callable that specifies clipping function. `clip_fn` receives two arguments: a flat list of vars in a record and a `clip_value` to clip the - corresponding record, e.g. clip_fn(flat_record, clip_value). + corresponding record, e.g. clip_fn(flat_record, clip_value). clip_value: float indicating the value at which to clip the record. record_specs: A nested structure of `tf.TensorSpec`s specifying structure and shapes of records. @@ -364,3 +377,297 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): record_specs=record_specs, noise_generator=gaussian_noise_generator, use_efficient=use_efficient) + + +@tf.function +def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor: + """A function constructs a complete tree given all the leaf nodes. + + The function takes a 1-D array representing the leaf nodes of a tree and the + tree's arity, and constructs a complete tree by recursively summing the + adjacent children to get the parent until reaching the root node. Because we + assume a complete tree, if the number of leaf nodes does not divide arity, the + leaf nodes will be padded with zeros. + + Args: + leaf_nodes: A 1-D array storing the leaf nodes of the tree. + arity: A `int` for the branching factor of the tree, i.e. the number of + children for each internal node. + + Returns: + `tf.RaggedTensor` representing the tree. For example, if + `leaf_nodes=tf.Tensor([1, 2, 3, 4])` and `arity=2`, then the returned value + should be `tree=tf.RaggedTensor([[10],[3,7],[1,2,3,4]])`. In this way, + `tree[layer][index]` can be used to access the node indexed by (layer, + index) in the tree, + """ + + def pad_zero(leaf_nodes, size): + paddings = [[0, size - len(leaf_nodes)]] + return tf.pad(leaf_nodes, paddings) + + leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32) + num_layers = tf.math.ceil( + tf.math.log(leaf_nodes_size) / + tf.math.log(tf.cast(arity, dtype=tf.float32))) + 1 + leaf_nodes = pad_zero( + leaf_nodes, tf.math.pow(tf.cast(arity, dtype=tf.float32), num_layers - 1)) + + def _shrink_layer(layer: tf.Tensor, arity: int) -> tf.Tensor: + return tf.reduce_sum((tf.reshape(layer, (-1, arity))), 1) + + # The following `tf.while_loop` constructs the tree from bottom up by + # iteratively applying `_shrink_layer` to each layer of the tree. The reason + # for the choice of TF1.0-style `tf.while_loop` is that @tf.function does not + # support auto-translation from python loop to tf loop when loop variables + # contain a `RaggedTensor` whose shape changes across iterations. + + idx = tf.identity(num_layers) + loop_cond = lambda i, h: tf.less_equal(2.0, i) + + def _loop_body(i, h): + return [ + tf.add(i, -1.0), + tf.concat(([_shrink_layer(h[0], arity)], h), axis=0) + ] + + _, tree = tf.while_loop( + loop_cond, + _loop_body, [idx, tf.RaggedTensor.from_tensor([leaf_nodes])], + shape_invariants=[ + idx.get_shape(), + tf.RaggedTensorSpec(dtype=leaf_nodes.dtype, ragged_rank=1) + ]) + + return tree + + +def _get_add_noise(stddev): + """Utility function to decide which `add_noise` to use according to tf version.""" + if distutils.version.LooseVersion( + tf.__version__) < distutils.version.LooseVersion('2.0.0'): + + def add_noise(v): + return v + tf.random.normal( + tf.shape(input=v), stddev=stddev, dtype=v.dtype) + else: + random_normal = tf.random_normal_initializer(stddev=stddev) + + def add_noise(v): + return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype) + + return add_noise + + +class CentralTreeSumQuery(dp_query.SumAggregationDPQuery): + """Implements dp_query for differentially private tree aggregation protocol. + + Implements a central variant of the tree aggregation protocol from the paper + "'Is interaction necessary for distributed private learning?.' Adam Smith, + Abhradeep Thakurta, Jalaj Upadhyay" by replacing their local randomizer with + gaussian mechanism. The first step is to clip the clients' local updates (i.e. + a 1-D array containing the leaf nodes of the tree) by L1 norm to make sure it + does not exceed a prespecified upper bound. The second step is to construct + the tree on the clipped update. The third step is to add independent gaussian + noise to each node in the tree. The returned tree can support efficient and + accurate range queries with differential privacy. + """ + + @attr.s(frozen=True) + class GlobalState(object): + """Class defining global state for `CentralTreeSumQuery`. + + Attributes: + stddev: The stddev of the noise added to each node in the tree. + arity: The branching factor of the tree (i.e. the number of children each + internal node has). + l1_bound: An upper bound on the L1 norm of the input record. This is + needed to bound the sensitivity and deploy differential privacy. + """ + stddev = attr.ib() + arity = attr.ib() + l1_bound = attr.ib() + + def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10): + """Initializes the `CentralTreeSumQuery`. + + Args: + stddev: The stddev of the noise added to each internal node of the + constructed tree. + arity: The branching factor of the tree. + l1_bound: An upper bound on the L1 norm of the input record. This is + needed to bound the sensitivity and deploy differential privacy. + """ + self._stddev = stddev + self._arity = arity + self._l1_bound = l1_bound + + def initial_global_state(self): + """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" + return CentralTreeSumQuery.GlobalState( + stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound) + + def derive_sample_params(self, global_state): + """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" + return global_state.l1_bound + + def preprocess_record(self, params, record): + """Implements `tensorflow_privacy.DPQuery.preprocess_record`.""" + casted_record = tf.cast(record, tf.float32) + l1_norm = tf.norm(casted_record, ord=1) + + l1_bound = tf.cast(params, tf.float32) + + preprocessed_record, _ = tf.clip_by_global_norm([casted_record], + l1_bound, + use_norm=l1_norm) + + return preprocessed_record[0] + + def get_noised_result(self, sample_state, global_state): + """Implements `tensorflow_privacy.DPQuery.get_noised_result`. + + Args: + sample_state: a frequency histogram. + global_state: hyper-parameters of the query. + + Returns: + a `tf.RaggedTensor` representing the tree built on top of `sample_state`. + The jth node on the ith layer of the tree can be accessed by tree[i][j] + where tree is the returned value. + """ + add_noise = _get_add_noise(self._stddev) + tree = _build_tree_from_leaf(sample_state, global_state.arity) + return tf.nest.map_structure( + add_noise, tree, expand_composites=True), global_state + + +class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery): + """Implements dp_query for differentially private tree aggregation protocol. + + The difference from `CentralTreeSumQuery` is that the tree construction and + gaussian noise addition happen in `preprocess_records`. The difference only + takes effect when used together with + `tff.aggregators.DifferentiallyPrivateFactory`. In other cases, this class + should be treated as equal with `CentralTreeSumQuery`. + + Implements a distributed version of the tree aggregation protocol from. "Is + interaction necessary for distributed private learning?." by replacing their + local randomizer with gaussian mechanism. The first step is to check the L1 + norm of the clients' local updates (i.e. a 1-D array containing the leaf nodes + of the tree) to make sure it does not exceed a prespecified upper bound. The + second step is to construct the tree. The third step is to add independent + gaussian noise to each node in the tree. The returned tree can support + efficient and accurate range queries with differential privacy. + """ + + @attr.s(frozen=True) + class GlobalState(object): + """Class defining global state for DistributedTreeSumQuery. + + Attributes: + stddev: The stddev of the noise added to each internal node in the + constructed tree. + arity: The branching factor of the tree (i.e. the number of children each + internal node has). + l1_bound: An upper bound on the L1 norm of the input record. This is + needed to bound the sensitivity and deploy differential privacy. + """ + stddev = attr.ib() + arity = attr.ib() + l1_bound = attr.ib() + + def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10): + """Initializes the `DistributedTreeSumQuery`. + + Args: + stddev: The stddev of the noise added to each node in the tree. + arity: The branching factor of the tree. + l1_bound: An upper bound on the L1 norm of the input record. This is + needed to bound the sensitivity and deploy differential privacy. + """ + self._stddev = stddev + self._arity = arity + self._l1_bound = l1_bound + + def initial_global_state(self): + """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" + return DistributedTreeSumQuery.GlobalState( + stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound) + + def derive_sample_params(self, global_state): + """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" + return (global_state.stddev, global_state.arity, global_state.l1_bound) + + def preprocess_record(self, params, record): + """Implements `tensorflow_privacy.DPQuery.preprocess_record`. + + This method clips the input record by L1 norm, constructs a tree on top of + it, and adds gaussian noise to each node of the tree for differential + privacy. Unlike `get_noised_result` in `CentralTreeSumQuery`, this function + flattens the `tf.RaggedTensor` before outputting it. This is useful when + used inside `tff.aggregators.DifferentiallyPrivateFactory` because it does + not accept ragged output tensor. + + Args: + params: hyper-parameters for preprocessing record, (stddev, aritry, + l1_bound) + record: leaf nodes for the tree. + + Returns: + `tf.Tensor` representing the flattened version of the tree. + """ + _, arity, l1_bound_ = params + l1_bound = tf.cast(l1_bound_, tf.float32) + + casted_record = tf.cast(record, tf.float32) + l1_norm = tf.norm(casted_record, ord=1) + + preprocessed_record, _ = tf.clip_by_global_norm([casted_record], + l1_bound, + use_norm=l1_norm) + preprocessed_record = preprocessed_record[0] + + add_noise = _get_add_noise(self._stddev) + tree = _build_tree_from_leaf(preprocessed_record, arity) + noisy_tree = tf.nest.map_structure(add_noise, tree, expand_composites=True) + + # The following codes reshape the output vector so the output shape of can + # be statically inferred. This is useful when used with + # `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know + # the output shape of this function statically and explicitly. + flat_noisy_tree = noisy_tree.flat_values + flat_tree_shape = [ + (self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) - + 1) // (self._arity - 1) + ] + return tf.reshape(flat_noisy_tree, flat_tree_shape) + + def get_noised_result(self, sample_state, global_state): + """Implements `tensorflow_privacy.DPQuery.get_noised_result`. + + This function re-constructs the `tf.RaggedTensor` from the flattened tree + output by `preprocess_records.` + + Args: + sample_state: `tf.Tensor` for the flattened tree. + global_state: hyper-parameters including noise multiplier, the branching + factor of the tree and the maximum records per user. + + Returns: + a `tf.RaggedTensor` for the tree. + """ + # The [0] is needed because of how tf.RaggedTensor.from_two_splits works. + # print(tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6], + # row_splits=[0, 4, 4, 7, 8, 8])) + # + # This part is not written in tensorflow and will be executed on the server + # side instead of the client side if used with + # tff.aggregators.DifferentiallyPrivateFactory for federated learning. + row_splits = [0] + [ + (self._arity**(x + 1) - 1) // (self._arity - 1) for x in range( + math.floor(math.log(sample_state.shape[0], self._arity)) + 1) + ] + tree = tf.RaggedTensor.from_row_splits( + values=sample_state, row_splits=row_splits) + return tree, global_state diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py index 8cf2157..34f2c9c 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py @@ -13,16 +13,15 @@ # limitations under the License. """Tests for `tree_aggregation_query`.""" -from absl.testing import parameterized +import math +from absl.testing import parameterized import numpy as np import tensorflow as tf - from tensorflow_privacy.privacy.dp_query import test_utils from tensorflow_privacy.privacy.dp_query import tree_aggregation from tensorflow_privacy.privacy.dp_query import tree_aggregation_query - STRUCT_RECORD = [ tf.constant([[2.0, 0.0], [0.0, 1.0]]), tf.constant([-1.0, 0.0]) @@ -55,6 +54,7 @@ def _get_noise_fn(specs, stddev=NOISE_STD, seed=1): def _get_no_noise_fn(specs): shape = tf.nest.map_structure(lambda spec: spec.shape, specs) + def no_noise_fn(): return tf.nest.map_structure(tf.zeros, shape) @@ -73,6 +73,7 @@ def _get_l2_clip_fn(): def _get_l_infty_clip_fn(): def l_infty_clip_fn(record_as_list, clip_value): + def clip(record): return tf.clip_by_value( record, clip_value_min=-clip_value, clip_value_max=clip_value) @@ -395,5 +396,283 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase): self.assertIsInstance(query._tree_aggregator, tree_class) +class BuildTreeTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.product( + leaf_nodes_size=[1, 2, 3, 4, 5], + arity=[2, 3], + dtype=[tf.int32, tf.float32], + ) + def test_build_tree_from_leaf(self, leaf_nodes_size, arity, dtype): + """Test whether `_build_tree_from_leaf` will output the correct tree.""" + + leaf_nodes = tf.cast(tf.range(leaf_nodes_size), dtype) + depth = math.ceil(math.log(leaf_nodes_size, arity)) + 1 + + tree = tree_aggregation_query._build_tree_from_leaf(leaf_nodes, arity) + + self.assertEqual(depth, tree.shape[0]) + + for layer in range(depth): + reverse_depth = tree.shape[0] - layer - 1 + span_size = arity**reverse_depth + for idx in range(arity**layer): + left = idx * span_size + right = (idx + 1) * span_size + expected_value = sum(leaf_nodes[left:right]) + self.assertEqual(tree[layer][idx], expected_value) + + +class CentralTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase): + + def test_initial_global_state_type(self): + + query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD) + global_state = query.initial_global_state() + self.assertIsInstance( + global_state, tree_aggregation_query.CentralTreeSumQuery.GlobalState) + + def test_derive_sample_params(self): + query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + self.assertAllClose(params, 10.) + + @parameterized.named_parameters( + ('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32)), + ('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32)), + ('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32)), + ('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], + dtype=tf.float32)), + ) + def test_preprocess_record(self, arity, record): + query = tree_aggregation_query.CentralTreeSumQuery( + stddev=NOISE_STD, arity=arity) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + preprocessed_record = query.preprocess_record(params, record) + + self.assertAllClose(preprocessed_record, record) + + @parameterized.named_parameters( + ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32), + tf.constant([5, 5, 0, 0], dtype=tf.int32)), + ('binary_test_float', 2, tf.constant( + [10., 10., 0., 0.], + dtype=tf.float32), tf.constant([5., 5., 0., 0.], dtype=tf.float32)), + ('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32), + tf.constant([5, 5, 0, 0], dtype=tf.int32)), + ('ternary_test_float', 3, tf.constant([10., 10., 0., 0.], + dtype=tf.float32), + tf.constant([5., 5., 0., 0.], dtype=tf.float32)), + ) + def test_preprocess_record_clipped(self, arity, record, + expected_clipped_value): + query = tree_aggregation_query.CentralTreeSumQuery( + stddev=NOISE_STD, arity=arity) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + preprocessed_record = query.preprocess_record(params, record) + self.assertAllClose(preprocessed_record, expected_clipped_value) + + @parameterized.named_parameters( + ('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32), + tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])), + ('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32), + tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])), + ('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32), + tf.ragged.constant([[1.], [1., 0., 0.], + [1., 0., 0., 0., 0., 0., 0., 0., 0.]])), + ('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32), + tf.ragged.constant([[1.], [1., 0., 0.], + [1., 0., 0., 0., 0., 0., 0., 0., 0.]])), + ) + def test_get_noised_result(self, arity, record, expected_tree): + query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + preprocessed_record = query.preprocess_record(params, record) + sample_state, global_state = query.get_noised_result( + preprocessed_record, global_state) + + self.assertAllClose(sample_state, expected_tree) + + @parameterized.named_parameters( + ('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), + ('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), + ) + def test_get_noised_result_with_noise(self, stddev, record, expected_tree): + query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + preprocessed_record = query.preprocess_record(params, record) + sample_state_list = [] + for _ in range(1000): + sample_state, _ = query.get_noised_result(preprocessed_record, + global_state) + sample_state_list.append(sample_state.flat_values.numpy()) + expectation = np.mean(sample_state_list, axis=0) + variance = np.std(sample_state_list, axis=0) + + self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4) + self.assertAllClose( + variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4) + + @parameterized.named_parameters( + ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32), + tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])), + ('binary_test_float', 2, tf.constant([10., 10., 0., 0.], + dtype=tf.float32), + tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])), + ('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32), + tf.ragged.constant([[10.], [10., 0., 0.], + [5., 5., 0., 0., 0., 0., 0., 0., 0.]])), + ('ternary_test_float', 3, tf.constant([10., 10., 0., 0.], + dtype=tf.float32), + tf.ragged.constant([[10.], [10., 0., 0.], + [5., 5., 0., 0., 0., 0., 0., 0., 0.]])), + ) + def test_get_noised_result_clipped(self, arity, record, expected_tree): + query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + preprocessed_record = query.preprocess_record(params, record) + sample_state, global_state = query.get_noised_result( + preprocessed_record, global_state) + + self.assertAllClose(sample_state, expected_tree) + + +class DistributedTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase): + + def test_initial_global_state_type(self): + + query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD) + global_state = query.initial_global_state() + self.assertIsInstance( + global_state, + tree_aggregation_query.DistributedTreeSumQuery.GlobalState) + + def test_derive_sample_params(self): + query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD) + global_state = query.initial_global_state() + stddev, arity, l1_bound = query.derive_sample_params( + global_state) + self.assertAllClose(stddev, NOISE_STD) + self.assertAllClose(arity, 2) + self.assertAllClose(l1_bound, 10) + + @parameterized.named_parameters( + ('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32), + tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])), + ('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32), + tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])), + ('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32), + tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0. + ])), + ('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32), + tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0. + ])), + ) + def test_preprocess_record(self, arity, record, expected_tree): + query = tree_aggregation_query.DistributedTreeSumQuery( + stddev=0., arity=arity) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + preprocessed_record = query.preprocess_record(params, record) + self.assertAllClose(preprocessed_record, expected_tree) + + @parameterized.named_parameters( + ('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), + ('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), + ) + def test_preprocess_record_with_noise(self, stddev, record, expected_tree): + query = tree_aggregation_query.DistributedTreeSumQuery(stddev=stddev) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + + preprocessed_record_list = [] + for _ in range(1000): + preprocessed_record = query.preprocess_record(params, record) + preprocessed_record_list.append(preprocessed_record.numpy()) + + expectation = np.mean(preprocessed_record_list, axis=0) + variance = np.std(preprocessed_record_list, axis=0) + + self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4) + self.assertAllClose( + variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4) + + @parameterized.named_parameters( + ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32), + tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])), + ('binary_test_float', 2, tf.constant([10., 10., 0., 0.], + dtype=tf.float32), + tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])), + ('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32), + tf.ragged.constant( + [10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])), + ('ternary_test_float', 3, tf.constant([10., 10., 0., 0.], + dtype=tf.float32), + tf.ragged.constant( + [10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])), + ) + def test_preprocess_record_clipped(self, arity, record, expected_tree): + query = tree_aggregation_query.DistributedTreeSumQuery( + stddev=0., arity=arity) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + preprocessed_record = query.preprocess_record(params, record) + self.assertAllClose(preprocessed_record, expected_tree) + + @parameterized.named_parameters( + ('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32), + tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])), + ('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32), + tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])), + ('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32), + tf.ragged.constant([[1.], [1., 0., 0.], + [1., 0., 0., 0., 0., 0., 0., 0., 0.]])), + ('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32), + tf.ragged.constant([[1.], [1., 0., 0.], + [1., 0., 0., 0., 0., 0., 0., 0., 0.]])), + ) + def test_get_noised_result(self, arity, record, expected_tree): + query = tree_aggregation_query.DistributedTreeSumQuery( + stddev=0., arity=arity) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + preprocessed_record = query.preprocess_record(params, record) + sample_state, global_state = query.get_noised_result( + preprocessed_record, global_state) + + self.assertAllClose(sample_state, expected_tree) + + @parameterized.named_parameters( + ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32), + tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])), + ('binary_test_float', 2, tf.constant([10., 10., 0., 0.], + dtype=tf.float32), + tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])), + ('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32), + tf.ragged.constant([[10.], [10., 0., 0.], + [5., 5., 0., 0., 0., 0., 0., 0., 0.]])), + ('ternary_test_float', 3, tf.constant([10., 10., 0., 0.], + dtype=tf.float32), + tf.ragged.constant([[10.], [10., 0., 0.], + [5., 5., 0., 0., 0., 0., 0., 0., 0.]])), + ) + def test_get_noised_result_clipped(self, arity, record, expected_tree): + query = tree_aggregation_query.DistributedTreeSumQuery( + stddev=0., arity=arity) + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + preprocessed_record = query.preprocess_record(params, record) + sample_state, global_state = query.get_noised_result( + preprocessed_record, global_state) + + self.assertAllClose(sample_state, expected_tree) + + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py index 9a237ad..9a8be35 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py @@ -365,39 +365,5 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase): self.assertAllEqual(gstate, gstate2) -class BuildTreeTest(tf.test.TestCase, parameterized.TestCase): - - @parameterized.product( - leaf_nodes_size=[1, 2, 3, 4, 5], - arity=[2, 3], - dtype=[tf.int32, tf.float32], - ) - def test_build_tree_from_leaf(self, leaf_nodes_size, arity, dtype): - """Test whether `build_tree_from_leaf` will output the correct tree.""" - - leaf_nodes = tf.cast(tf.range(leaf_nodes_size), dtype) - depth = math.ceil(math.log(leaf_nodes_size, arity)) + 1 - - tree = tree_aggregation.build_tree_from_leaf(leaf_nodes, arity) - - self.assertEqual(depth, tree.shape[0]) - - for layer in range(depth): - reverse_depth = tree.shape[0] - layer - 1 - span_size = arity**reverse_depth - for idx in range(arity**layer): - left = idx * span_size - right = (idx + 1) * span_size - expected_value = sum(leaf_nodes[left:right]) - self.assertEqual(tree[layer][idx], expected_value) - - @parameterized.named_parameters(('negative_arity', [1], -1), - ('empty_hist', [], 2)) - def test_value_error_raises(self, leaf_nodes, arity): - """Test whether `build_tree_from_leaf` will raise the correct error when the input is illegal.""" - with self.assertRaises(ValueError): - tree_aggregation.build_tree_from_leaf(leaf_nodes, arity) - - if __name__ == '__main__': tf.test.main()