diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation.py index 58d4094..1f1c234 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation.py @@ -13,10 +13,14 @@ # limitations under the License. """Tree aggregation algorithm. -This algorithm computes cumulative sums of noise based on tree aggregation. When -using an appropriate noise function (e.g., Gaussian noise), it allows for -efficient differentially private algorithms under continual observation, without -prior subsampling or shuffling assumptions. +`TreeAggregator` and `EfficientTreeAggregator` compute cumulative sums of noise +based on tree aggregation. When using an appropriate noise function (e.g., +Gaussian noise), it allows for efficient differentially private algorithms under +continual observation, without prior subsampling or shuffling assumptions. + +`build_tree` constructs a tree given the leaf nodes by recursively summing the +children nodes to get the parent node. It allows for efficient range queries and +other statistics such as quantiles on the leaf nodes. """ import abc @@ -440,3 +444,79 @@ class EfficientTreeAggregator(): level_buffer_idx=new_level_buffer_idx, value_generator_state=value_generator_state) return cumsum, new_state + + +@tf.function +def build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor: + """A function constructs a complete tree given all the leaf nodes. + + The function takes a 1-D array representing the leaf nodes of a tree and the + tree's arity, and constructs a complete tree by recursively summing the + adjacent children to get the parent until reaching the root node. Because we + assume a complete tree, if the number of leaf nodes does not divide arity, the + leaf nodes will be padded with zeros. + + Args: + leaf_nodes: A 1-D array storing the leaf nodes of the tree. + arity: A `int` for the branching factor of the tree, i.e. the number of + children for each internal node. + + Returns: + `tf.RaggedTensor` representing the tree. For example, if + `leaf_nodes=tf.Tensor([1, 2, 3, 4])` and `arity=2`, then the returned value + should be `tree=tf.RaggedTensor([[10],[3,7],[1,2,3,4]])`. In this way, + `tree[layer][index]` can be used to access the node indexed by (layer, + index) in the tree, + + Raises: + ValueError: if parameters don't meet expectations. There are two situations + where the error is raised: (1) the input tensor has length smaller than 1; + (2) The arity is less than 2. + """ + + if len(leaf_nodes) <= 0: + raise ValueError( + 'The number of leaf nodes should at least be 1.' + f'However, an array of length {len(leaf_nodes)} is detected') + + if arity <= 1: + raise ValueError('The branching factor should be at least 2.' + f'However, a branching factor of {arity} is detected.') + + def pad_zero(leaf_nodes, size): + paddings = [[0, size - len(leaf_nodes)]] + return tf.pad(leaf_nodes, paddings) + + leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32) + num_layers = tf.math.ceil( + tf.math.log(leaf_nodes_size) / + tf.math.log(tf.constant(arity, dtype=tf.float32))) + 1 + leaf_nodes = pad_zero(leaf_nodes, tf.math.pow(float(arity), num_layers - 1)) + + def _shrink_layer(layer: tf.Tensor, arity: int) -> tf.Tensor: + return tf.reduce_sum((tf.reshape(layer, (-1, arity))), 1) + + # The following `tf.while_loop` constructs the tree from bottom up by + # iteratively applying `_shrink_layer` to each layer of the tree. The reason + # for the choice of TF1.0-style `tf.while_loop` is that @tf.function does not + # support auto-translation from python loop to tf loop when loop variables + # contain a `RaggedTensor` whose shape changes across iterations. + + idx = tf.identity(num_layers) + loop_cond = lambda i, h: tf.less_equal(2.0, i) + + def _loop_body(i, h): + return [ + tf.add(i, -1.0), + tf.concat(([_shrink_layer(h[0], arity)], h), axis=0) + ] + + _, tree = tf.while_loop( + loop_cond, + _loop_body, [idx, tf.RaggedTensor.from_tensor([leaf_nodes])], + shape_invariants=[ + idx.get_shape(), + tf.RaggedTensorSpec(dtype=leaf_nodes.dtype, ragged_rank=1) + ]) + + return tree diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py index 9a8be35..9a237ad 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_test.py @@ -365,5 +365,39 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase): self.assertAllEqual(gstate, gstate2) +class BuildTreeTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.product( + leaf_nodes_size=[1, 2, 3, 4, 5], + arity=[2, 3], + dtype=[tf.int32, tf.float32], + ) + def test_build_tree_from_leaf(self, leaf_nodes_size, arity, dtype): + """Test whether `build_tree_from_leaf` will output the correct tree.""" + + leaf_nodes = tf.cast(tf.range(leaf_nodes_size), dtype) + depth = math.ceil(math.log(leaf_nodes_size, arity)) + 1 + + tree = tree_aggregation.build_tree_from_leaf(leaf_nodes, arity) + + self.assertEqual(depth, tree.shape[0]) + + for layer in range(depth): + reverse_depth = tree.shape[0] - layer - 1 + span_size = arity**reverse_depth + for idx in range(arity**layer): + left = idx * span_size + right = (idx + 1) * span_size + expected_value = sum(leaf_nodes[left:right]) + self.assertEqual(tree[layer][idx], expected_value) + + @parameterized.named_parameters(('negative_arity', [1], -1), + ('empty_hist', [], 2)) + def test_value_error_raises(self, leaf_nodes, arity): + """Test whether `build_tree_from_leaf` will raise the correct error when the input is illegal.""" + with self.assertRaises(ValueError): + tree_aggregation.build_tree_from_leaf(leaf_nodes, arity) + + if __name__ == '__main__': tf.test.main()