(1) add CentralTreeSumQuery and DistributedTreeSumQuery to tree_aggregation_query.py. (2) move build_tree_from_leaf to tree_aggregation_query.py together with CentralTreeSumQuery.

PiperOrigin-RevId: 383511025
This commit is contained in:
A. Unique TensorFlower 2021-07-07 15:54:53 -07:00
parent d6aa796684
commit caf6f36b80
4 changed files with 594 additions and 122 deletions

View file

@ -17,10 +17,6 @@
based on tree aggregation. When using an appropriate noise function (e.g., based on tree aggregation. When using an appropriate noise function (e.g.,
Gaussian noise), it allows for efficient differentially private algorithms under Gaussian noise), it allows for efficient differentially private algorithms under
continual observation, without prior subsampling or shuffling assumptions. continual observation, without prior subsampling or shuffling assumptions.
`build_tree` constructs a tree given the leaf nodes by recursively summing the
children nodes to get the parent node. It allows for efficient range queries and
other statistics such as quantiles on the leaf nodes.
""" """
import abc import abc
@ -449,79 +445,3 @@ class EfficientTreeAggregator():
level_buffer_idx=new_level_buffer_idx, level_buffer_idx=new_level_buffer_idx,
value_generator_state=value_generator_state) value_generator_state=value_generator_state)
return cumsum, new_state return cumsum, new_state
@tf.function
def build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
"""A function constructs a complete tree given all the leaf nodes.
The function takes a 1-D array representing the leaf nodes of a tree and the
tree's arity, and constructs a complete tree by recursively summing the
adjacent children to get the parent until reaching the root node. Because we
assume a complete tree, if the number of leaf nodes does not divide arity, the
leaf nodes will be padded with zeros.
Args:
leaf_nodes: A 1-D array storing the leaf nodes of the tree.
arity: A `int` for the branching factor of the tree, i.e. the number of
children for each internal node.
Returns:
`tf.RaggedTensor` representing the tree. For example, if
`leaf_nodes=tf.Tensor([1, 2, 3, 4])` and `arity=2`, then the returned value
should be `tree=tf.RaggedTensor([[10],[3,7],[1,2,3,4]])`. In this way,
`tree[layer][index]` can be used to access the node indexed by (layer,
index) in the tree,
Raises:
ValueError: if parameters don't meet expectations. There are two situations
where the error is raised: (1) the input tensor has length smaller than 1;
(2) The arity is less than 2.
"""
if len(leaf_nodes) <= 0:
raise ValueError(
'The number of leaf nodes should at least be 1.'
f'However, an array of length {len(leaf_nodes)} is detected')
if arity <= 1:
raise ValueError('The branching factor should be at least 2.'
f'However, a branching factor of {arity} is detected.')
def pad_zero(leaf_nodes, size):
paddings = [[0, size - len(leaf_nodes)]]
return tf.pad(leaf_nodes, paddings)
leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32)
num_layers = tf.math.ceil(
tf.math.log(leaf_nodes_size) /
tf.math.log(tf.constant(arity, dtype=tf.float32))) + 1
leaf_nodes = pad_zero(leaf_nodes, tf.math.pow(float(arity), num_layers - 1))
def _shrink_layer(layer: tf.Tensor, arity: int) -> tf.Tensor:
return tf.reduce_sum((tf.reshape(layer, (-1, arity))), 1)
# The following `tf.while_loop` constructs the tree from bottom up by
# iteratively applying `_shrink_layer` to each layer of the tree. The reason
# for the choice of TF1.0-style `tf.while_loop` is that @tf.function does not
# support auto-translation from python loop to tf loop when loop variables
# contain a `RaggedTensor` whose shape changes across iterations.
idx = tf.identity(num_layers)
loop_cond = lambda i, h: tf.less_equal(2.0, i)
def _loop_body(i, h):
return [
tf.add(i, -1.0),
tf.concat(([_shrink_layer(h[0], arity)], h), axis=0)
]
_, tree = tf.while_loop(
loop_cond,
_loop_body, [idx, tf.RaggedTensor.from_tensor([leaf_nodes])],
shape_invariants=[
idx.get_shape(),
tf.RaggedTensorSpec(dtype=leaf_nodes.dtype, ragged_rank=1)
])
return tree

View file

@ -11,9 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""DPQuery for continual observation queries relying on `tree_aggregation`.""" """`DPQuery`s for differentially private tree aggregation protocols.
`TreeCumulativeSumQuery` and `TreeResidualSumQuery` are `DPQuery`s for continual
online observation queries relying on `tree_aggregation`. 'Online' means that
the leaf nodes of the tree arrive one by one as the time proceeds. The leaves
are vector records as defined in `dp_query.DPQuery`.
`CentralTreeSumQuery` and `DistributedTreeSumQuery` are `DPQuery`s for
central/distributed offline tree aggregation protocol. 'Offline' means all the
leaf nodes are ready before the protocol starts. Each record, different from
what is defined in `dp_query.DPQuery`, is a histogram (i.e. the leaf nodes).
"""
import distutils
import math
import attr import attr
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
@ -31,11 +44,11 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
Attributes: Attributes:
clip_fn: Callable that specifies clipping function. `clip_fn` receives two clip_fn: Callable that specifies clipping function. `clip_fn` receives two
arguments: a flat list of vars in a record and a `clip_value` to clip the arguments: a flat list of vars in a record and a `clip_value` to clip the
corresponding record, e.g. clip_fn(flat_record, clip_value). corresponding record, e.g. clip_fn(flat_record, clip_value).
clip_value: float indicating the value at which to clip the record. clip_value: float indicating the value at which to clip the record.
record_specs: `Collection[tf.TensorSpec]` specifying shapes of records. record_specs: `Collection[tf.TensorSpec]` specifying shapes of records.
tree_aggregator: `tree_aggregation.TreeAggregator` initialized with tree_aggregator: `tree_aggregation.TreeAggregator` initialized with user
user defined `noise_generator`. `noise_generator` is a defined `noise_generator`. `noise_generator` is a
`tree_aggregation.ValueGenerator` to generate the noise value for a tree `tree_aggregation.ValueGenerator` to generate the noise value for a tree
node. Noise stdandard deviation is specified outside the `dp_query` by the node. Noise stdandard deviation is specified outside the `dp_query` by the
user when defining `noise_fn` and should have order user when defining `noise_fn` and should have order
@ -209,7 +222,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
Attributes: Attributes:
clip_fn: Callable that specifies clipping function. `clip_fn` receives two clip_fn: Callable that specifies clipping function. `clip_fn` receives two
arguments: a flat list of vars in a record and a `clip_value` to clip the arguments: a flat list of vars in a record and a `clip_value` to clip the
corresponding record, e.g. clip_fn(flat_record, clip_value). corresponding record, e.g. clip_fn(flat_record, clip_value).
clip_value: float indicating the value at which to clip the record. clip_value: float indicating the value at which to clip the record.
record_specs: A nested structure of `tf.TensorSpec`s specifying structure record_specs: A nested structure of `tf.TensorSpec`s specifying structure
and shapes of records. and shapes of records.
@ -364,3 +377,297 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
record_specs=record_specs, record_specs=record_specs,
noise_generator=gaussian_noise_generator, noise_generator=gaussian_noise_generator,
use_efficient=use_efficient) use_efficient=use_efficient)
@tf.function
def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
"""A function constructs a complete tree given all the leaf nodes.
The function takes a 1-D array representing the leaf nodes of a tree and the
tree's arity, and constructs a complete tree by recursively summing the
adjacent children to get the parent until reaching the root node. Because we
assume a complete tree, if the number of leaf nodes does not divide arity, the
leaf nodes will be padded with zeros.
Args:
leaf_nodes: A 1-D array storing the leaf nodes of the tree.
arity: A `int` for the branching factor of the tree, i.e. the number of
children for each internal node.
Returns:
`tf.RaggedTensor` representing the tree. For example, if
`leaf_nodes=tf.Tensor([1, 2, 3, 4])` and `arity=2`, then the returned value
should be `tree=tf.RaggedTensor([[10],[3,7],[1,2,3,4]])`. In this way,
`tree[layer][index]` can be used to access the node indexed by (layer,
index) in the tree,
"""
def pad_zero(leaf_nodes, size):
paddings = [[0, size - len(leaf_nodes)]]
return tf.pad(leaf_nodes, paddings)
leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32)
num_layers = tf.math.ceil(
tf.math.log(leaf_nodes_size) /
tf.math.log(tf.cast(arity, dtype=tf.float32))) + 1
leaf_nodes = pad_zero(
leaf_nodes, tf.math.pow(tf.cast(arity, dtype=tf.float32), num_layers - 1))
def _shrink_layer(layer: tf.Tensor, arity: int) -> tf.Tensor:
return tf.reduce_sum((tf.reshape(layer, (-1, arity))), 1)
# The following `tf.while_loop` constructs the tree from bottom up by
# iteratively applying `_shrink_layer` to each layer of the tree. The reason
# for the choice of TF1.0-style `tf.while_loop` is that @tf.function does not
# support auto-translation from python loop to tf loop when loop variables
# contain a `RaggedTensor` whose shape changes across iterations.
idx = tf.identity(num_layers)
loop_cond = lambda i, h: tf.less_equal(2.0, i)
def _loop_body(i, h):
return [
tf.add(i, -1.0),
tf.concat(([_shrink_layer(h[0], arity)], h), axis=0)
]
_, tree = tf.while_loop(
loop_cond,
_loop_body, [idx, tf.RaggedTensor.from_tensor([leaf_nodes])],
shape_invariants=[
idx.get_shape(),
tf.RaggedTensorSpec(dtype=leaf_nodes.dtype, ragged_rank=1)
])
return tree
def _get_add_noise(stddev):
"""Utility function to decide which `add_noise` to use according to tf version."""
if distutils.version.LooseVersion(
tf.__version__) < distutils.version.LooseVersion('2.0.0'):
def add_noise(v):
return v + tf.random.normal(
tf.shape(input=v), stddev=stddev, dtype=v.dtype)
else:
random_normal = tf.random_normal_initializer(stddev=stddev)
def add_noise(v):
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
return add_noise
class CentralTreeSumQuery(dp_query.SumAggregationDPQuery):
"""Implements dp_query for differentially private tree aggregation protocol.
Implements a central variant of the tree aggregation protocol from the paper
"'Is interaction necessary for distributed private learning?.' Adam Smith,
Abhradeep Thakurta, Jalaj Upadhyay" by replacing their local randomizer with
gaussian mechanism. The first step is to clip the clients' local updates (i.e.
a 1-D array containing the leaf nodes of the tree) by L1 norm to make sure it
does not exceed a prespecified upper bound. The second step is to construct
the tree on the clipped update. The third step is to add independent gaussian
noise to each node in the tree. The returned tree can support efficient and
accurate range queries with differential privacy.
"""
@attr.s(frozen=True)
class GlobalState(object):
"""Class defining global state for `CentralTreeSumQuery`.
Attributes:
stddev: The stddev of the noise added to each node in the tree.
arity: The branching factor of the tree (i.e. the number of children each
internal node has).
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
"""
stddev = attr.ib()
arity = attr.ib()
l1_bound = attr.ib()
def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10):
"""Initializes the `CentralTreeSumQuery`.
Args:
stddev: The stddev of the noise added to each internal node of the
constructed tree.
arity: The branching factor of the tree.
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
"""
self._stddev = stddev
self._arity = arity
self._l1_bound = l1_bound
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return CentralTreeSumQuery.GlobalState(
stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound)
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return global_state.l1_bound
def preprocess_record(self, params, record):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
casted_record = tf.cast(record, tf.float32)
l1_norm = tf.norm(casted_record, ord=1)
l1_bound = tf.cast(params, tf.float32)
preprocessed_record, _ = tf.clip_by_global_norm([casted_record],
l1_bound,
use_norm=l1_norm)
return preprocessed_record[0]
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
Args:
sample_state: a frequency histogram.
global_state: hyper-parameters of the query.
Returns:
a `tf.RaggedTensor` representing the tree built on top of `sample_state`.
The jth node on the ith layer of the tree can be accessed by tree[i][j]
where tree is the returned value.
"""
add_noise = _get_add_noise(self._stddev)
tree = _build_tree_from_leaf(sample_state, global_state.arity)
return tf.nest.map_structure(
add_noise, tree, expand_composites=True), global_state
class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
"""Implements dp_query for differentially private tree aggregation protocol.
The difference from `CentralTreeSumQuery` is that the tree construction and
gaussian noise addition happen in `preprocess_records`. The difference only
takes effect when used together with
`tff.aggregators.DifferentiallyPrivateFactory`. In other cases, this class
should be treated as equal with `CentralTreeSumQuery`.
Implements a distributed version of the tree aggregation protocol from. "Is
interaction necessary for distributed private learning?." by replacing their
local randomizer with gaussian mechanism. The first step is to check the L1
norm of the clients' local updates (i.e. a 1-D array containing the leaf nodes
of the tree) to make sure it does not exceed a prespecified upper bound. The
second step is to construct the tree. The third step is to add independent
gaussian noise to each node in the tree. The returned tree can support
efficient and accurate range queries with differential privacy.
"""
@attr.s(frozen=True)
class GlobalState(object):
"""Class defining global state for DistributedTreeSumQuery.
Attributes:
stddev: The stddev of the noise added to each internal node in the
constructed tree.
arity: The branching factor of the tree (i.e. the number of children each
internal node has).
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
"""
stddev = attr.ib()
arity = attr.ib()
l1_bound = attr.ib()
def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10):
"""Initializes the `DistributedTreeSumQuery`.
Args:
stddev: The stddev of the noise added to each node in the tree.
arity: The branching factor of the tree.
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
"""
self._stddev = stddev
self._arity = arity
self._l1_bound = l1_bound
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return DistributedTreeSumQuery.GlobalState(
stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound)
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return (global_state.stddev, global_state.arity, global_state.l1_bound)
def preprocess_record(self, params, record):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
This method clips the input record by L1 norm, constructs a tree on top of
it, and adds gaussian noise to each node of the tree for differential
privacy. Unlike `get_noised_result` in `CentralTreeSumQuery`, this function
flattens the `tf.RaggedTensor` before outputting it. This is useful when
used inside `tff.aggregators.DifferentiallyPrivateFactory` because it does
not accept ragged output tensor.
Args:
params: hyper-parameters for preprocessing record, (stddev, aritry,
l1_bound)
record: leaf nodes for the tree.
Returns:
`tf.Tensor` representing the flattened version of the tree.
"""
_, arity, l1_bound_ = params
l1_bound = tf.cast(l1_bound_, tf.float32)
casted_record = tf.cast(record, tf.float32)
l1_norm = tf.norm(casted_record, ord=1)
preprocessed_record, _ = tf.clip_by_global_norm([casted_record],
l1_bound,
use_norm=l1_norm)
preprocessed_record = preprocessed_record[0]
add_noise = _get_add_noise(self._stddev)
tree = _build_tree_from_leaf(preprocessed_record, arity)
noisy_tree = tf.nest.map_structure(add_noise, tree, expand_composites=True)
# The following codes reshape the output vector so the output shape of can
# be statically inferred. This is useful when used with
# `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know
# the output shape of this function statically and explicitly.
flat_noisy_tree = noisy_tree.flat_values
flat_tree_shape = [
(self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) -
1) // (self._arity - 1)
]
return tf.reshape(flat_noisy_tree, flat_tree_shape)
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
This function re-constructs the `tf.RaggedTensor` from the flattened tree
output by `preprocess_records.`
Args:
sample_state: `tf.Tensor` for the flattened tree.
global_state: hyper-parameters including noise multiplier, the branching
factor of the tree and the maximum records per user.
Returns:
a `tf.RaggedTensor` for the tree.
"""
# The [0] is needed because of how tf.RaggedTensor.from_two_splits works.
# print(tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6],
# row_splits=[0, 4, 4, 7, 8, 8]))
# <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
# This part is not written in tensorflow and will be executed on the server
# side instead of the client side if used with
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
row_splits = [0] + [
(self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(
math.floor(math.log(sample_state.shape[0], self._arity)) + 1)
]
tree = tf.RaggedTensor.from_row_splits(
values=sample_state, row_splits=row_splits)
return tree, global_state

View file

@ -13,16 +13,15 @@
# limitations under the License. # limitations under the License.
"""Tests for `tree_aggregation_query`.""" """Tests for `tree_aggregation_query`."""
from absl.testing import parameterized import math
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import test_utils from tensorflow_privacy.privacy.dp_query import test_utils
from tensorflow_privacy.privacy.dp_query import tree_aggregation from tensorflow_privacy.privacy.dp_query import tree_aggregation
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
STRUCT_RECORD = [ STRUCT_RECORD = [
tf.constant([[2.0, 0.0], [0.0, 1.0]]), tf.constant([[2.0, 0.0], [0.0, 1.0]]),
tf.constant([-1.0, 0.0]) tf.constant([-1.0, 0.0])
@ -55,6 +54,7 @@ def _get_noise_fn(specs, stddev=NOISE_STD, seed=1):
def _get_no_noise_fn(specs): def _get_no_noise_fn(specs):
shape = tf.nest.map_structure(lambda spec: spec.shape, specs) shape = tf.nest.map_structure(lambda spec: spec.shape, specs)
def no_noise_fn(): def no_noise_fn():
return tf.nest.map_structure(tf.zeros, shape) return tf.nest.map_structure(tf.zeros, shape)
@ -73,6 +73,7 @@ def _get_l2_clip_fn():
def _get_l_infty_clip_fn(): def _get_l_infty_clip_fn():
def l_infty_clip_fn(record_as_list, clip_value): def l_infty_clip_fn(record_as_list, clip_value):
def clip(record): def clip(record):
return tf.clip_by_value( return tf.clip_by_value(
record, clip_value_min=-clip_value, clip_value_max=clip_value) record, clip_value_min=-clip_value, clip_value_max=clip_value)
@ -395,5 +396,283 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(query._tree_aggregator, tree_class) self.assertIsInstance(query._tree_aggregator, tree_class)
class BuildTreeTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
leaf_nodes_size=[1, 2, 3, 4, 5],
arity=[2, 3],
dtype=[tf.int32, tf.float32],
)
def test_build_tree_from_leaf(self, leaf_nodes_size, arity, dtype):
"""Test whether `_build_tree_from_leaf` will output the correct tree."""
leaf_nodes = tf.cast(tf.range(leaf_nodes_size), dtype)
depth = math.ceil(math.log(leaf_nodes_size, arity)) + 1
tree = tree_aggregation_query._build_tree_from_leaf(leaf_nodes, arity)
self.assertEqual(depth, tree.shape[0])
for layer in range(depth):
reverse_depth = tree.shape[0] - layer - 1
span_size = arity**reverse_depth
for idx in range(arity**layer):
left = idx * span_size
right = (idx + 1) * span_size
expected_value = sum(leaf_nodes[left:right])
self.assertEqual(tree[layer][idx], expected_value)
class CentralTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_initial_global_state_type(self):
query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD)
global_state = query.initial_global_state()
self.assertIsInstance(
global_state, tree_aggregation_query.CentralTreeSumQuery.GlobalState)
def test_derive_sample_params(self):
query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
self.assertAllClose(params, 10.)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32)),
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32)),
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32)),
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.],
dtype=tf.float32)),
)
def test_preprocess_record(self, arity, record):
query = tree_aggregation_query.CentralTreeSumQuery(
stddev=NOISE_STD, arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(preprocessed_record, record)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.constant([5, 5, 0, 0], dtype=tf.int32)),
('binary_test_float', 2, tf.constant(
[10., 10., 0., 0.],
dtype=tf.float32), tf.constant([5., 5., 0., 0.], dtype=tf.float32)),
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.constant([5, 5, 0, 0], dtype=tf.int32)),
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.constant([5., 5., 0., 0.], dtype=tf.float32)),
)
def test_preprocess_record_clipped(self, arity, record,
expected_clipped_value):
query = tree_aggregation_query.CentralTreeSumQuery(
stddev=NOISE_STD, arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(preprocessed_record, expected_clipped_value)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([[1.], [1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([[1.], [1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
)
def test_get_noised_result(self, arity, record, expected_tree):
query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(sample_state, expected_tree)
@parameterized.named_parameters(
('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
)
def test_get_noised_result_with_noise(self, stddev, record, expected_tree):
query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state_list = []
for _ in range(1000):
sample_state, _ = query.get_noised_result(preprocessed_record,
global_state)
sample_state_list.append(sample_state.flat_values.numpy())
expectation = np.mean(sample_state_list, axis=0)
variance = np.std(sample_state_list, axis=0)
self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4)
self.assertAllClose(
variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant([[10.], [10., 0., 0.],
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant([[10.], [10., 0., 0.],
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
)
def test_get_noised_result_clipped(self, arity, record, expected_tree):
query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(sample_state, expected_tree)
class DistributedTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_initial_global_state_type(self):
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
global_state = query.initial_global_state()
self.assertIsInstance(
global_state,
tree_aggregation_query.DistributedTreeSumQuery.GlobalState)
def test_derive_sample_params(self):
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
global_state = query.initial_global_state()
stddev, arity, l1_bound = query.derive_sample_params(
global_state)
self.assertAllClose(stddev, NOISE_STD)
self.assertAllClose(arity, 2)
self.assertAllClose(l1_bound, 10)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])),
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])),
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.
])),
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.
])),
)
def test_preprocess_record(self, arity, record, expected_tree):
query = tree_aggregation_query.DistributedTreeSumQuery(
stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(preprocessed_record, expected_tree)
@parameterized.named_parameters(
('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
)
def test_preprocess_record_with_noise(self, stddev, record, expected_tree):
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=stddev)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record_list = []
for _ in range(1000):
preprocessed_record = query.preprocess_record(params, record)
preprocessed_record_list.append(preprocessed_record.numpy())
expectation = np.mean(preprocessed_record_list, axis=0)
variance = np.std(preprocessed_record_list, axis=0)
self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4)
self.assertAllClose(
variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])),
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])),
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant(
[10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])),
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant(
[10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])),
)
def test_preprocess_record_clipped(self, arity, record, expected_tree):
query = tree_aggregation_query.DistributedTreeSumQuery(
stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(preprocessed_record, expected_tree)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([[1.], [1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([[1.], [1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
)
def test_get_noised_result(self, arity, record, expected_tree):
query = tree_aggregation_query.DistributedTreeSumQuery(
stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(sample_state, expected_tree)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant([[10.], [10., 0., 0.],
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant([[10.], [10., 0., 0.],
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
)
def test_get_noised_result_clipped(self, arity, record, expected_tree):
query = tree_aggregation_query.DistributedTreeSumQuery(
stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(sample_state, expected_tree)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()

View file

@ -365,39 +365,5 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
self.assertAllEqual(gstate, gstate2) self.assertAllEqual(gstate, gstate2)
class BuildTreeTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
leaf_nodes_size=[1, 2, 3, 4, 5],
arity=[2, 3],
dtype=[tf.int32, tf.float32],
)
def test_build_tree_from_leaf(self, leaf_nodes_size, arity, dtype):
"""Test whether `build_tree_from_leaf` will output the correct tree."""
leaf_nodes = tf.cast(tf.range(leaf_nodes_size), dtype)
depth = math.ceil(math.log(leaf_nodes_size, arity)) + 1
tree = tree_aggregation.build_tree_from_leaf(leaf_nodes, arity)
self.assertEqual(depth, tree.shape[0])
for layer in range(depth):
reverse_depth = tree.shape[0] - layer - 1
span_size = arity**reverse_depth
for idx in range(arity**layer):
left = idx * span_size
right = (idx + 1) * span_size
expected_value = sum(leaf_nodes[left:right])
self.assertEqual(tree[layer][idx], expected_value)
@parameterized.named_parameters(('negative_arity', [1], -1),
('empty_hist', [], 2))
def test_value_error_raises(self, leaf_nodes, arity):
"""Test whether `build_tree_from_leaf` will raise the correct error when the input is illegal."""
with self.assertRaises(ValueError):
tree_aggregation.build_tree_from_leaf(leaf_nodes, arity)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()