forked from 626_privacy/tensorflow_privacy
Add build_tree
function which takes in a histogram and builds a tree on top of it. The function will be used in CentralTreeSumQuery
and DistributedTreeSumQuery
in a following CL.
For more details about `CentralTreeSumQuery` and `DistributedTreeSumQuery`, please refer to the implementation design section in the following design doc: https://docs.google.com/document/d/14LL94yZx3MdorCEOE0QZNhyIx7P_3voyrl4Nlt2HF7k/edit?resourcekey=0-X3xeTk6w-fkYFezl5fxmCQ# PiperOrigin-RevId: 382199971
This commit is contained in:
parent
34249f464b
commit
2396098b94
2 changed files with 118 additions and 4 deletions
|
@ -13,10 +13,14 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tree aggregation algorithm.
|
"""Tree aggregation algorithm.
|
||||||
|
|
||||||
This algorithm computes cumulative sums of noise based on tree aggregation. When
|
`TreeAggregator` and `EfficientTreeAggregator` compute cumulative sums of noise
|
||||||
using an appropriate noise function (e.g., Gaussian noise), it allows for
|
based on tree aggregation. When using an appropriate noise function (e.g.,
|
||||||
efficient differentially private algorithms under continual observation, without
|
Gaussian noise), it allows for efficient differentially private algorithms under
|
||||||
prior subsampling or shuffling assumptions.
|
continual observation, without prior subsampling or shuffling assumptions.
|
||||||
|
|
||||||
|
`build_tree` constructs a tree given the leaf nodes by recursively summing the
|
||||||
|
children nodes to get the parent node. It allows for efficient range queries and
|
||||||
|
other statistics such as quantiles on the leaf nodes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
@ -440,3 +444,79 @@ class EfficientTreeAggregator():
|
||||||
level_buffer_idx=new_level_buffer_idx,
|
level_buffer_idx=new_level_buffer_idx,
|
||||||
value_generator_state=value_generator_state)
|
value_generator_state=value_generator_state)
|
||||||
return cumsum, new_state
|
return cumsum, new_state
|
||||||
|
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
|
||||||
|
"""A function constructs a complete tree given all the leaf nodes.
|
||||||
|
|
||||||
|
The function takes a 1-D array representing the leaf nodes of a tree and the
|
||||||
|
tree's arity, and constructs a complete tree by recursively summing the
|
||||||
|
adjacent children to get the parent until reaching the root node. Because we
|
||||||
|
assume a complete tree, if the number of leaf nodes does not divide arity, the
|
||||||
|
leaf nodes will be padded with zeros.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
leaf_nodes: A 1-D array storing the leaf nodes of the tree.
|
||||||
|
arity: A `int` for the branching factor of the tree, i.e. the number of
|
||||||
|
children for each internal node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`tf.RaggedTensor` representing the tree. For example, if
|
||||||
|
`leaf_nodes=tf.Tensor([1, 2, 3, 4])` and `arity=2`, then the returned value
|
||||||
|
should be `tree=tf.RaggedTensor([[10],[3,7],[1,2,3,4]])`. In this way,
|
||||||
|
`tree[layer][index]` can be used to access the node indexed by (layer,
|
||||||
|
index) in the tree,
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if parameters don't meet expectations. There are two situations
|
||||||
|
where the error is raised: (1) the input tensor has length smaller than 1;
|
||||||
|
(2) The arity is less than 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(leaf_nodes) <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
'The number of leaf nodes should at least be 1.'
|
||||||
|
f'However, an array of length {len(leaf_nodes)} is detected')
|
||||||
|
|
||||||
|
if arity <= 1:
|
||||||
|
raise ValueError('The branching factor should be at least 2.'
|
||||||
|
f'However, a branching factor of {arity} is detected.')
|
||||||
|
|
||||||
|
def pad_zero(leaf_nodes, size):
|
||||||
|
paddings = [[0, size - len(leaf_nodes)]]
|
||||||
|
return tf.pad(leaf_nodes, paddings)
|
||||||
|
|
||||||
|
leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32)
|
||||||
|
num_layers = tf.math.ceil(
|
||||||
|
tf.math.log(leaf_nodes_size) /
|
||||||
|
tf.math.log(tf.constant(arity, dtype=tf.float32))) + 1
|
||||||
|
leaf_nodes = pad_zero(leaf_nodes, tf.math.pow(float(arity), num_layers - 1))
|
||||||
|
|
||||||
|
def _shrink_layer(layer: tf.Tensor, arity: int) -> tf.Tensor:
|
||||||
|
return tf.reduce_sum((tf.reshape(layer, (-1, arity))), 1)
|
||||||
|
|
||||||
|
# The following `tf.while_loop` constructs the tree from bottom up by
|
||||||
|
# iteratively applying `_shrink_layer` to each layer of the tree. The reason
|
||||||
|
# for the choice of TF1.0-style `tf.while_loop` is that @tf.function does not
|
||||||
|
# support auto-translation from python loop to tf loop when loop variables
|
||||||
|
# contain a `RaggedTensor` whose shape changes across iterations.
|
||||||
|
|
||||||
|
idx = tf.identity(num_layers)
|
||||||
|
loop_cond = lambda i, h: tf.less_equal(2.0, i)
|
||||||
|
|
||||||
|
def _loop_body(i, h):
|
||||||
|
return [
|
||||||
|
tf.add(i, -1.0),
|
||||||
|
tf.concat(([_shrink_layer(h[0], arity)], h), axis=0)
|
||||||
|
]
|
||||||
|
|
||||||
|
_, tree = tf.while_loop(
|
||||||
|
loop_cond,
|
||||||
|
_loop_body, [idx, tf.RaggedTensor.from_tensor([leaf_nodes])],
|
||||||
|
shape_invariants=[
|
||||||
|
idx.get_shape(),
|
||||||
|
tf.RaggedTensorSpec(dtype=leaf_nodes.dtype, ragged_rank=1)
|
||||||
|
])
|
||||||
|
|
||||||
|
return tree
|
||||||
|
|
|
@ -365,5 +365,39 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
|
||||||
self.assertAllEqual(gstate, gstate2)
|
self.assertAllEqual(gstate, gstate2)
|
||||||
|
|
||||||
|
|
||||||
|
class BuildTreeTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
leaf_nodes_size=[1, 2, 3, 4, 5],
|
||||||
|
arity=[2, 3],
|
||||||
|
dtype=[tf.int32, tf.float32],
|
||||||
|
)
|
||||||
|
def test_build_tree_from_leaf(self, leaf_nodes_size, arity, dtype):
|
||||||
|
"""Test whether `build_tree_from_leaf` will output the correct tree."""
|
||||||
|
|
||||||
|
leaf_nodes = tf.cast(tf.range(leaf_nodes_size), dtype)
|
||||||
|
depth = math.ceil(math.log(leaf_nodes_size, arity)) + 1
|
||||||
|
|
||||||
|
tree = tree_aggregation.build_tree_from_leaf(leaf_nodes, arity)
|
||||||
|
|
||||||
|
self.assertEqual(depth, tree.shape[0])
|
||||||
|
|
||||||
|
for layer in range(depth):
|
||||||
|
reverse_depth = tree.shape[0] - layer - 1
|
||||||
|
span_size = arity**reverse_depth
|
||||||
|
for idx in range(arity**layer):
|
||||||
|
left = idx * span_size
|
||||||
|
right = (idx + 1) * span_size
|
||||||
|
expected_value = sum(leaf_nodes[left:right])
|
||||||
|
self.assertEqual(tree[layer][idx], expected_value)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('negative_arity', [1], -1),
|
||||||
|
('empty_hist', [], 2))
|
||||||
|
def test_value_error_raises(self, leaf_nodes, arity):
|
||||||
|
"""Test whether `build_tree_from_leaf` will raise the correct error when the input is illegal."""
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tree_aggregation.build_tree_from_leaf(leaf_nodes, arity)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue