Create a hierarchical histogram IterativeProcess that is compatible with tff.backends.mapreduce.MapReduceForm.
PiperOrigin-RevId: 411845363
This commit is contained in:
parent
7c4f5bab09
commit
290ecf7797
1 changed files with 4 additions and 4 deletions
|
@ -26,7 +26,6 @@ from tensorflow_privacy.privacy.dp_query import dp_query
|
|||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||
|
||||
|
||||
@tf.function
|
||||
def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
|
||||
"""A function constructs a complete tree given all the leaf nodes.
|
||||
|
||||
|
@ -50,10 +49,11 @@ def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
|
|||
"""
|
||||
|
||||
def pad_zero(leaf_nodes, size):
|
||||
paddings = [[0, size - len(leaf_nodes)]]
|
||||
return tf.pad(leaf_nodes, paddings)
|
||||
paddings = tf.zeros(
|
||||
shape=(size - leaf_nodes.shape[0],), dtype=leaf_nodes.dtype)
|
||||
return tf.concat((leaf_nodes, paddings), axis=0)
|
||||
|
||||
leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32)
|
||||
leaf_nodes_size = tf.constant(leaf_nodes.shape[0], dtype=tf.float32)
|
||||
num_layers = tf.math.ceil(
|
||||
tf.math.log(leaf_nodes_size) /
|
||||
tf.math.log(tf.cast(arity, dtype=tf.float32))) + 1
|
||||
|
|
Loading…
Reference in a new issue