From 290ecf7797a883e6015902f77f0ac1366edb57ea Mon Sep 17 00:00:00 2001 From: Wennan Zhu Date: Tue, 23 Nov 2021 10:37:41 -0800 Subject: [PATCH] Create a hierarchical histogram IterativeProcess that is compatible with tff.backends.mapreduce.MapReduceForm. PiperOrigin-RevId: 411845363 --- tensorflow_privacy/privacy/dp_query/tree_range_query.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow_privacy/privacy/dp_query/tree_range_query.py b/tensorflow_privacy/privacy/dp_query/tree_range_query.py index f5a6083..471915b 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_range_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_range_query.py @@ -26,7 +26,6 @@ from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import gaussian_query -@tf.function def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor: """A function constructs a complete tree given all the leaf nodes. @@ -50,10 +49,11 @@ def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor: """ def pad_zero(leaf_nodes, size): - paddings = [[0, size - len(leaf_nodes)]] - return tf.pad(leaf_nodes, paddings) + paddings = tf.zeros( + shape=(size - leaf_nodes.shape[0],), dtype=leaf_nodes.dtype) + return tf.concat((leaf_nodes, paddings), axis=0) - leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32) + leaf_nodes_size = tf.constant(leaf_nodes.shape[0], dtype=tf.float32) num_layers = tf.math.ceil( tf.math.log(leaf_nodes_size) / tf.math.log(tf.cast(arity, dtype=tf.float32))) + 1