Add build_tree function which takes in a histogram and builds a tree on top of it. The function will be used in CentralTreeSumQuery and DistributedTreeSumQuery in a following CL.

For more details about `CentralTreeSumQuery` and `DistributedTreeSumQuery`, please refer to the implementation design section in the following design doc: https://docs.google.com/document/d/14LL94yZx3MdorCEOE0QZNhyIx7P_3voyrl4Nlt2HF7k/edit?resourcekey=0-X3xeTk6w-fkYFezl5fxmCQ#

PiperOrigin-RevId: 382199971
This commit is contained in:
A. Unique TensorFlower 2021-06-29 17:31:02 -07:00
parent 34249f464b
commit 2396098b94
2 changed files with 118 additions and 4 deletions

View file

@ -13,10 +13,14 @@
# limitations under the License.
"""Tree aggregation algorithm.
This algorithm computes cumulative sums of noise based on tree aggregation. When
using an appropriate noise function (e.g., Gaussian noise), it allows for
efficient differentially private algorithms under continual observation, without
prior subsampling or shuffling assumptions.
`TreeAggregator` and `EfficientTreeAggregator` compute cumulative sums of noise
based on tree aggregation. When using an appropriate noise function (e.g.,
Gaussian noise), it allows for efficient differentially private algorithms under
continual observation, without prior subsampling or shuffling assumptions.
`build_tree` constructs a tree given the leaf nodes by recursively summing the
children nodes to get the parent node. It allows for efficient range queries and
other statistics such as quantiles on the leaf nodes.
"""
import abc
@ -440,3 +444,79 @@ class EfficientTreeAggregator():
level_buffer_idx=new_level_buffer_idx,
value_generator_state=value_generator_state)
return cumsum, new_state
@tf.function
def build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
"""A function constructs a complete tree given all the leaf nodes.
The function takes a 1-D array representing the leaf nodes of a tree and the
tree's arity, and constructs a complete tree by recursively summing the
adjacent children to get the parent until reaching the root node. Because we
assume a complete tree, if the number of leaf nodes does not divide arity, the
leaf nodes will be padded with zeros.
Args:
leaf_nodes: A 1-D array storing the leaf nodes of the tree.
arity: A `int` for the branching factor of the tree, i.e. the number of
children for each internal node.
Returns:
`tf.RaggedTensor` representing the tree. For example, if
`leaf_nodes=tf.Tensor([1, 2, 3, 4])` and `arity=2`, then the returned value
should be `tree=tf.RaggedTensor([[10],[3,7],[1,2,3,4]])`. In this way,
`tree[layer][index]` can be used to access the node indexed by (layer,
index) in the tree,
Raises:
ValueError: if parameters don't meet expectations. There are two situations
where the error is raised: (1) the input tensor has length smaller than 1;
(2) The arity is less than 2.
"""
if len(leaf_nodes) <= 0:
raise ValueError(
'The number of leaf nodes should at least be 1.'
f'However, an array of length {len(leaf_nodes)} is detected')
if arity <= 1:
raise ValueError('The branching factor should be at least 2.'
f'However, a branching factor of {arity} is detected.')
def pad_zero(leaf_nodes, size):
paddings = [[0, size - len(leaf_nodes)]]
return tf.pad(leaf_nodes, paddings)
leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32)
num_layers = tf.math.ceil(
tf.math.log(leaf_nodes_size) /
tf.math.log(tf.constant(arity, dtype=tf.float32))) + 1
leaf_nodes = pad_zero(leaf_nodes, tf.math.pow(float(arity), num_layers - 1))
def _shrink_layer(layer: tf.Tensor, arity: int) -> tf.Tensor:
return tf.reduce_sum((tf.reshape(layer, (-1, arity))), 1)
# The following `tf.while_loop` constructs the tree from bottom up by
# iteratively applying `_shrink_layer` to each layer of the tree. The reason
# for the choice of TF1.0-style `tf.while_loop` is that @tf.function does not
# support auto-translation from python loop to tf loop when loop variables
# contain a `RaggedTensor` whose shape changes across iterations.
idx = tf.identity(num_layers)
loop_cond = lambda i, h: tf.less_equal(2.0, i)
def _loop_body(i, h):
return [
tf.add(i, -1.0),
tf.concat(([_shrink_layer(h[0], arity)], h), axis=0)
]
_, tree = tf.while_loop(
loop_cond,
_loop_body, [idx, tf.RaggedTensor.from_tensor([leaf_nodes])],
shape_invariants=[
idx.get_shape(),
tf.RaggedTensorSpec(dtype=leaf_nodes.dtype, ragged_rank=1)
])
return tree

View file

@ -365,5 +365,39 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
self.assertAllEqual(gstate, gstate2)
class BuildTreeTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
leaf_nodes_size=[1, 2, 3, 4, 5],
arity=[2, 3],
dtype=[tf.int32, tf.float32],
)
def test_build_tree_from_leaf(self, leaf_nodes_size, arity, dtype):
"""Test whether `build_tree_from_leaf` will output the correct tree."""
leaf_nodes = tf.cast(tf.range(leaf_nodes_size), dtype)
depth = math.ceil(math.log(leaf_nodes_size, arity)) + 1
tree = tree_aggregation.build_tree_from_leaf(leaf_nodes, arity)
self.assertEqual(depth, tree.shape[0])
for layer in range(depth):
reverse_depth = tree.shape[0] - layer - 1
span_size = arity**reverse_depth
for idx in range(arity**layer):
left = idx * span_size
right = (idx + 1) * span_size
expected_value = sum(leaf_nodes[left:right])
self.assertEqual(tree[layer][idx], expected_value)
@parameterized.named_parameters(('negative_arity', [1], -1),
('empty_hist', [], 2))
def test_value_error_raises(self, leaf_nodes, arity):
"""Test whether `build_tree_from_leaf` will raise the correct error when the input is illegal."""
with self.assertRaises(ValueError):
tree_aggregation.build_tree_from_leaf(leaf_nodes, arity)
if __name__ == '__main__':
tf.test.main()