(1) add CentralTreeSumQuery
and DistributedTreeSumQuery
to tree_aggregation_query.py. (2) move build_tree_from_leaf
to tree_aggregation_query.py together with CentralTreeSumQuery
.
PiperOrigin-RevId: 383511025
This commit is contained in:
parent
d6aa796684
commit
caf6f36b80
4 changed files with 594 additions and 122 deletions
|
@ -17,10 +17,6 @@
|
||||||
based on tree aggregation. When using an appropriate noise function (e.g.,
|
based on tree aggregation. When using an appropriate noise function (e.g.,
|
||||||
Gaussian noise), it allows for efficient differentially private algorithms under
|
Gaussian noise), it allows for efficient differentially private algorithms under
|
||||||
continual observation, without prior subsampling or shuffling assumptions.
|
continual observation, without prior subsampling or shuffling assumptions.
|
||||||
|
|
||||||
`build_tree` constructs a tree given the leaf nodes by recursively summing the
|
|
||||||
children nodes to get the parent node. It allows for efficient range queries and
|
|
||||||
other statistics such as quantiles on the leaf nodes.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
@ -449,79 +445,3 @@ class EfficientTreeAggregator():
|
||||||
level_buffer_idx=new_level_buffer_idx,
|
level_buffer_idx=new_level_buffer_idx,
|
||||||
value_generator_state=value_generator_state)
|
value_generator_state=value_generator_state)
|
||||||
return cumsum, new_state
|
return cumsum, new_state
|
||||||
|
|
||||||
|
|
||||||
@tf.function
|
|
||||||
def build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
|
|
||||||
"""A function constructs a complete tree given all the leaf nodes.
|
|
||||||
|
|
||||||
The function takes a 1-D array representing the leaf nodes of a tree and the
|
|
||||||
tree's arity, and constructs a complete tree by recursively summing the
|
|
||||||
adjacent children to get the parent until reaching the root node. Because we
|
|
||||||
assume a complete tree, if the number of leaf nodes does not divide arity, the
|
|
||||||
leaf nodes will be padded with zeros.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
leaf_nodes: A 1-D array storing the leaf nodes of the tree.
|
|
||||||
arity: A `int` for the branching factor of the tree, i.e. the number of
|
|
||||||
children for each internal node.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`tf.RaggedTensor` representing the tree. For example, if
|
|
||||||
`leaf_nodes=tf.Tensor([1, 2, 3, 4])` and `arity=2`, then the returned value
|
|
||||||
should be `tree=tf.RaggedTensor([[10],[3,7],[1,2,3,4]])`. In this way,
|
|
||||||
`tree[layer][index]` can be used to access the node indexed by (layer,
|
|
||||||
index) in the tree,
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if parameters don't meet expectations. There are two situations
|
|
||||||
where the error is raised: (1) the input tensor has length smaller than 1;
|
|
||||||
(2) The arity is less than 2.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if len(leaf_nodes) <= 0:
|
|
||||||
raise ValueError(
|
|
||||||
'The number of leaf nodes should at least be 1.'
|
|
||||||
f'However, an array of length {len(leaf_nodes)} is detected')
|
|
||||||
|
|
||||||
if arity <= 1:
|
|
||||||
raise ValueError('The branching factor should be at least 2.'
|
|
||||||
f'However, a branching factor of {arity} is detected.')
|
|
||||||
|
|
||||||
def pad_zero(leaf_nodes, size):
|
|
||||||
paddings = [[0, size - len(leaf_nodes)]]
|
|
||||||
return tf.pad(leaf_nodes, paddings)
|
|
||||||
|
|
||||||
leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32)
|
|
||||||
num_layers = tf.math.ceil(
|
|
||||||
tf.math.log(leaf_nodes_size) /
|
|
||||||
tf.math.log(tf.constant(arity, dtype=tf.float32))) + 1
|
|
||||||
leaf_nodes = pad_zero(leaf_nodes, tf.math.pow(float(arity), num_layers - 1))
|
|
||||||
|
|
||||||
def _shrink_layer(layer: tf.Tensor, arity: int) -> tf.Tensor:
|
|
||||||
return tf.reduce_sum((tf.reshape(layer, (-1, arity))), 1)
|
|
||||||
|
|
||||||
# The following `tf.while_loop` constructs the tree from bottom up by
|
|
||||||
# iteratively applying `_shrink_layer` to each layer of the tree. The reason
|
|
||||||
# for the choice of TF1.0-style `tf.while_loop` is that @tf.function does not
|
|
||||||
# support auto-translation from python loop to tf loop when loop variables
|
|
||||||
# contain a `RaggedTensor` whose shape changes across iterations.
|
|
||||||
|
|
||||||
idx = tf.identity(num_layers)
|
|
||||||
loop_cond = lambda i, h: tf.less_equal(2.0, i)
|
|
||||||
|
|
||||||
def _loop_body(i, h):
|
|
||||||
return [
|
|
||||||
tf.add(i, -1.0),
|
|
||||||
tf.concat(([_shrink_layer(h[0], arity)], h), axis=0)
|
|
||||||
]
|
|
||||||
|
|
||||||
_, tree = tf.while_loop(
|
|
||||||
loop_cond,
|
|
||||||
_loop_body, [idx, tf.RaggedTensor.from_tensor([leaf_nodes])],
|
|
||||||
shape_invariants=[
|
|
||||||
idx.get_shape(),
|
|
||||||
tf.RaggedTensorSpec(dtype=leaf_nodes.dtype, ragged_rank=1)
|
|
||||||
])
|
|
||||||
|
|
||||||
return tree
|
|
||||||
|
|
|
@ -11,9 +11,22 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""DPQuery for continual observation queries relying on `tree_aggregation`."""
|
"""`DPQuery`s for differentially private tree aggregation protocols.
|
||||||
|
|
||||||
|
`TreeCumulativeSumQuery` and `TreeResidualSumQuery` are `DPQuery`s for continual
|
||||||
|
online observation queries relying on `tree_aggregation`. 'Online' means that
|
||||||
|
the leaf nodes of the tree arrive one by one as the time proceeds. The leaves
|
||||||
|
are vector records as defined in `dp_query.DPQuery`.
|
||||||
|
|
||||||
|
`CentralTreeSumQuery` and `DistributedTreeSumQuery` are `DPQuery`s for
|
||||||
|
central/distributed offline tree aggregation protocol. 'Offline' means all the
|
||||||
|
leaf nodes are ready before the protocol starts. Each record, different from
|
||||||
|
what is defined in `dp_query.DPQuery`, is a histogram (i.e. the leaf nodes).
|
||||||
|
"""
|
||||||
|
import distutils
|
||||||
|
import math
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
|
@ -34,8 +47,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
corresponding record, e.g. clip_fn(flat_record, clip_value).
|
corresponding record, e.g. clip_fn(flat_record, clip_value).
|
||||||
clip_value: float indicating the value at which to clip the record.
|
clip_value: float indicating the value at which to clip the record.
|
||||||
record_specs: `Collection[tf.TensorSpec]` specifying shapes of records.
|
record_specs: `Collection[tf.TensorSpec]` specifying shapes of records.
|
||||||
tree_aggregator: `tree_aggregation.TreeAggregator` initialized with
|
tree_aggregator: `tree_aggregation.TreeAggregator` initialized with user
|
||||||
user defined `noise_generator`. `noise_generator` is a
|
defined `noise_generator`. `noise_generator` is a
|
||||||
`tree_aggregation.ValueGenerator` to generate the noise value for a tree
|
`tree_aggregation.ValueGenerator` to generate the noise value for a tree
|
||||||
node. Noise stdandard deviation is specified outside the `dp_query` by the
|
node. Noise stdandard deviation is specified outside the `dp_query` by the
|
||||||
user when defining `noise_fn` and should have order
|
user when defining `noise_fn` and should have order
|
||||||
|
@ -364,3 +377,297 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
record_specs=record_specs,
|
record_specs=record_specs,
|
||||||
noise_generator=gaussian_noise_generator,
|
noise_generator=gaussian_noise_generator,
|
||||||
use_efficient=use_efficient)
|
use_efficient=use_efficient)
|
||||||
|
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
|
||||||
|
"""A function constructs a complete tree given all the leaf nodes.
|
||||||
|
|
||||||
|
The function takes a 1-D array representing the leaf nodes of a tree and the
|
||||||
|
tree's arity, and constructs a complete tree by recursively summing the
|
||||||
|
adjacent children to get the parent until reaching the root node. Because we
|
||||||
|
assume a complete tree, if the number of leaf nodes does not divide arity, the
|
||||||
|
leaf nodes will be padded with zeros.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
leaf_nodes: A 1-D array storing the leaf nodes of the tree.
|
||||||
|
arity: A `int` for the branching factor of the tree, i.e. the number of
|
||||||
|
children for each internal node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`tf.RaggedTensor` representing the tree. For example, if
|
||||||
|
`leaf_nodes=tf.Tensor([1, 2, 3, 4])` and `arity=2`, then the returned value
|
||||||
|
should be `tree=tf.RaggedTensor([[10],[3,7],[1,2,3,4]])`. In this way,
|
||||||
|
`tree[layer][index]` can be used to access the node indexed by (layer,
|
||||||
|
index) in the tree,
|
||||||
|
"""
|
||||||
|
|
||||||
|
def pad_zero(leaf_nodes, size):
|
||||||
|
paddings = [[0, size - len(leaf_nodes)]]
|
||||||
|
return tf.pad(leaf_nodes, paddings)
|
||||||
|
|
||||||
|
leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32)
|
||||||
|
num_layers = tf.math.ceil(
|
||||||
|
tf.math.log(leaf_nodes_size) /
|
||||||
|
tf.math.log(tf.cast(arity, dtype=tf.float32))) + 1
|
||||||
|
leaf_nodes = pad_zero(
|
||||||
|
leaf_nodes, tf.math.pow(tf.cast(arity, dtype=tf.float32), num_layers - 1))
|
||||||
|
|
||||||
|
def _shrink_layer(layer: tf.Tensor, arity: int) -> tf.Tensor:
|
||||||
|
return tf.reduce_sum((tf.reshape(layer, (-1, arity))), 1)
|
||||||
|
|
||||||
|
# The following `tf.while_loop` constructs the tree from bottom up by
|
||||||
|
# iteratively applying `_shrink_layer` to each layer of the tree. The reason
|
||||||
|
# for the choice of TF1.0-style `tf.while_loop` is that @tf.function does not
|
||||||
|
# support auto-translation from python loop to tf loop when loop variables
|
||||||
|
# contain a `RaggedTensor` whose shape changes across iterations.
|
||||||
|
|
||||||
|
idx = tf.identity(num_layers)
|
||||||
|
loop_cond = lambda i, h: tf.less_equal(2.0, i)
|
||||||
|
|
||||||
|
def _loop_body(i, h):
|
||||||
|
return [
|
||||||
|
tf.add(i, -1.0),
|
||||||
|
tf.concat(([_shrink_layer(h[0], arity)], h), axis=0)
|
||||||
|
]
|
||||||
|
|
||||||
|
_, tree = tf.while_loop(
|
||||||
|
loop_cond,
|
||||||
|
_loop_body, [idx, tf.RaggedTensor.from_tensor([leaf_nodes])],
|
||||||
|
shape_invariants=[
|
||||||
|
idx.get_shape(),
|
||||||
|
tf.RaggedTensorSpec(dtype=leaf_nodes.dtype, ragged_rank=1)
|
||||||
|
])
|
||||||
|
|
||||||
|
return tree
|
||||||
|
|
||||||
|
|
||||||
|
def _get_add_noise(stddev):
|
||||||
|
"""Utility function to decide which `add_noise` to use according to tf version."""
|
||||||
|
if distutils.version.LooseVersion(
|
||||||
|
tf.__version__) < distutils.version.LooseVersion('2.0.0'):
|
||||||
|
|
||||||
|
def add_noise(v):
|
||||||
|
return v + tf.random.normal(
|
||||||
|
tf.shape(input=v), stddev=stddev, dtype=v.dtype)
|
||||||
|
else:
|
||||||
|
random_normal = tf.random_normal_initializer(stddev=stddev)
|
||||||
|
|
||||||
|
def add_noise(v):
|
||||||
|
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
|
||||||
|
|
||||||
|
return add_noise
|
||||||
|
|
||||||
|
|
||||||
|
class CentralTreeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
"""Implements dp_query for differentially private tree aggregation protocol.
|
||||||
|
|
||||||
|
Implements a central variant of the tree aggregation protocol from the paper
|
||||||
|
"'Is interaction necessary for distributed private learning?.' Adam Smith,
|
||||||
|
Abhradeep Thakurta, Jalaj Upadhyay" by replacing their local randomizer with
|
||||||
|
gaussian mechanism. The first step is to clip the clients' local updates (i.e.
|
||||||
|
a 1-D array containing the leaf nodes of the tree) by L1 norm to make sure it
|
||||||
|
does not exceed a prespecified upper bound. The second step is to construct
|
||||||
|
the tree on the clipped update. The third step is to add independent gaussian
|
||||||
|
noise to each node in the tree. The returned tree can support efficient and
|
||||||
|
accurate range queries with differential privacy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@attr.s(frozen=True)
|
||||||
|
class GlobalState(object):
|
||||||
|
"""Class defining global state for `CentralTreeSumQuery`.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
stddev: The stddev of the noise added to each node in the tree.
|
||||||
|
arity: The branching factor of the tree (i.e. the number of children each
|
||||||
|
internal node has).
|
||||||
|
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||||
|
needed to bound the sensitivity and deploy differential privacy.
|
||||||
|
"""
|
||||||
|
stddev = attr.ib()
|
||||||
|
arity = attr.ib()
|
||||||
|
l1_bound = attr.ib()
|
||||||
|
|
||||||
|
def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10):
|
||||||
|
"""Initializes the `CentralTreeSumQuery`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stddev: The stddev of the noise added to each internal node of the
|
||||||
|
constructed tree.
|
||||||
|
arity: The branching factor of the tree.
|
||||||
|
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||||
|
needed to bound the sensitivity and deploy differential privacy.
|
||||||
|
"""
|
||||||
|
self._stddev = stddev
|
||||||
|
self._arity = arity
|
||||||
|
self._l1_bound = l1_bound
|
||||||
|
|
||||||
|
def initial_global_state(self):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||||
|
return CentralTreeSumQuery.GlobalState(
|
||||||
|
stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound)
|
||||||
|
|
||||||
|
def derive_sample_params(self, global_state):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||||
|
return global_state.l1_bound
|
||||||
|
|
||||||
|
def preprocess_record(self, params, record):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
|
||||||
|
casted_record = tf.cast(record, tf.float32)
|
||||||
|
l1_norm = tf.norm(casted_record, ord=1)
|
||||||
|
|
||||||
|
l1_bound = tf.cast(params, tf.float32)
|
||||||
|
|
||||||
|
preprocessed_record, _ = tf.clip_by_global_norm([casted_record],
|
||||||
|
l1_bound,
|
||||||
|
use_norm=l1_norm)
|
||||||
|
|
||||||
|
return preprocessed_record[0]
|
||||||
|
|
||||||
|
def get_noised_result(self, sample_state, global_state):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_state: a frequency histogram.
|
||||||
|
global_state: hyper-parameters of the query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a `tf.RaggedTensor` representing the tree built on top of `sample_state`.
|
||||||
|
The jth node on the ith layer of the tree can be accessed by tree[i][j]
|
||||||
|
where tree is the returned value.
|
||||||
|
"""
|
||||||
|
add_noise = _get_add_noise(self._stddev)
|
||||||
|
tree = _build_tree_from_leaf(sample_state, global_state.arity)
|
||||||
|
return tf.nest.map_structure(
|
||||||
|
add_noise, tree, expand_composites=True), global_state
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
"""Implements dp_query for differentially private tree aggregation protocol.
|
||||||
|
|
||||||
|
The difference from `CentralTreeSumQuery` is that the tree construction and
|
||||||
|
gaussian noise addition happen in `preprocess_records`. The difference only
|
||||||
|
takes effect when used together with
|
||||||
|
`tff.aggregators.DifferentiallyPrivateFactory`. In other cases, this class
|
||||||
|
should be treated as equal with `CentralTreeSumQuery`.
|
||||||
|
|
||||||
|
Implements a distributed version of the tree aggregation protocol from. "Is
|
||||||
|
interaction necessary for distributed private learning?." by replacing their
|
||||||
|
local randomizer with gaussian mechanism. The first step is to check the L1
|
||||||
|
norm of the clients' local updates (i.e. a 1-D array containing the leaf nodes
|
||||||
|
of the tree) to make sure it does not exceed a prespecified upper bound. The
|
||||||
|
second step is to construct the tree. The third step is to add independent
|
||||||
|
gaussian noise to each node in the tree. The returned tree can support
|
||||||
|
efficient and accurate range queries with differential privacy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@attr.s(frozen=True)
|
||||||
|
class GlobalState(object):
|
||||||
|
"""Class defining global state for DistributedTreeSumQuery.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
stddev: The stddev of the noise added to each internal node in the
|
||||||
|
constructed tree.
|
||||||
|
arity: The branching factor of the tree (i.e. the number of children each
|
||||||
|
internal node has).
|
||||||
|
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||||
|
needed to bound the sensitivity and deploy differential privacy.
|
||||||
|
"""
|
||||||
|
stddev = attr.ib()
|
||||||
|
arity = attr.ib()
|
||||||
|
l1_bound = attr.ib()
|
||||||
|
|
||||||
|
def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10):
|
||||||
|
"""Initializes the `DistributedTreeSumQuery`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stddev: The stddev of the noise added to each node in the tree.
|
||||||
|
arity: The branching factor of the tree.
|
||||||
|
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||||
|
needed to bound the sensitivity and deploy differential privacy.
|
||||||
|
"""
|
||||||
|
self._stddev = stddev
|
||||||
|
self._arity = arity
|
||||||
|
self._l1_bound = l1_bound
|
||||||
|
|
||||||
|
def initial_global_state(self):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||||
|
return DistributedTreeSumQuery.GlobalState(
|
||||||
|
stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound)
|
||||||
|
|
||||||
|
def derive_sample_params(self, global_state):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||||
|
return (global_state.stddev, global_state.arity, global_state.l1_bound)
|
||||||
|
|
||||||
|
def preprocess_record(self, params, record):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
|
||||||
|
|
||||||
|
This method clips the input record by L1 norm, constructs a tree on top of
|
||||||
|
it, and adds gaussian noise to each node of the tree for differential
|
||||||
|
privacy. Unlike `get_noised_result` in `CentralTreeSumQuery`, this function
|
||||||
|
flattens the `tf.RaggedTensor` before outputting it. This is useful when
|
||||||
|
used inside `tff.aggregators.DifferentiallyPrivateFactory` because it does
|
||||||
|
not accept ragged output tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: hyper-parameters for preprocessing record, (stddev, aritry,
|
||||||
|
l1_bound)
|
||||||
|
record: leaf nodes for the tree.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`tf.Tensor` representing the flattened version of the tree.
|
||||||
|
"""
|
||||||
|
_, arity, l1_bound_ = params
|
||||||
|
l1_bound = tf.cast(l1_bound_, tf.float32)
|
||||||
|
|
||||||
|
casted_record = tf.cast(record, tf.float32)
|
||||||
|
l1_norm = tf.norm(casted_record, ord=1)
|
||||||
|
|
||||||
|
preprocessed_record, _ = tf.clip_by_global_norm([casted_record],
|
||||||
|
l1_bound,
|
||||||
|
use_norm=l1_norm)
|
||||||
|
preprocessed_record = preprocessed_record[0]
|
||||||
|
|
||||||
|
add_noise = _get_add_noise(self._stddev)
|
||||||
|
tree = _build_tree_from_leaf(preprocessed_record, arity)
|
||||||
|
noisy_tree = tf.nest.map_structure(add_noise, tree, expand_composites=True)
|
||||||
|
|
||||||
|
# The following codes reshape the output vector so the output shape of can
|
||||||
|
# be statically inferred. This is useful when used with
|
||||||
|
# `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know
|
||||||
|
# the output shape of this function statically and explicitly.
|
||||||
|
flat_noisy_tree = noisy_tree.flat_values
|
||||||
|
flat_tree_shape = [
|
||||||
|
(self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) -
|
||||||
|
1) // (self._arity - 1)
|
||||||
|
]
|
||||||
|
return tf.reshape(flat_noisy_tree, flat_tree_shape)
|
||||||
|
|
||||||
|
def get_noised_result(self, sample_state, global_state):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
|
||||||
|
|
||||||
|
This function re-constructs the `tf.RaggedTensor` from the flattened tree
|
||||||
|
output by `preprocess_records.`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_state: `tf.Tensor` for the flattened tree.
|
||||||
|
global_state: hyper-parameters including noise multiplier, the branching
|
||||||
|
factor of the tree and the maximum records per user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a `tf.RaggedTensor` for the tree.
|
||||||
|
"""
|
||||||
|
# The [0] is needed because of how tf.RaggedTensor.from_two_splits works.
|
||||||
|
# print(tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6],
|
||||||
|
# row_splits=[0, 4, 4, 7, 8, 8]))
|
||||||
|
# <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
|
||||||
|
# This part is not written in tensorflow and will be executed on the server
|
||||||
|
# side instead of the client side if used with
|
||||||
|
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
|
||||||
|
row_splits = [0] + [
|
||||||
|
(self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(
|
||||||
|
math.floor(math.log(sample_state.shape[0], self._arity)) + 1)
|
||||||
|
]
|
||||||
|
tree = tf.RaggedTensor.from_row_splits(
|
||||||
|
values=sample_state, row_splits=row_splits)
|
||||||
|
return tree, global_state
|
||||||
|
|
|
@ -13,16 +13,15 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for `tree_aggregation_query`."""
|
"""Tests for `tree_aggregation_query`."""
|
||||||
|
|
||||||
from absl.testing import parameterized
|
import math
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.dp_query import test_utils
|
from tensorflow_privacy.privacy.dp_query import test_utils
|
||||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
||||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
|
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
|
||||||
|
|
||||||
|
|
||||||
STRUCT_RECORD = [
|
STRUCT_RECORD = [
|
||||||
tf.constant([[2.0, 0.0], [0.0, 1.0]]),
|
tf.constant([[2.0, 0.0], [0.0, 1.0]]),
|
||||||
tf.constant([-1.0, 0.0])
|
tf.constant([-1.0, 0.0])
|
||||||
|
@ -55,6 +54,7 @@ def _get_noise_fn(specs, stddev=NOISE_STD, seed=1):
|
||||||
|
|
||||||
def _get_no_noise_fn(specs):
|
def _get_no_noise_fn(specs):
|
||||||
shape = tf.nest.map_structure(lambda spec: spec.shape, specs)
|
shape = tf.nest.map_structure(lambda spec: spec.shape, specs)
|
||||||
|
|
||||||
def no_noise_fn():
|
def no_noise_fn():
|
||||||
return tf.nest.map_structure(tf.zeros, shape)
|
return tf.nest.map_structure(tf.zeros, shape)
|
||||||
|
|
||||||
|
@ -73,6 +73,7 @@ def _get_l2_clip_fn():
|
||||||
def _get_l_infty_clip_fn():
|
def _get_l_infty_clip_fn():
|
||||||
|
|
||||||
def l_infty_clip_fn(record_as_list, clip_value):
|
def l_infty_clip_fn(record_as_list, clip_value):
|
||||||
|
|
||||||
def clip(record):
|
def clip(record):
|
||||||
return tf.clip_by_value(
|
return tf.clip_by_value(
|
||||||
record, clip_value_min=-clip_value, clip_value_max=clip_value)
|
record, clip_value_min=-clip_value, clip_value_max=clip_value)
|
||||||
|
@ -395,5 +396,283 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertIsInstance(query._tree_aggregator, tree_class)
|
self.assertIsInstance(query._tree_aggregator, tree_class)
|
||||||
|
|
||||||
|
|
||||||
|
class BuildTreeTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
leaf_nodes_size=[1, 2, 3, 4, 5],
|
||||||
|
arity=[2, 3],
|
||||||
|
dtype=[tf.int32, tf.float32],
|
||||||
|
)
|
||||||
|
def test_build_tree_from_leaf(self, leaf_nodes_size, arity, dtype):
|
||||||
|
"""Test whether `_build_tree_from_leaf` will output the correct tree."""
|
||||||
|
|
||||||
|
leaf_nodes = tf.cast(tf.range(leaf_nodes_size), dtype)
|
||||||
|
depth = math.ceil(math.log(leaf_nodes_size, arity)) + 1
|
||||||
|
|
||||||
|
tree = tree_aggregation_query._build_tree_from_leaf(leaf_nodes, arity)
|
||||||
|
|
||||||
|
self.assertEqual(depth, tree.shape[0])
|
||||||
|
|
||||||
|
for layer in range(depth):
|
||||||
|
reverse_depth = tree.shape[0] - layer - 1
|
||||||
|
span_size = arity**reverse_depth
|
||||||
|
for idx in range(arity**layer):
|
||||||
|
left = idx * span_size
|
||||||
|
right = (idx + 1) * span_size
|
||||||
|
expected_value = sum(leaf_nodes[left:right])
|
||||||
|
self.assertEqual(tree[layer][idx], expected_value)
|
||||||
|
|
||||||
|
|
||||||
|
class CentralTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def test_initial_global_state_type(self):
|
||||||
|
|
||||||
|
query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
self.assertIsInstance(
|
||||||
|
global_state, tree_aggregation_query.CentralTreeSumQuery.GlobalState)
|
||||||
|
|
||||||
|
def test_derive_sample_params(self):
|
||||||
|
query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
self.assertAllClose(params, 10.)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32)),
|
||||||
|
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32)),
|
||||||
|
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32)),
|
||||||
|
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.],
|
||||||
|
dtype=tf.float32)),
|
||||||
|
)
|
||||||
|
def test_preprocess_record(self, arity, record):
|
||||||
|
query = tree_aggregation_query.CentralTreeSumQuery(
|
||||||
|
stddev=NOISE_STD, arity=arity)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
|
||||||
|
self.assertAllClose(preprocessed_record, record)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
|
tf.constant([5, 5, 0, 0], dtype=tf.int32)),
|
||||||
|
('binary_test_float', 2, tf.constant(
|
||||||
|
[10., 10., 0., 0.],
|
||||||
|
dtype=tf.float32), tf.constant([5., 5., 0., 0.], dtype=tf.float32)),
|
||||||
|
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
|
tf.constant([5, 5, 0, 0], dtype=tf.int32)),
|
||||||
|
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
|
||||||
|
dtype=tf.float32),
|
||||||
|
tf.constant([5., 5., 0., 0.], dtype=tf.float32)),
|
||||||
|
)
|
||||||
|
def test_preprocess_record_clipped(self, arity, record,
|
||||||
|
expected_clipped_value):
|
||||||
|
query = tree_aggregation_query.CentralTreeSumQuery(
|
||||||
|
stddev=NOISE_STD, arity=arity)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
self.assertAllClose(preprocessed_record, expected_clipped_value)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
|
||||||
|
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||||
|
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
|
||||||
|
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([[1.], [1., 0., 0.],
|
||||||
|
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||||
|
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||||
|
tf.ragged.constant([[1.], [1., 0., 0.],
|
||||||
|
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||||
|
)
|
||||||
|
def test_get_noised_result(self, arity, record, expected_tree):
|
||||||
|
query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
sample_state, global_state = query.get_noised_result(
|
||||||
|
preprocessed_record, global_state)
|
||||||
|
|
||||||
|
self.assertAllClose(sample_state, expected_tree)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
||||||
|
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
||||||
|
)
|
||||||
|
def test_get_noised_result_with_noise(self, stddev, record, expected_tree):
|
||||||
|
query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
sample_state_list = []
|
||||||
|
for _ in range(1000):
|
||||||
|
sample_state, _ = query.get_noised_result(preprocessed_record,
|
||||||
|
global_state)
|
||||||
|
sample_state_list.append(sample_state.flat_values.numpy())
|
||||||
|
expectation = np.mean(sample_state_list, axis=0)
|
||||||
|
variance = np.std(sample_state_list, axis=0)
|
||||||
|
|
||||||
|
self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4)
|
||||||
|
self.assertAllClose(
|
||||||
|
variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
|
||||||
|
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
|
||||||
|
dtype=tf.float32),
|
||||||
|
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
|
||||||
|
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([[10.], [10., 0., 0.],
|
||||||
|
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||||
|
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
|
||||||
|
dtype=tf.float32),
|
||||||
|
tf.ragged.constant([[10.], [10., 0., 0.],
|
||||||
|
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||||
|
)
|
||||||
|
def test_get_noised_result_clipped(self, arity, record, expected_tree):
|
||||||
|
query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
sample_state, global_state = query.get_noised_result(
|
||||||
|
preprocessed_record, global_state)
|
||||||
|
|
||||||
|
self.assertAllClose(sample_state, expected_tree)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def test_initial_global_state_type(self):
|
||||||
|
|
||||||
|
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
self.assertIsInstance(
|
||||||
|
global_state,
|
||||||
|
tree_aggregation_query.DistributedTreeSumQuery.GlobalState)
|
||||||
|
|
||||||
|
def test_derive_sample_params(self):
|
||||||
|
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
stddev, arity, l1_bound = query.derive_sample_params(
|
||||||
|
global_state)
|
||||||
|
self.assertAllClose(stddev, NOISE_STD)
|
||||||
|
self.assertAllClose(arity, 2)
|
||||||
|
self.assertAllClose(l1_bound, 10)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])),
|
||||||
|
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||||
|
tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])),
|
||||||
|
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.
|
||||||
|
])),
|
||||||
|
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||||
|
tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.
|
||||||
|
])),
|
||||||
|
)
|
||||||
|
def test_preprocess_record(self, arity, record, expected_tree):
|
||||||
|
query = tree_aggregation_query.DistributedTreeSumQuery(
|
||||||
|
stddev=0., arity=arity)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
self.assertAllClose(preprocessed_record, expected_tree)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
||||||
|
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
||||||
|
)
|
||||||
|
def test_preprocess_record_with_noise(self, stddev, record, expected_tree):
|
||||||
|
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=stddev)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
|
||||||
|
preprocessed_record_list = []
|
||||||
|
for _ in range(1000):
|
||||||
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
preprocessed_record_list.append(preprocessed_record.numpy())
|
||||||
|
|
||||||
|
expectation = np.mean(preprocessed_record_list, axis=0)
|
||||||
|
variance = np.std(preprocessed_record_list, axis=0)
|
||||||
|
|
||||||
|
self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4)
|
||||||
|
self.assertAllClose(
|
||||||
|
variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])),
|
||||||
|
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
|
||||||
|
dtype=tf.float32),
|
||||||
|
tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])),
|
||||||
|
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant(
|
||||||
|
[10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])),
|
||||||
|
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
|
||||||
|
dtype=tf.float32),
|
||||||
|
tf.ragged.constant(
|
||||||
|
[10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])),
|
||||||
|
)
|
||||||
|
def test_preprocess_record_clipped(self, arity, record, expected_tree):
|
||||||
|
query = tree_aggregation_query.DistributedTreeSumQuery(
|
||||||
|
stddev=0., arity=arity)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
self.assertAllClose(preprocessed_record, expected_tree)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
|
||||||
|
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||||
|
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
|
||||||
|
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([[1.], [1., 0., 0.],
|
||||||
|
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||||
|
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||||
|
tf.ragged.constant([[1.], [1., 0., 0.],
|
||||||
|
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||||
|
)
|
||||||
|
def test_get_noised_result(self, arity, record, expected_tree):
|
||||||
|
query = tree_aggregation_query.DistributedTreeSumQuery(
|
||||||
|
stddev=0., arity=arity)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
sample_state, global_state = query.get_noised_result(
|
||||||
|
preprocessed_record, global_state)
|
||||||
|
|
||||||
|
self.assertAllClose(sample_state, expected_tree)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
|
||||||
|
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
|
||||||
|
dtype=tf.float32),
|
||||||
|
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
|
||||||
|
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([[10.], [10., 0., 0.],
|
||||||
|
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||||
|
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
|
||||||
|
dtype=tf.float32),
|
||||||
|
tf.ragged.constant([[10.], [10., 0., 0.],
|
||||||
|
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||||
|
)
|
||||||
|
def test_get_noised_result_clipped(self, arity, record, expected_tree):
|
||||||
|
query = tree_aggregation_query.DistributedTreeSumQuery(
|
||||||
|
stddev=0., arity=arity)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
sample_state, global_state = query.get_noised_result(
|
||||||
|
preprocessed_record, global_state)
|
||||||
|
|
||||||
|
self.assertAllClose(sample_state, expected_tree)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
|
@ -365,39 +365,5 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
|
||||||
self.assertAllEqual(gstate, gstate2)
|
self.assertAllEqual(gstate, gstate2)
|
||||||
|
|
||||||
|
|
||||||
class BuildTreeTest(tf.test.TestCase, parameterized.TestCase):
|
|
||||||
|
|
||||||
@parameterized.product(
|
|
||||||
leaf_nodes_size=[1, 2, 3, 4, 5],
|
|
||||||
arity=[2, 3],
|
|
||||||
dtype=[tf.int32, tf.float32],
|
|
||||||
)
|
|
||||||
def test_build_tree_from_leaf(self, leaf_nodes_size, arity, dtype):
|
|
||||||
"""Test whether `build_tree_from_leaf` will output the correct tree."""
|
|
||||||
|
|
||||||
leaf_nodes = tf.cast(tf.range(leaf_nodes_size), dtype)
|
|
||||||
depth = math.ceil(math.log(leaf_nodes_size, arity)) + 1
|
|
||||||
|
|
||||||
tree = tree_aggregation.build_tree_from_leaf(leaf_nodes, arity)
|
|
||||||
|
|
||||||
self.assertEqual(depth, tree.shape[0])
|
|
||||||
|
|
||||||
for layer in range(depth):
|
|
||||||
reverse_depth = tree.shape[0] - layer - 1
|
|
||||||
span_size = arity**reverse_depth
|
|
||||||
for idx in range(arity**layer):
|
|
||||||
left = idx * span_size
|
|
||||||
right = (idx + 1) * span_size
|
|
||||||
expected_value = sum(leaf_nodes[left:right])
|
|
||||||
self.assertEqual(tree[layer][idx], expected_value)
|
|
||||||
|
|
||||||
@parameterized.named_parameters(('negative_arity', [1], -1),
|
|
||||||
('empty_hist', [], 2))
|
|
||||||
def test_value_error_raises(self, leaf_nodes, arity):
|
|
||||||
"""Test whether `build_tree_from_leaf` will raise the correct error when the input is illegal."""
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
tree_aggregation.build_tree_from_leaf(leaf_nodes, arity)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue