Create a hierarchical histogram IterativeProcess that is compatible with tff.backends.mapreduce.MapReduceForm.

PiperOrigin-RevId: 411845363
This commit is contained in:
Wennan Zhu 2021-11-23 10:37:41 -08:00 committed by A. Unique TensorFlower
parent 7c4f5bab09
commit 290ecf7797

View file

@ -26,7 +26,6 @@ from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query from tensorflow_privacy.privacy.dp_query import gaussian_query
@tf.function
def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor: def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
"""A function constructs a complete tree given all the leaf nodes. """A function constructs a complete tree given all the leaf nodes.
@ -50,10 +49,11 @@ def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
""" """
def pad_zero(leaf_nodes, size): def pad_zero(leaf_nodes, size):
paddings = [[0, size - len(leaf_nodes)]] paddings = tf.zeros(
return tf.pad(leaf_nodes, paddings) shape=(size - leaf_nodes.shape[0],), dtype=leaf_nodes.dtype)
return tf.concat((leaf_nodes, paddings), axis=0)
leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32) leaf_nodes_size = tf.constant(leaf_nodes.shape[0], dtype=tf.float32)
num_layers = tf.math.ceil( num_layers = tf.math.ceil(
tf.math.log(leaf_nodes_size) / tf.math.log(leaf_nodes_size) /
tf.math.log(tf.cast(arity, dtype=tf.float32))) + 1 tf.math.log(tf.cast(arity, dtype=tf.float32))) + 1