diff --git a/tensorflow_privacy/privacy/dp_query/tree_range_query.py b/tensorflow_privacy/privacy/dp_query/tree_range_query.py index f5a6083..471915b 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_range_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_range_query.py @@ -26,7 +26,6 @@ from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import gaussian_query -@tf.function def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor: """A function constructs a complete tree given all the leaf nodes. @@ -50,10 +49,11 @@ def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor: """ def pad_zero(leaf_nodes, size): - paddings = [[0, size - len(leaf_nodes)]] - return tf.pad(leaf_nodes, paddings) + paddings = tf.zeros( + shape=(size - leaf_nodes.shape[0],), dtype=leaf_nodes.dtype) + return tf.concat((leaf_nodes, paddings), axis=0) - leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32) + leaf_nodes_size = tf.constant(leaf_nodes.shape[0], dtype=tf.float32) num_layers = tf.math.ceil( tf.math.log(leaf_nodes_size) / tf.math.log(tf.cast(arity, dtype=tf.float32))) + 1