parent
ef83391ce6
commit
b9e4cf1a20
2 changed files with 0 additions and 459 deletions
|
@ -22,7 +22,6 @@ the leaf nodes of the tree arrive one by one as the time proceeds.
|
|||
"""
|
||||
import distutils
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import attr
|
||||
import tensorflow as tf
|
||||
|
@ -732,224 +731,3 @@ def _get_add_noise(stddev, seed: int = None):
|
|||
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
|
||||
|
||||
return add_noise
|
||||
|
||||
|
||||
class CentralTreeSumQuery(dp_query.SumAggregationDPQuery):
|
||||
"""Implements dp_query for differentially private tree aggregation protocol.
|
||||
|
||||
Implements a central variant of the tree aggregation protocol from the paper
|
||||
"'Is interaction necessary for distributed private learning?.' Adam Smith,
|
||||
Abhradeep Thakurta, Jalaj Upadhyay" by replacing their local randomizer with
|
||||
gaussian mechanism. The first step is to clip the clients' local updates (i.e.
|
||||
a 1-D array containing the leaf nodes of the tree) by L1 norm to make sure it
|
||||
does not exceed a prespecified upper bound. The second step is to construct
|
||||
the tree on the clipped update. The third step is to add independent gaussian
|
||||
noise to each node in the tree. The returned tree can support efficient and
|
||||
accurate range queries with differential privacy.
|
||||
"""
|
||||
|
||||
@attr.s(frozen=True)
|
||||
class GlobalState(object):
|
||||
"""Class defining global state for `CentralTreeSumQuery`.
|
||||
|
||||
Attributes:
|
||||
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||
needed to bound the sensitivity and deploy differential privacy.
|
||||
"""
|
||||
l1_bound = attr.ib()
|
||||
|
||||
def __init__(self,
|
||||
stddev: float,
|
||||
arity: int = 2,
|
||||
l1_bound: int = 10,
|
||||
seed: Optional[int] = None):
|
||||
"""Initializes the `CentralTreeSumQuery`.
|
||||
|
||||
Args:
|
||||
stddev: The stddev of the noise added to each internal node of the
|
||||
constructed tree.
|
||||
arity: The branching factor of the tree.
|
||||
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||
needed to bound the sensitivity and deploy differential privacy.
|
||||
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
|
||||
test purpose.
|
||||
"""
|
||||
self._stddev = stddev
|
||||
self._arity = arity
|
||||
self._l1_bound = l1_bound
|
||||
self._seed = seed
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
return CentralTreeSumQuery.GlobalState(l1_bound=self._l1_bound)
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
return global_state.l1_bound
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
|
||||
casted_record = tf.cast(record, tf.float32)
|
||||
l1_norm = tf.norm(casted_record, ord=1)
|
||||
|
||||
l1_bound = tf.cast(params, tf.float32)
|
||||
|
||||
preprocessed_record, _ = tf.clip_by_global_norm([casted_record],
|
||||
l1_bound,
|
||||
use_norm=l1_norm)
|
||||
|
||||
return preprocessed_record[0]
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
|
||||
|
||||
Args:
|
||||
sample_state: a frequency histogram.
|
||||
global_state: hyper-parameters of the query.
|
||||
|
||||
Returns:
|
||||
a `tf.RaggedTensor` representing the tree built on top of `sample_state`.
|
||||
The jth node on the ith layer of the tree can be accessed by tree[i][j]
|
||||
where tree is the returned value.
|
||||
"""
|
||||
add_noise = _get_add_noise(self._stddev, self._seed)
|
||||
tree = _build_tree_from_leaf(sample_state, self._arity)
|
||||
return tf.map_fn(add_noise, tree), global_state
|
||||
|
||||
|
||||
class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
|
||||
"""Implements dp_query for differentially private tree aggregation protocol.
|
||||
|
||||
The difference from `CentralTreeSumQuery` is that the tree construction and
|
||||
gaussian noise addition happen in `preprocess_records`. The difference only
|
||||
takes effect when used together with
|
||||
`tff.aggregators.DifferentiallyPrivateFactory`. In other cases, this class
|
||||
should be treated as equal with `CentralTreeSumQuery`.
|
||||
|
||||
Implements a distributed version of the tree aggregation protocol from. "Is
|
||||
interaction necessary for distributed private learning?." by replacing their
|
||||
local randomizer with gaussian mechanism. The first step is to check the L1
|
||||
norm of the clients' local updates (i.e. a 1-D array containing the leaf nodes
|
||||
of the tree) to make sure it does not exceed a prespecified upper bound. The
|
||||
second step is to construct the tree. The third step is to add independent
|
||||
gaussian noise to each node in the tree. The returned tree can support
|
||||
efficient and accurate range queries with differential privacy.
|
||||
"""
|
||||
|
||||
@attr.s(frozen=True)
|
||||
class GlobalState(object):
|
||||
"""Class defining global state for DistributedTreeSumQuery.
|
||||
|
||||
Attributes:
|
||||
stddev: The stddev of the noise added to each internal node in the
|
||||
constructed tree.
|
||||
arity: The branching factor of the tree (i.e. the number of children each
|
||||
internal node has).
|
||||
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||
needed to bound the sensitivity and deploy differential privacy.
|
||||
"""
|
||||
stddev = attr.ib()
|
||||
arity = attr.ib()
|
||||
l1_bound = attr.ib()
|
||||
|
||||
def __init__(self,
|
||||
stddev: float,
|
||||
arity: int = 2,
|
||||
l1_bound: int = 10,
|
||||
seed: Optional[int] = None):
|
||||
"""Initializes the `DistributedTreeSumQuery`.
|
||||
|
||||
Args:
|
||||
stddev: The stddev of the noise added to each node in the tree.
|
||||
arity: The branching factor of the tree.
|
||||
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||
needed to bound the sensitivity and deploy differential privacy.
|
||||
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
|
||||
test purpose.
|
||||
"""
|
||||
self._stddev = stddev
|
||||
self._arity = arity
|
||||
self._l1_bound = l1_bound
|
||||
self._seed = seed
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
return DistributedTreeSumQuery.GlobalState(
|
||||
stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound)
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
return (global_state.stddev, global_state.arity, global_state.l1_bound)
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
|
||||
|
||||
This method clips the input record by L1 norm, constructs a tree on top of
|
||||
it, and adds gaussian noise to each node of the tree for differential
|
||||
privacy. Unlike `get_noised_result` in `CentralTreeSumQuery`, this function
|
||||
flattens the `tf.RaggedTensor` before outputting it. This is useful when
|
||||
used inside `tff.aggregators.DifferentiallyPrivateFactory` because it does
|
||||
not accept ragged output tensor.
|
||||
|
||||
Args:
|
||||
params: hyper-parameters for preprocessing record, (stddev, aritry,
|
||||
l1_bound)
|
||||
record: leaf nodes for the tree.
|
||||
|
||||
Returns:
|
||||
`tf.Tensor` representing the flattened version of the tree.
|
||||
"""
|
||||
_, arity, l1_bound_ = params
|
||||
l1_bound = tf.cast(l1_bound_, tf.float32)
|
||||
|
||||
casted_record = tf.cast(record, tf.float32)
|
||||
l1_norm = tf.norm(casted_record, ord=1)
|
||||
|
||||
preprocessed_record, _ = tf.clip_by_global_norm([casted_record],
|
||||
l1_bound,
|
||||
use_norm=l1_norm)
|
||||
preprocessed_record = preprocessed_record[0]
|
||||
|
||||
add_noise = _get_add_noise(self._stddev, self._seed)
|
||||
tree = _build_tree_from_leaf(preprocessed_record, arity)
|
||||
noisy_tree = tf.map_fn(add_noise, tree)
|
||||
|
||||
# The following codes reshape the output vector so the output shape of can
|
||||
# be statically inferred. This is useful when used with
|
||||
# `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know
|
||||
# the output shape of this function statically and explicitly.
|
||||
flat_noisy_tree = noisy_tree.flat_values
|
||||
flat_tree_shape = [
|
||||
(self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) -
|
||||
1) // (self._arity - 1)
|
||||
]
|
||||
return tf.reshape(flat_noisy_tree, flat_tree_shape)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
|
||||
|
||||
This function re-constructs the `tf.RaggedTensor` from the flattened tree
|
||||
output by `preprocess_records.`
|
||||
|
||||
Args:
|
||||
sample_state: `tf.Tensor` for the flattened tree.
|
||||
global_state: hyper-parameters including noise multiplier, the branching
|
||||
factor of the tree and the maximum records per user.
|
||||
|
||||
Returns:
|
||||
a `tf.RaggedTensor` for the tree.
|
||||
"""
|
||||
# The [0] is needed because of how tf.RaggedTensor.from_two_splits works.
|
||||
# print(tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6],
|
||||
# row_splits=[0, 4, 4, 7, 8, 8]))
|
||||
# <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
|
||||
# This part is not written in tensorflow and will be executed on the server
|
||||
# side instead of the client side if used with
|
||||
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
|
||||
row_splits = [0] + [
|
||||
(self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(
|
||||
math.floor(math.log(sample_state.shape[0], self._arity)) + 1)
|
||||
]
|
||||
tree = tf.RaggedTensor.from_row_splits(
|
||||
values=sample_state, row_splits=row_splits)
|
||||
return tree, global_state
|
||||
|
|
|
@ -630,242 +630,5 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
sample_state, tf.ragged.constant([[1.], [1., 0.]]), atol=10 * stddev)
|
||||
|
||||
|
||||
class CentralTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_initial_global_state_type(self):
|
||||
|
||||
query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD)
|
||||
global_state = query.initial_global_state()
|
||||
self.assertIsInstance(
|
||||
global_state, tree_aggregation_query.CentralTreeSumQuery.GlobalState)
|
||||
|
||||
def test_derive_sample_params(self):
|
||||
query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
self.assertAllClose(params, 10.)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32)),
|
||||
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32)),
|
||||
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32)),
|
||||
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.],
|
||||
dtype=tf.float32)),
|
||||
)
|
||||
def test_preprocess_record(self, arity, record):
|
||||
query = tree_aggregation_query.CentralTreeSumQuery(
|
||||
stddev=NOISE_STD, arity=arity)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
|
||||
self.assertAllClose(preprocessed_record, record)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||
tf.constant([5, 5, 0, 0], dtype=tf.int32)),
|
||||
('binary_test_float', 2, tf.constant(
|
||||
[10., 10., 0., 0.],
|
||||
dtype=tf.float32), tf.constant([5., 5., 0., 0.], dtype=tf.float32)),
|
||||
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||
tf.constant([5, 5, 0, 0], dtype=tf.int32)),
|
||||
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
|
||||
dtype=tf.float32),
|
||||
tf.constant([5., 5., 0., 0.], dtype=tf.float32)),
|
||||
)
|
||||
def test_preprocess_record_clipped(self, arity, record,
|
||||
expected_clipped_value):
|
||||
query = tree_aggregation_query.CentralTreeSumQuery(
|
||||
stddev=NOISE_STD, arity=arity)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
self.assertAllClose(preprocessed_record, expected_clipped_value)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
|
||||
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
|
||||
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||
tf.ragged.constant([[1.], [1., 0., 0.],
|
||||
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||
tf.ragged.constant([[1.], [1., 0., 0.],
|
||||
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||
)
|
||||
def test_get_noised_result(self, arity, record, expected_tree):
|
||||
query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
sample_state, global_state = query.get_noised_result(
|
||||
preprocessed_record, global_state)
|
||||
|
||||
self.assertAllClose(sample_state, expected_tree)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
||||
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
||||
)
|
||||
def test_get_noised_result_with_noise(self, stddev, record, expected_tree):
|
||||
query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev, seed=0)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
|
||||
sample_state, _ = query.get_noised_result(preprocessed_record, global_state)
|
||||
|
||||
self.assertAllClose(
|
||||
sample_state.flat_values, expected_tree, atol=3 * stddev)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
|
||||
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
|
||||
dtype=tf.float32),
|
||||
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
|
||||
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||
tf.ragged.constant([[10.], [10., 0., 0.],
|
||||
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
|
||||
dtype=tf.float32),
|
||||
tf.ragged.constant([[10.], [10., 0., 0.],
|
||||
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||
)
|
||||
def test_get_noised_result_clipped(self, arity, record, expected_tree):
|
||||
query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
sample_state, global_state = query.get_noised_result(
|
||||
preprocessed_record, global_state)
|
||||
|
||||
self.assertAllClose(sample_state, expected_tree)
|
||||
|
||||
|
||||
class DistributedTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_initial_global_state_type(self):
|
||||
|
||||
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
|
||||
global_state = query.initial_global_state()
|
||||
self.assertIsInstance(
|
||||
global_state,
|
||||
tree_aggregation_query.DistributedTreeSumQuery.GlobalState)
|
||||
|
||||
def test_derive_sample_params(self):
|
||||
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
|
||||
global_state = query.initial_global_state()
|
||||
stddev, arity, l1_bound = query.derive_sample_params(global_state)
|
||||
self.assertAllClose(stddev, NOISE_STD)
|
||||
self.assertAllClose(arity, 2)
|
||||
self.assertAllClose(l1_bound, 10)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||
tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])),
|
||||
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||
tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])),
|
||||
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||
tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.
|
||||
])),
|
||||
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||
tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.
|
||||
])),
|
||||
)
|
||||
def test_preprocess_record(self, arity, record, expected_tree):
|
||||
query = tree_aggregation_query.DistributedTreeSumQuery(
|
||||
stddev=0., arity=arity)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
self.assertAllClose(preprocessed_record, expected_tree)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
||||
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
||||
)
|
||||
def test_preprocess_record_with_noise(self, stddev, record, expected_tree):
|
||||
query = tree_aggregation_query.DistributedTreeSumQuery(
|
||||
stddev=stddev, seed=0)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
|
||||
self.assertAllClose(preprocessed_record, expected_tree, atol=3 * stddev)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||
tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])),
|
||||
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
|
||||
dtype=tf.float32),
|
||||
tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])),
|
||||
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||
tf.ragged.constant(
|
||||
[10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])),
|
||||
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
|
||||
dtype=tf.float32),
|
||||
tf.ragged.constant(
|
||||
[10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])),
|
||||
)
|
||||
def test_preprocess_record_clipped(self, arity, record, expected_tree):
|
||||
query = tree_aggregation_query.DistributedTreeSumQuery(
|
||||
stddev=0., arity=arity)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
self.assertAllClose(preprocessed_record, expected_tree)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
|
||||
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
|
||||
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||
tf.ragged.constant([[1.], [1., 0., 0.],
|
||||
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||
tf.ragged.constant([[1.], [1., 0., 0.],
|
||||
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||
)
|
||||
def test_get_noised_result(self, arity, record, expected_tree):
|
||||
query = tree_aggregation_query.DistributedTreeSumQuery(
|
||||
stddev=0., arity=arity)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
sample_state, global_state = query.get_noised_result(
|
||||
preprocessed_record, global_state)
|
||||
|
||||
self.assertAllClose(sample_state, expected_tree)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
|
||||
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
|
||||
dtype=tf.float32),
|
||||
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
|
||||
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||
tf.ragged.constant([[10.], [10., 0., 0.],
|
||||
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
|
||||
dtype=tf.float32),
|
||||
tf.ragged.constant([[10.], [10., 0., 0.],
|
||||
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||
)
|
||||
def test_get_noised_result_clipped(self, arity, record, expected_tree):
|
||||
query = tree_aggregation_query.DistributedTreeSumQuery(
|
||||
stddev=0., arity=arity)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
sample_state, global_state = query.get_noised_result(
|
||||
preprocessed_record, global_state)
|
||||
|
||||
self.assertAllClose(sample_state, expected_tree)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in a new issue