forked from 626_privacy/tensorflow_privacy
Create a hierarchical histogram IterativeProcess that is compatible with tff.backends.mapreduce.MapReduceForm.
PiperOrigin-RevId: 411845363
This commit is contained in:
parent
7c4f5bab09
commit
290ecf7797
1 changed files with 4 additions and 4 deletions
|
@ -26,7 +26,6 @@ from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||||
|
|
||||||
|
|
||||||
@tf.function
|
|
||||||
def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
|
def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
|
||||||
"""A function constructs a complete tree given all the leaf nodes.
|
"""A function constructs a complete tree given all the leaf nodes.
|
||||||
|
|
||||||
|
@ -50,10 +49,11 @@ def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def pad_zero(leaf_nodes, size):
|
def pad_zero(leaf_nodes, size):
|
||||||
paddings = [[0, size - len(leaf_nodes)]]
|
paddings = tf.zeros(
|
||||||
return tf.pad(leaf_nodes, paddings)
|
shape=(size - leaf_nodes.shape[0],), dtype=leaf_nodes.dtype)
|
||||||
|
return tf.concat((leaf_nodes, paddings), axis=0)
|
||||||
|
|
||||||
leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32)
|
leaf_nodes_size = tf.constant(leaf_nodes.shape[0], dtype=tf.float32)
|
||||||
num_layers = tf.math.ceil(
|
num_layers = tf.math.ceil(
|
||||||
tf.math.log(leaf_nodes_size) /
|
tf.math.log(leaf_nodes_size) /
|
||||||
tf.math.log(tf.cast(arity, dtype=tf.float32))) + 1
|
tf.math.log(tf.cast(arity, dtype=tf.float32))) + 1
|
||||||
|
|
Loading…
Reference in a new issue