From a9764e3e7d4795d4ce8bd724b5d9347be110769c Mon Sep 17 00:00:00 2001 From: Zheng Xu Date: Wed, 8 Sep 2021 21:05:28 -0700 Subject: [PATCH] TFF: cleanup the TFP query usage in tff.analytics; remove dependency on internal TFP structure. TFP: remove duplicate TreeRangeSumQuery in `tree_aggregation_query` PiperOrigin-RevId: 395618363 --- .../dp_query/tree_aggregation_query.py | 265 ------------------ .../dp_query/tree_aggregation_query_test.py | 160 ----------- 2 files changed, 425 deletions(-) diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index 1219651..2752dba 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -19,15 +19,11 @@ the leaf nodes of the tree arrive one by one as the time proceeds. The core logic of tree aggregation is implemented in `tree_aggregation.TreeAggregator` and `tree_aggregation.EfficientTreeAggregator`. """ -import distutils -import math import attr import tensorflow as tf from tensorflow_privacy.privacy.analysis import dp_event -from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query from tensorflow_privacy.privacy.dp_query import dp_query -from tensorflow_privacy.privacy.dp_query import gaussian_query from tensorflow_privacy.privacy.dp_query import tree_aggregation # TODO(b/193679963): define `RestartQuery` and move `RestartIndicator` to be @@ -468,264 +464,3 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): record_specs=record_specs, noise_generator=gaussian_noise_generator, use_efficient=use_efficient) - - -# TODO(b/197596864): Remove `TreeRangeSumQuery` from this file after the next -# TFP release - - -@tf.function -def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor: - """A function constructs a complete tree given all the leaf nodes. - - The function takes a 1-D array representing the leaf nodes of a tree and the - tree's arity, and constructs a complete tree by recursively summing the - adjacent children to get the parent until reaching the root node. Because we - assume a complete tree, if the number of leaf nodes does not divide arity, the - leaf nodes will be padded with zeros. - - Args: - leaf_nodes: A 1-D array storing the leaf nodes of the tree. - arity: A `int` for the branching factor of the tree, i.e. the number of - children for each internal node. - - Returns: - `tf.RaggedTensor` representing the tree. For example, if - `leaf_nodes=tf.Tensor([1, 2, 3, 4])` and `arity=2`, then the returned value - should be `tree=tf.RaggedTensor([[10],[3,7],[1,2,3,4]])`. In this way, - `tree[layer][index]` can be used to access the node indexed by (layer, - index) in the tree, - """ - - def pad_zero(leaf_nodes, size): - paddings = [[0, size - len(leaf_nodes)]] - return tf.pad(leaf_nodes, paddings) - - leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32) - num_layers = tf.math.ceil( - tf.math.log(leaf_nodes_size) / - tf.math.log(tf.cast(arity, dtype=tf.float32))) + 1 - leaf_nodes = pad_zero( - leaf_nodes, tf.math.pow(tf.cast(arity, dtype=tf.float32), num_layers - 1)) - - def _shrink_layer(layer: tf.Tensor, arity: int) -> tf.Tensor: - return tf.reduce_sum((tf.reshape(layer, (-1, arity))), 1) - - # The following `tf.while_loop` constructs the tree from bottom up by - # iteratively applying `_shrink_layer` to each layer of the tree. The reason - # for the choice of TF1.0-style `tf.while_loop` is that @tf.function does not - # support auto-translation from python loop to tf loop when loop variables - # contain a `RaggedTensor` whose shape changes across iterations. - - idx = tf.identity(num_layers) - loop_cond = lambda i, h: tf.less_equal(2.0, i) - - def _loop_body(i, h): - return [ - tf.add(i, -1.0), - tf.concat(([_shrink_layer(h[0], arity)], h), axis=0) - ] - - _, tree = tf.while_loop( - loop_cond, - _loop_body, [idx, tf.RaggedTensor.from_tensor([leaf_nodes])], - shape_invariants=[ - idx.get_shape(), - tf.RaggedTensorSpec(dtype=leaf_nodes.dtype, ragged_rank=1) - ]) - - return tree - - -class TreeRangeSumQuery(dp_query.SumAggregationDPQuery): - """Implements dp_query for accurate range queries using tree aggregation. - - Implements a variant of the tree aggregation protocol from. "Is interaction - necessary for distributed private learning?. Adam Smith, Abhradeep Thakurta, - Jalaj Upadhyay." Builds a tree on top of the input record and adds noise to - the tree for differential privacy. Any range query can be decomposed into the - sum of O(log(n)) nodes in the tree compared to O(n) when using a histogram. - Improves efficiency and reduces noise scale. - """ - - @attr.s(frozen=True) - class GlobalState(object): - """Class defining global state for TreeRangeSumQuery. - - Attributes: - arity: The branching factor of the tree (i.e. the number of children each - internal node has). - inner_query_state: The global state of the inner query. - """ - arity = attr.ib() - inner_query_state = attr.ib() - - def __init__(self, - inner_query: dp_query.SumAggregationDPQuery, - arity: int = 2): - """Initializes the `TreeRangeSumQuery`. - - Args: - inner_query: The inner `DPQuery` that adds noise to the tree. - arity: The branching factor of the tree (i.e. the number of children each - internal node has). Defaults to 2. - """ - self._inner_query = inner_query - self._arity = arity - - if self._arity < 1: - raise ValueError(f'Invalid arity={arity} smaller than 2.') - - def initial_global_state(self): - """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" - return TreeRangeSumQuery.GlobalState( - arity=self._arity, - inner_query_state=self._inner_query.initial_global_state()) - - def derive_sample_params(self, global_state): - """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" - return (global_state.arity, - self._inner_query.derive_sample_params( - global_state.inner_query_state)) - - def preprocess_record(self, params, record): - """Implements `tensorflow_privacy.DPQuery.preprocess_record`. - - This method builds the tree, flattens it and applies - `inner_query.preprocess_record` to the flattened tree. - - Args: - params: Hyper-parameters for preprocessing record. - record: A histogram representing the leaf nodes of the tree. - - Returns: - A `tf.Tensor` representing the flattened version of the preprocessed tree. - """ - arity, inner_query_params = params - preprocessed_record = _build_tree_from_leaf(record, arity).flat_values - # The following codes reshape the output vector so the output shape of can - # be statically inferred. This is useful when used with - # `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know - # the output shape of this function statically and explicitly. - preprocessed_record_shape = [ - (self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) - - 1) // (self._arity - 1) - ] - preprocessed_record = tf.reshape(preprocessed_record, - preprocessed_record_shape) - preprocessed_record = self._inner_query.preprocess_record( - inner_query_params, preprocessed_record) - - return preprocessed_record - - def get_noised_result(self, sample_state, global_state): - """Implements `tensorflow_privacy.DPQuery.get_noised_result`. - - This function re-constructs the `tf.RaggedTensor` from the flattened tree - output by `preprocess_records.` - - Args: - sample_state: A `tf.Tensor` for the flattened tree. - global_state: The global state of the protocol. - - Returns: - A `tf.RaggedTensor` representing the tree. - """ - # The [0] is needed because of how tf.RaggedTensor.from_two_splits works. - # print(tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6], - # row_splits=[0, 4, 4, 7, 8, 8])) - # - # This part is not written in tensorflow and will be executed on the server - # side instead of the client side if used with - # tff.aggregators.DifferentiallyPrivateFactory for federated learning. - sample_state, inner_query_state, _ = self._inner_query.get_noised_result( - sample_state, global_state.inner_query_state) - new_global_state = TreeRangeSumQuery.GlobalState( - arity=global_state.arity, inner_query_state=inner_query_state) - - row_splits = [0] + [ - (self._arity**(x + 1) - 1) // (self._arity - 1) for x in range( - math.floor(math.log(sample_state.shape[0], self._arity)) + 1) - ] - tree = tf.RaggedTensor.from_row_splits( - values=sample_state, row_splits=row_splits) - event = dp_event.UnsupportedDpEvent() - return tree, new_global_state, event - - @classmethod - def build_central_gaussian_query(cls, - l2_norm_clip: float, - stddev: float, - arity: int = 2): - """Returns `TreeRangeSumQuery` with central Gaussian noise. - - Args: - l2_norm_clip: Each record should be clipped so that it has L2 norm at most - `l2_norm_clip`. - stddev: Stddev of the central Gaussian noise. - arity: The branching factor of the tree (i.e. the number of children each - internal node has). Defaults to 2. - """ - if l2_norm_clip <= 0: - raise ValueError(f'`l2_norm_clip` must be positive, got {l2_norm_clip}.') - - if stddev < 0: - raise ValueError(f'`stddev` must be non-negative, got {stddev}.') - - if arity < 2: - raise ValueError(f'`arity` must be at least 2, got {arity}.') - - inner_query = gaussian_query.GaussianSumQuery(l2_norm_clip, stddev) - - return cls(arity=arity, inner_query=inner_query) - - @classmethod - def build_distributed_discrete_gaussian_query(cls, - l2_norm_bound: float, - local_stddev: float, - arity: int = 2): - """Returns `TreeRangeSumQuery` with central Gaussian noise. - - Args: - l2_norm_bound: Each record should be clipped so that it has L2 norm at - most `l2_norm_bound`. - local_stddev: Scale/stddev of the local discrete Gaussian noise. - arity: The branching factor of the tree (i.e. the number of children each - internal node has). Defaults to 2. - """ - if l2_norm_bound <= 0: - raise ValueError( - f'`l2_clip_bound` must be positive, got {l2_norm_bound}.') - - if local_stddev < 0: - raise ValueError( - f'`local_stddev` must be non-negative, got {local_stddev}.') - - if arity < 2: - raise ValueError(f'`arity` must be at least 2, got {arity}.') - - inner_query = distributed_discrete_gaussian_query.DistributedDiscreteGaussianSumQuery( - l2_norm_bound, local_stddev) - - return cls(arity=arity, inner_query=inner_query) - - -def _get_add_noise(stddev, seed: int = None): - """Utility function to decide which `add_noise` to use according to tf version.""" - if distutils.version.LooseVersion( - tf.__version__) < distutils.version.LooseVersion('2.0.0'): - - # The seed should be only used for testing purpose. - if seed is not None: - tf.random.set_seed(seed) - - def add_noise(v): - return v + tf.random.normal( - tf.shape(input=v), stddev=stddev, dtype=v.dtype) - else: - random_normal = tf.random_normal_initializer(stddev=stddev, seed=seed) - - def add_noise(v): - return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype) - - return add_noise diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py index 1115f40..699c890 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py @@ -13,8 +13,6 @@ # limitations under the License. """Tests for `tree_aggregation_query`.""" -import math - from absl.testing import parameterized import numpy as np import tensorflow as tf @@ -470,163 +468,5 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase): self.assertEqual(query_result, expected) -class BuildTreeTest(tf.test.TestCase, parameterized.TestCase): - - @parameterized.product( - leaf_nodes_size=[1, 2, 3, 4, 5], - arity=[2, 3], - dtype=[tf.int32, tf.float32], - ) - def test_build_tree_from_leaf(self, leaf_nodes_size, arity, dtype): - """Test whether `_build_tree_from_leaf` will output the correct tree.""" - - leaf_nodes = tf.cast(tf.range(leaf_nodes_size), dtype) - depth = math.ceil(math.log(leaf_nodes_size, arity)) + 1 - - tree = tree_aggregation_query._build_tree_from_leaf(leaf_nodes, arity) - - self.assertEqual(depth, tree.shape[0]) - - for layer in range(depth): - reverse_depth = tree.shape[0] - layer - 1 - span_size = arity**reverse_depth - for idx in range(arity**layer): - left = idx * span_size - right = (idx + 1) * span_size - expected_value = sum(leaf_nodes[left:right]) - self.assertEqual(tree[layer][idx], expected_value) - - -class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase): - - @parameterized.product( - inner_query=['central', 'distributed'], - params=[(0., 1., 2), (1., -1., 2), (1., 1., 1)], - ) - def test_raises_error(self, inner_query, params): - clip_norm, stddev, arity = params - with self.assertRaises(ValueError): - if inner_query == 'central': - tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( - clip_norm, stddev, arity) - elif inner_query == 'distributed': - tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( - clip_norm, stddev, arity) - - @parameterized.product( - inner_query=['central', 'distributed'], - clip_norm=[0.1, 1.0, 10.0], - stddev=[0.1, 1.0, 10.0]) - def test_initial_global_state_type(self, inner_query, clip_norm, stddev): - - if inner_query == 'central': - query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( - clip_norm, stddev) - elif inner_query == 'distributed': - query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( - clip_norm, stddev) - global_state = query.initial_global_state() - self.assertIsInstance(global_state, - tree_aggregation_query.TreeRangeSumQuery.GlobalState) - - @parameterized.product( - inner_query=['central', 'distributed'], - clip_norm=[0.1, 1.0, 10.0], - stddev=[0.1, 1.0, 10.0], - arity=[2, 3, 4]) - def test_derive_sample_params(self, inner_query, clip_norm, stddev, arity): - if inner_query == 'central': - query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( - clip_norm, stddev, arity) - elif inner_query == 'distributed': - query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( - clip_norm, stddev, arity) - global_state = query.initial_global_state() - derived_arity, inner_query_state = query.derive_sample_params(global_state) - self.assertAllClose(derived_arity, arity) - if inner_query == 'central': - self.assertAllClose(inner_query_state, clip_norm) - elif inner_query == 'distributed': - self.assertAllClose(inner_query_state.l2_norm_bound, clip_norm) - self.assertAllClose(inner_query_state.local_stddev, stddev) - - @parameterized.product( - (dict(arity=2, expected_tree=[1, 1, 0, 1, 0, 0, 0]), - dict(arity=3, expected_tree=[1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])), - inner_query=['central', 'distributed'], - ) - def test_preprocess_record(self, inner_query, arity, expected_tree): - if inner_query == 'central': - query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( - 10., 0., arity) - record = tf.constant([1, 0, 0, 0], dtype=tf.float32) - expected_tree = tf.cast(expected_tree, tf.float32) - elif inner_query == 'distributed': - query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( - 10., 0., arity) - record = tf.constant([1, 0, 0, 0], dtype=tf.int32) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - self.assertAllClose(preprocessed_record, expected_tree) - - @parameterized.named_parameters( - ('stddev_1', 1, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]), - ('stddev_0_1', 4, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]), - ) - def test_distributed_preprocess_record_with_noise(self, local_stddev, record, - expected_tree): - query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( - 10., local_stddev) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - - preprocessed_record = query.preprocess_record(params, record) - - self.assertAllClose( - preprocessed_record, expected_tree, atol=10 * local_stddev) - - @parameterized.product( - (dict( - arity=2, - expected_tree=tf.ragged.constant([[1], [1, 0], [1, 0, 0, 0]])), - dict( - arity=3, - expected_tree=tf.ragged.constant([[1], [1, 0, 0], - [1, 0, 0, 0, 0, 0, 0, 0, 0]]))), - inner_query=['central', 'distributed'], - ) - def test_get_noised_result(self, inner_query, arity, expected_tree): - if inner_query == 'central': - query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( - 10., 0., arity) - record = tf.constant([1, 0, 0, 0], dtype=tf.float32) - expected_tree = tf.cast(expected_tree, tf.float32) - elif inner_query == 'distributed': - query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( - 10., 0., arity) - record = tf.constant([1, 0, 0, 0], dtype=tf.int32) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - sample_state, global_state, _ = query.get_noised_result( - preprocessed_record, global_state) - - self.assertAllClose(sample_state, expected_tree) - - @parameterized.product(stddev=[0.1, 1.0, 10.0]) - def test_central_get_noised_result_with_noise(self, stddev): - query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query( - 10., stddev) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.])) - sample_state, global_state, _ = query.get_noised_result( - preprocessed_record, global_state) - - self.assertAllClose( - sample_state, tf.ragged.constant([[1.], [1., 0.]]), atol=10 * stddev) - - if __name__ == '__main__': tf.test.main()