forked from 626_privacy/tensorflow_privacy
Move TreeRangeSumQuery
to its own module. This is the first step, will remove the function in the old module after a TFP release.
PiperOrigin-RevId: 392776774
This commit is contained in:
parent
477b5b2899
commit
853b18929d
4 changed files with 472 additions and 5 deletions
|
@ -56,6 +56,7 @@ else:
|
|||
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
||||
from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeCumulativeSumQuery
|
||||
from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeResidualSumQuery
|
||||
from tensorflow_privacy.privacy.dp_query.tree_range_query import TreeRangeSumQuery
|
||||
|
||||
# Estimators
|
||||
from tensorflow_privacy.privacy.estimators.dnn import DNNClassifier
|
||||
|
|
|
@ -15,10 +15,9 @@
|
|||
|
||||
`TreeCumulativeSumQuery` and `TreeResidualSumQuery` are `DPQuery`s for continual
|
||||
online observation queries relying on `tree_aggregation`. 'Online' means that
|
||||
the leaf nodes of the tree arrive one by one as the time proceeds.
|
||||
|
||||
`TreeRangeSumQuery` is a `DPQuery`s for offline tree aggregation protocol.
|
||||
'Offline' means all the leaf nodes are ready before the protocol starts.
|
||||
the leaf nodes of the tree arrive one by one as the time proceeds. The core
|
||||
logic of tree aggregation is implemented in `tree_aggregation.TreeAggregator`
|
||||
and `tree_aggregation.EfficientTreeAggregator`.
|
||||
"""
|
||||
import distutils
|
||||
import math
|
||||
|
@ -31,7 +30,7 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query
|
|||
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
||||
|
||||
|
||||
# TODO(b/192464750): define `RestartQuery` and move `RestartIndicator` to be
|
||||
# TODO(b/193679963): define `RestartQuery` and move `RestartIndicator` to be
|
||||
# in the same module.
|
||||
|
||||
|
||||
|
@ -477,6 +476,10 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
restart_indicator=restart_indicator)
|
||||
|
||||
|
||||
# TODO(b/197596864): Remove `TreeRangeSumQuery` from this file after the next
|
||||
# TFP release
|
||||
|
||||
|
||||
@tf.function
|
||||
def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
|
||||
"""A function constructs a complete tree given all the leaf nodes.
|
||||
|
|
281
tensorflow_privacy/privacy/dp_query/tree_range_query.py
Normal file
281
tensorflow_privacy/privacy/dp_query/tree_range_query.py
Normal file
|
@ -0,0 +1,281 @@
|
|||
# Copyright 2021, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""`DPQuery`s for offline differentially private tree aggregation protocols.
|
||||
|
||||
'Offline' means all the leaf nodes are ready before the protocol starts.
|
||||
"""
|
||||
import distutils
|
||||
import math
|
||||
|
||||
import attr
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||
|
||||
|
||||
@tf.function
|
||||
def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
|
||||
"""A function constructs a complete tree given all the leaf nodes.
|
||||
|
||||
The function takes a 1-D array representing the leaf nodes of a tree and the
|
||||
tree's arity, and constructs a complete tree by recursively summing the
|
||||
adjacent children to get the parent until reaching the root node. Because we
|
||||
assume a complete tree, if the number of leaf nodes does not divide arity, the
|
||||
leaf nodes will be padded with zeros.
|
||||
|
||||
Args:
|
||||
leaf_nodes: A 1-D array storing the leaf nodes of the tree.
|
||||
arity: A `int` for the branching factor of the tree, i.e. the number of
|
||||
children for each internal node.
|
||||
|
||||
Returns:
|
||||
`tf.RaggedTensor` representing the tree. For example, if
|
||||
`leaf_nodes=tf.Tensor([1, 2, 3, 4])` and `arity=2`, then the returned value
|
||||
should be `tree=tf.RaggedTensor([[10],[3,7],[1,2,3,4]])`. In this way,
|
||||
`tree[layer][index]` can be used to access the node indexed by (layer,
|
||||
index) in the tree,
|
||||
"""
|
||||
|
||||
def pad_zero(leaf_nodes, size):
|
||||
paddings = [[0, size - len(leaf_nodes)]]
|
||||
return tf.pad(leaf_nodes, paddings)
|
||||
|
||||
leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32)
|
||||
num_layers = tf.math.ceil(
|
||||
tf.math.log(leaf_nodes_size) /
|
||||
tf.math.log(tf.cast(arity, dtype=tf.float32))) + 1
|
||||
leaf_nodes = pad_zero(
|
||||
leaf_nodes, tf.math.pow(tf.cast(arity, dtype=tf.float32), num_layers - 1))
|
||||
|
||||
def _shrink_layer(layer: tf.Tensor, arity: int) -> tf.Tensor:
|
||||
return tf.reduce_sum((tf.reshape(layer, (-1, arity))), 1)
|
||||
|
||||
# The following `tf.while_loop` constructs the tree from bottom up by
|
||||
# iteratively applying `_shrink_layer` to each layer of the tree. The reason
|
||||
# for the choice of TF1.0-style `tf.while_loop` is that @tf.function does not
|
||||
# support auto-translation from python loop to tf loop when loop variables
|
||||
# contain a `RaggedTensor` whose shape changes across iterations.
|
||||
|
||||
idx = tf.identity(num_layers)
|
||||
loop_cond = lambda i, h: tf.less_equal(2.0, i)
|
||||
|
||||
def _loop_body(i, h):
|
||||
return [
|
||||
tf.add(i, -1.0),
|
||||
tf.concat(([_shrink_layer(h[0], arity)], h), axis=0)
|
||||
]
|
||||
|
||||
_, tree = tf.while_loop(
|
||||
loop_cond,
|
||||
_loop_body, [idx, tf.RaggedTensor.from_tensor([leaf_nodes])],
|
||||
shape_invariants=[
|
||||
idx.get_shape(),
|
||||
tf.RaggedTensorSpec(dtype=leaf_nodes.dtype, ragged_rank=1)
|
||||
])
|
||||
|
||||
return tree
|
||||
|
||||
|
||||
class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
||||
"""Implements dp_query for accurate range queries using tree aggregation.
|
||||
|
||||
Implements a variant of the tree aggregation protocol from. "Is interaction
|
||||
necessary for distributed private learning?. Adam Smith, Abhradeep Thakurta,
|
||||
Jalaj Upadhyay." Builds a tree on top of the input record and adds noise to
|
||||
the tree for differential privacy. Any range query can be decomposed into the
|
||||
sum of O(log(n)) nodes in the tree compared to O(n) when using a histogram.
|
||||
Improves efficiency and reduces noise scale.
|
||||
"""
|
||||
|
||||
@attr.s(frozen=True)
|
||||
class GlobalState(object):
|
||||
"""Class defining global state for TreeRangeSumQuery.
|
||||
|
||||
Attributes:
|
||||
arity: The branching factor of the tree (i.e. the number of children each
|
||||
internal node has).
|
||||
inner_query_state: The global state of the inner query.
|
||||
"""
|
||||
arity = attr.ib()
|
||||
inner_query_state = attr.ib()
|
||||
|
||||
def __init__(self,
|
||||
inner_query: dp_query.SumAggregationDPQuery,
|
||||
arity: int = 2):
|
||||
"""Initializes the `TreeRangeSumQuery`.
|
||||
|
||||
Args:
|
||||
inner_query: The inner `DPQuery` that adds noise to the tree.
|
||||
arity: The branching factor of the tree (i.e. the number of children each
|
||||
internal node has). Defaults to 2.
|
||||
"""
|
||||
self._inner_query = inner_query
|
||||
self._arity = arity
|
||||
|
||||
if self._arity < 1:
|
||||
raise ValueError(f'Invalid arity={arity} smaller than 2.')
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
return TreeRangeSumQuery.GlobalState(
|
||||
arity=self._arity,
|
||||
inner_query_state=self._inner_query.initial_global_state())
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
return (global_state.arity,
|
||||
self._inner_query.derive_sample_params(
|
||||
global_state.inner_query_state))
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
|
||||
|
||||
This method builds the tree, flattens it and applies
|
||||
`inner_query.preprocess_record` to the flattened tree.
|
||||
|
||||
Args:
|
||||
params: Hyper-parameters for preprocessing record.
|
||||
record: A histogram representing the leaf nodes of the tree.
|
||||
|
||||
Returns:
|
||||
A `tf.Tensor` representing the flattened version of the preprocessed tree.
|
||||
"""
|
||||
arity, inner_query_params = params
|
||||
preprocessed_record = _build_tree_from_leaf(record, arity).flat_values
|
||||
# The following codes reshape the output vector so the output shape of can
|
||||
# be statically inferred. This is useful when used with
|
||||
# `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know
|
||||
# the output shape of this function statically and explicitly.
|
||||
preprocessed_record_shape = [
|
||||
(self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) -
|
||||
1) // (self._arity - 1)
|
||||
]
|
||||
preprocessed_record = tf.reshape(preprocessed_record,
|
||||
preprocessed_record_shape)
|
||||
preprocessed_record = self._inner_query.preprocess_record(
|
||||
inner_query_params, preprocessed_record)
|
||||
|
||||
return preprocessed_record
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
|
||||
|
||||
This function re-constructs the `tf.RaggedTensor` from the flattened tree
|
||||
output by `preprocess_records.`
|
||||
|
||||
Args:
|
||||
sample_state: A `tf.Tensor` for the flattened tree.
|
||||
global_state: The global state of the protocol.
|
||||
|
||||
Returns:
|
||||
A `tf.RaggedTensor` representing the tree.
|
||||
"""
|
||||
# The [0] is needed because of how tf.RaggedTensor.from_two_splits works.
|
||||
# print(tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6],
|
||||
# row_splits=[0, 4, 4, 7, 8, 8]))
|
||||
# <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
|
||||
# This part is not written in tensorflow and will be executed on the server
|
||||
# side instead of the client side if used with
|
||||
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
|
||||
sample_state, inner_query_state = self._inner_query.get_noised_result(
|
||||
sample_state, global_state.inner_query_state)
|
||||
new_global_state = TreeRangeSumQuery.GlobalState(
|
||||
arity=global_state.arity, inner_query_state=inner_query_state)
|
||||
|
||||
row_splits = [0] + [
|
||||
(self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(
|
||||
math.floor(math.log(sample_state.shape[0], self._arity)) + 1)
|
||||
]
|
||||
tree = tf.RaggedTensor.from_row_splits(
|
||||
values=sample_state, row_splits=row_splits)
|
||||
return tree, new_global_state
|
||||
|
||||
@classmethod
|
||||
def build_central_gaussian_query(cls,
|
||||
l2_norm_clip: float,
|
||||
stddev: float,
|
||||
arity: int = 2):
|
||||
"""Returns `TreeRangeSumQuery` with central Gaussian noise.
|
||||
|
||||
Args:
|
||||
l2_norm_clip: Each record should be clipped so that it has L2 norm at most
|
||||
`l2_norm_clip`.
|
||||
stddev: Stddev of the central Gaussian noise.
|
||||
arity: The branching factor of the tree (i.e. the number of children each
|
||||
internal node has). Defaults to 2.
|
||||
"""
|
||||
if l2_norm_clip <= 0:
|
||||
raise ValueError(f'`l2_norm_clip` must be positive, got {l2_norm_clip}.')
|
||||
|
||||
if stddev < 0:
|
||||
raise ValueError(f'`stddev` must be non-negative, got {stddev}.')
|
||||
|
||||
if arity < 2:
|
||||
raise ValueError(f'`arity` must be at least 2, got {arity}.')
|
||||
|
||||
inner_query = gaussian_query.GaussianSumQuery(l2_norm_clip, stddev)
|
||||
|
||||
return cls(arity=arity, inner_query=inner_query)
|
||||
|
||||
@classmethod
|
||||
def build_distributed_discrete_gaussian_query(cls,
|
||||
l2_norm_bound: float,
|
||||
local_stddev: float,
|
||||
arity: int = 2):
|
||||
"""Returns `TreeRangeSumQuery` with central Gaussian noise.
|
||||
|
||||
Args:
|
||||
l2_norm_bound: Each record should be clipped so that it has L2 norm at
|
||||
most `l2_norm_bound`.
|
||||
local_stddev: Scale/stddev of the local discrete Gaussian noise.
|
||||
arity: The branching factor of the tree (i.e. the number of children each
|
||||
internal node has). Defaults to 2.
|
||||
"""
|
||||
if l2_norm_bound <= 0:
|
||||
raise ValueError(
|
||||
f'`l2_clip_bound` must be positive, got {l2_norm_bound}.')
|
||||
|
||||
if local_stddev < 0:
|
||||
raise ValueError(
|
||||
f'`local_stddev` must be non-negative, got {local_stddev}.')
|
||||
|
||||
if arity < 2:
|
||||
raise ValueError(f'`arity` must be at least 2, got {arity}.')
|
||||
|
||||
inner_query = distributed_discrete_gaussian_query.DistributedDiscreteGaussianSumQuery(
|
||||
l2_norm_bound, local_stddev)
|
||||
|
||||
return cls(arity=arity, inner_query=inner_query)
|
||||
|
||||
|
||||
def _get_add_noise(stddev, seed: int = None):
|
||||
"""Utility function to decide which `add_noise` to use according to tf version."""
|
||||
if distutils.version.LooseVersion(
|
||||
tf.__version__) < distutils.version.LooseVersion('2.0.0'):
|
||||
|
||||
# The seed should be only used for testing purpose.
|
||||
if seed is not None:
|
||||
tf.random.set_seed(seed)
|
||||
|
||||
def add_noise(v):
|
||||
return v + tf.random.normal(
|
||||
tf.shape(input=v), stddev=stddev, dtype=v.dtype)
|
||||
else:
|
||||
random_normal = tf.random_normal_initializer(stddev=stddev, seed=seed)
|
||||
|
||||
def add_noise(v):
|
||||
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
|
||||
|
||||
return add_noise
|
182
tensorflow_privacy/privacy/dp_query/tree_range_query_test.py
Normal file
182
tensorflow_privacy/privacy/dp_query/tree_range_query_test.py
Normal file
|
@ -0,0 +1,182 @@
|
|||
# Copyright 2021, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for `tree_range_query`."""
|
||||
|
||||
import math
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.dp_query import tree_range_query
|
||||
|
||||
|
||||
class BuildTreeTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.product(
|
||||
leaf_nodes_size=[1, 2, 3, 4, 5],
|
||||
arity=[2, 3],
|
||||
dtype=[tf.int32, tf.float32],
|
||||
)
|
||||
def test_build_tree_from_leaf(self, leaf_nodes_size, arity, dtype):
|
||||
"""Test whether `_build_tree_from_leaf` will output the correct tree."""
|
||||
|
||||
leaf_nodes = tf.cast(tf.range(leaf_nodes_size), dtype)
|
||||
depth = math.ceil(math.log(leaf_nodes_size, arity)) + 1
|
||||
|
||||
tree = tree_range_query._build_tree_from_leaf(leaf_nodes, arity)
|
||||
|
||||
self.assertEqual(depth, tree.shape[0])
|
||||
|
||||
for layer in range(depth):
|
||||
reverse_depth = tree.shape[0] - layer - 1
|
||||
span_size = arity**reverse_depth
|
||||
for idx in range(arity**layer):
|
||||
left = idx * span_size
|
||||
right = (idx + 1) * span_size
|
||||
expected_value = sum(leaf_nodes[left:right])
|
||||
self.assertEqual(tree[layer][idx], expected_value)
|
||||
|
||||
|
||||
class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.product(
|
||||
inner_query=['central', 'distributed'],
|
||||
params=[(0., 1., 2), (1., -1., 2), (1., 1., 1)],
|
||||
)
|
||||
def test_raises_error(self, inner_query, params):
|
||||
clip_norm, stddev, arity = params
|
||||
with self.assertRaises(ValueError):
|
||||
if inner_query == 'central':
|
||||
tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
|
||||
clip_norm, stddev, arity)
|
||||
elif inner_query == 'distributed':
|
||||
tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
||||
clip_norm, stddev, arity)
|
||||
|
||||
@parameterized.product(
|
||||
inner_query=['central', 'distributed'],
|
||||
clip_norm=[0.1, 1.0, 10.0],
|
||||
stddev=[0.1, 1.0, 10.0])
|
||||
def test_initial_global_state_type(self, inner_query, clip_norm, stddev):
|
||||
|
||||
if inner_query == 'central':
|
||||
query = tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
|
||||
clip_norm, stddev)
|
||||
elif inner_query == 'distributed':
|
||||
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
||||
clip_norm, stddev)
|
||||
global_state = query.initial_global_state()
|
||||
self.assertIsInstance(global_state,
|
||||
tree_range_query.TreeRangeSumQuery.GlobalState)
|
||||
|
||||
@parameterized.product(
|
||||
inner_query=['central', 'distributed'],
|
||||
clip_norm=[0.1, 1.0, 10.0],
|
||||
stddev=[0.1, 1.0, 10.0],
|
||||
arity=[2, 3, 4])
|
||||
def test_derive_sample_params(self, inner_query, clip_norm, stddev, arity):
|
||||
if inner_query == 'central':
|
||||
query = tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
|
||||
clip_norm, stddev, arity)
|
||||
elif inner_query == 'distributed':
|
||||
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
||||
clip_norm, stddev, arity)
|
||||
global_state = query.initial_global_state()
|
||||
derived_arity, inner_query_state = query.derive_sample_params(global_state)
|
||||
self.assertAllClose(derived_arity, arity)
|
||||
if inner_query == 'central':
|
||||
self.assertAllClose(inner_query_state, clip_norm)
|
||||
elif inner_query == 'distributed':
|
||||
self.assertAllClose(inner_query_state.l2_norm_bound, clip_norm)
|
||||
self.assertAllClose(inner_query_state.local_stddev, stddev)
|
||||
|
||||
@parameterized.product(
|
||||
(dict(arity=2, expected_tree=[1, 1, 0, 1, 0, 0, 0]),
|
||||
dict(arity=3, expected_tree=[1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
inner_query=['central', 'distributed'],
|
||||
)
|
||||
def test_preprocess_record(self, inner_query, arity, expected_tree):
|
||||
if inner_query == 'central':
|
||||
query = tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
|
||||
10., 0., arity)
|
||||
record = tf.constant([1, 0, 0, 0], dtype=tf.float32)
|
||||
expected_tree = tf.cast(expected_tree, tf.float32)
|
||||
elif inner_query == 'distributed':
|
||||
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
||||
10., 0., arity)
|
||||
record = tf.constant([1, 0, 0, 0], dtype=tf.int32)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
self.assertAllClose(preprocessed_record, expected_tree)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('stddev_1', 1, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]),
|
||||
('stddev_0_1', 4, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]),
|
||||
)
|
||||
def test_distributed_preprocess_record_with_noise(self, local_stddev, record,
|
||||
expected_tree):
|
||||
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
||||
10., local_stddev)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
|
||||
self.assertAllClose(
|
||||
preprocessed_record, expected_tree, atol=10 * local_stddev)
|
||||
|
||||
@parameterized.product(
|
||||
(dict(
|
||||
arity=2,
|
||||
expected_tree=tf.ragged.constant([[1], [1, 0], [1, 0, 0, 0]])),
|
||||
dict(
|
||||
arity=3,
|
||||
expected_tree=tf.ragged.constant([[1], [1, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0, 0, 0, 0]]))),
|
||||
inner_query=['central', 'distributed'],
|
||||
)
|
||||
def test_get_noised_result(self, inner_query, arity, expected_tree):
|
||||
if inner_query == 'central':
|
||||
query = tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
|
||||
10., 0., arity)
|
||||
record = tf.constant([1, 0, 0, 0], dtype=tf.float32)
|
||||
expected_tree = tf.cast(expected_tree, tf.float32)
|
||||
elif inner_query == 'distributed':
|
||||
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
||||
10., 0., arity)
|
||||
record = tf.constant([1, 0, 0, 0], dtype=tf.int32)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
sample_state, global_state = query.get_noised_result(
|
||||
preprocessed_record, global_state)
|
||||
|
||||
self.assertAllClose(sample_state, expected_tree)
|
||||
|
||||
@parameterized.product(stddev=[0.1, 1.0, 10.0])
|
||||
def test_central_get_noised_result_with_noise(self, stddev):
|
||||
query = tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
|
||||
10., stddev)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.]))
|
||||
sample_state, global_state = query.get_noised_result(
|
||||
preprocessed_record, global_state)
|
||||
|
||||
self.assertAllClose(
|
||||
sample_state, tf.ragged.constant([[1.], [1., 0.]]), atol=10 * stddev)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
Loading…
Reference in a new issue