Automated rollback of commit 4d335d1b69

PiperOrigin-RevId: 387254617
This commit is contained in:
Keith Rush 2021-07-27 20:03:58 -07:00 committed by A. Unique TensorFlower
parent 4d335d1b69
commit eef5810d94
2 changed files with 393 additions and 207 deletions

View file

@ -15,18 +15,21 @@
`TreeCumulativeSumQuery` and `TreeResidualSumQuery` are `DPQuery`s for continual
online observation queries relying on `tree_aggregation`. 'Online' means that
the leaf nodes of the tree arrive one by one as the time proceeds.
the leaf nodes of the tree arrive one by one as the time proceeds. The leaves
are vector records as defined in `dp_query.DPQuery`.
`TreeRangeSumQuery` is a `DPQuery`s for offline tree aggregation protocol.
'Offline' means all the leaf nodes are ready before the protocol starts.
`CentralTreeSumQuery` and `DistributedTreeSumQuery` are `DPQuery`s for
central/distributed offline tree aggregation protocol. 'Offline' means all the
leaf nodes are ready before the protocol starts. Each record, different from
what is defined in `dp_query.DPQuery`, is a histogram (i.e. the leaf nodes).
"""
import distutils
import math
from typing import Optional
import attr
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation
@ -439,84 +442,217 @@ def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
return tree
class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
"""Implements dp_query for accurate range queries using tree aggregation.
def _get_add_noise(stddev, seed: int = None):
"""Utility function to decide which `add_noise` to use according to tf version."""
if distutils.version.LooseVersion(
tf.__version__) < distutils.version.LooseVersion('2.0.0'):
Implements a variant of the tree aggregation protocol from. "Is interaction
necessary for distributed private learning?. Adam Smith, Abhradeep Thakurta,
Jalaj Upadhyay." Builds a tree on top of the input record and adds noise to
the tree for differential privacy. Any range query can be decomposed into the
sum of O(log(n)) nodes in the tree compared to O(n) when using a histogram.
Improves efficiency and reduces noise scale.
# The seed should be only used for testing purpose.
if seed is not None:
tf.random.set_seed(seed)
def add_noise(v):
return v + tf.random.normal(
tf.shape(input=v), stddev=stddev, dtype=v.dtype)
else:
random_normal = tf.random_normal_initializer(stddev=stddev, seed=seed)
def add_noise(v):
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
return add_noise
class CentralTreeSumQuery(dp_query.SumAggregationDPQuery):
"""Implements dp_query for differentially private tree aggregation protocol.
Implements a central variant of the tree aggregation protocol from the paper
"'Is interaction necessary for distributed private learning?.' Adam Smith,
Abhradeep Thakurta, Jalaj Upadhyay" by replacing their local randomizer with
gaussian mechanism. The first step is to clip the clients' local updates (i.e.
a 1-D array containing the leaf nodes of the tree) by L1 norm to make sure it
does not exceed a prespecified upper bound. The second step is to construct
the tree on the clipped update. The third step is to add independent gaussian
noise to each node in the tree. The returned tree can support efficient and
accurate range queries with differential privacy.
"""
@attr.s(frozen=True)
class GlobalState(object):
"""Class defining global state for TreeRangeSumQuery.
"""Class defining global state for `CentralTreeSumQuery`.
Attributes:
arity: The branching factor of the tree (i.e. the number of children each
internal node has).
inner_query_state: The global state of the inner query.
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
"""
arity = attr.ib()
inner_query_state = attr.ib()
l1_bound = attr.ib()
def __init__(self,
inner_query: dp_query.SumAggregationDPQuery,
arity: int = 2):
"""Initializes the `TreeRangeSumQuery`.
stddev: float,
arity: int = 2,
l1_bound: int = 10,
seed: Optional[int] = None):
"""Initializes the `CentralTreeSumQuery`.
Args:
inner_query: The inner `DPQuery` that adds noise to the tree.
arity: The branching factor of the tree (i.e. the number of children each
internal node has). Defaults to 2.
stddev: The stddev of the noise added to each internal node of the
constructed tree.
arity: The branching factor of the tree.
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
test purpose.
"""
self._inner_query = inner_query
self._stddev = stddev
self._arity = arity
if self._arity < 1:
raise ValueError(f'Invalid arity={arity} smaller than 2.')
self._l1_bound = l1_bound
self._seed = seed
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return TreeRangeSumQuery.GlobalState(
arity=self._arity,
inner_query_state=self._inner_query.initial_global_state())
return CentralTreeSumQuery.GlobalState(l1_bound=self._l1_bound)
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return (global_state.arity,
self._inner_query.derive_sample_params(
global_state.inner_query_state))
return global_state.l1_bound
def preprocess_record(self, params, record):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
casted_record = tf.cast(record, tf.float32)
l1_norm = tf.norm(casted_record, ord=1)
l1_bound = tf.cast(params, tf.float32)
preprocessed_record, _ = tf.clip_by_global_norm([casted_record],
l1_bound,
use_norm=l1_norm)
return preprocessed_record[0]
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
Args:
sample_state: a frequency histogram.
global_state: hyper-parameters of the query.
Returns:
a `tf.RaggedTensor` representing the tree built on top of `sample_state`.
The jth node on the ith layer of the tree can be accessed by tree[i][j]
where tree is the returned value.
"""
add_noise = _get_add_noise(self._stddev, self._seed)
tree = _build_tree_from_leaf(sample_state, self._arity)
return tf.map_fn(add_noise, tree), global_state
class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
"""Implements dp_query for differentially private tree aggregation protocol.
The difference from `CentralTreeSumQuery` is that the tree construction and
gaussian noise addition happen in `preprocess_records`. The difference only
takes effect when used together with
`tff.aggregators.DifferentiallyPrivateFactory`. In other cases, this class
should be treated as equal with `CentralTreeSumQuery`.
Implements a distributed version of the tree aggregation protocol from. "Is
interaction necessary for distributed private learning?." by replacing their
local randomizer with gaussian mechanism. The first step is to check the L1
norm of the clients' local updates (i.e. a 1-D array containing the leaf nodes
of the tree) to make sure it does not exceed a prespecified upper bound. The
second step is to construct the tree. The third step is to add independent
gaussian noise to each node in the tree. The returned tree can support
efficient and accurate range queries with differential privacy.
"""
@attr.s(frozen=True)
class GlobalState(object):
"""Class defining global state for DistributedTreeSumQuery.
Attributes:
stddev: The stddev of the noise added to each internal node in the
constructed tree.
arity: The branching factor of the tree (i.e. the number of children each
internal node has).
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
"""
stddev = attr.ib()
arity = attr.ib()
l1_bound = attr.ib()
def __init__(self,
stddev: float,
arity: int = 2,
l1_bound: int = 10,
seed: Optional[int] = None):
"""Initializes the `DistributedTreeSumQuery`.
Args:
stddev: The stddev of the noise added to each node in the tree.
arity: The branching factor of the tree.
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
test purpose.
"""
self._stddev = stddev
self._arity = arity
self._l1_bound = l1_bound
self._seed = seed
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return DistributedTreeSumQuery.GlobalState(
stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound)
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return (global_state.stddev, global_state.arity, global_state.l1_bound)
def preprocess_record(self, params, record):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
This method builds the tree, flattens it and applies
`inner_query.preprocess_record` to the flattened tree.
This method clips the input record by L1 norm, constructs a tree on top of
it, and adds gaussian noise to each node of the tree for differential
privacy. Unlike `get_noised_result` in `CentralTreeSumQuery`, this function
flattens the `tf.RaggedTensor` before outputting it. This is useful when
used inside `tff.aggregators.DifferentiallyPrivateFactory` because it does
not accept ragged output tensor.
Args:
params: Hyper-parameters for preprocessing record.
record: A histogram representing the leaf nodes of the tree.
params: hyper-parameters for preprocessing record, (stddev, aritry,
l1_bound)
record: leaf nodes for the tree.
Returns:
A `tf.Tensor` representing the flattened version of the preprocessed tree.
`tf.Tensor` representing the flattened version of the tree.
"""
arity, inner_query_params = params
preprocessed_record = _build_tree_from_leaf(record, arity).flat_values
preprocessed_record = self._inner_query.preprocess_record(
inner_query_params, preprocessed_record)
_, arity, l1_bound_ = params
l1_bound = tf.cast(l1_bound_, tf.float32)
casted_record = tf.cast(record, tf.float32)
l1_norm = tf.norm(casted_record, ord=1)
preprocessed_record, _ = tf.clip_by_global_norm([casted_record],
l1_bound,
use_norm=l1_norm)
preprocessed_record = preprocessed_record[0]
add_noise = _get_add_noise(self._stddev, self._seed)
tree = _build_tree_from_leaf(preprocessed_record, arity)
noisy_tree = tf.map_fn(add_noise, tree)
# The following codes reshape the output vector so the output shape of can
# be statically inferred. This is useful when used with
# `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know
# the output shape of this function statically and explicitly.
preprocessed_record_shape = [
flat_noisy_tree = noisy_tree.flat_values
flat_tree_shape = [
(self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) -
1) // (self._arity - 1)
]
return tf.reshape(preprocessed_record, preprocessed_record_shape)
return tf.reshape(flat_noisy_tree, flat_tree_shape)
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
@ -525,11 +661,12 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
output by `preprocess_records.`
Args:
sample_state: A `tf.Tensor` for the flattened tree.
global_state: The global state of the protocol.
sample_state: `tf.Tensor` for the flattened tree.
global_state: hyper-parameters including noise multiplier, the branching
factor of the tree and the maximum records per user.
Returns:
A `tf.RaggedTensor` representing the tree.
a `tf.RaggedTensor` for the tree.
"""
# The [0] is needed because of how tf.RaggedTensor.from_two_splits works.
# print(tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6],
@ -545,60 +682,3 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
tree = tf.RaggedTensor.from_row_splits(
values=sample_state, row_splits=row_splits)
return tree, global_state
@classmethod
def build_central_gaussian_query(cls,
l2_norm_clip: float,
stddev: float,
arity: int = 2):
"""Returns `TreeRangeSumQuery` with central Gaussian noise.
Args:
l2_norm_clip: Each record should be clipped so that it has L2 norm at most
`l2_norm_clip`.
stddev: Stddev of the central Gaussian noise.
arity: The branching factor of the tree (i.e. the number of children each
internal node has). Defaults to 2.
"""
if l2_norm_clip <= 0:
raise ValueError(f'`l2_norm_clip` must be positive, got {l2_norm_clip}.')
if stddev < 0:
raise ValueError(f'`stddev` must be non-negative, got {stddev}.')
if arity < 2:
raise ValueError(f'`arity` must be at least 2, got {arity}.')
inner_query = gaussian_query.GaussianSumQuery(l2_norm_clip, stddev)
return cls(arity=arity, inner_query=inner_query)
@classmethod
def build_distributed_discrete_gaussian_query(cls,
l2_norm_bound: float,
local_stddev: float,
arity: int = 2):
"""Returns `TreeRangeSumQuery` with central Gaussian noise.
Args:
l2_norm_bound: Each record should be clipped so that it has L2 norm at
most `l2_norm_bound`.
local_stddev: Scale/stddev of the local discrete Gaussian noise.
arity: The branching factor of the tree (i.e. the number of children each
internal node has). Defaults to 2.
"""
if l2_norm_bound <= 0:
raise ValueError(
f'`l2_clip_bound` must be positive, got {l2_norm_bound}.')
if local_stddev < 0:
raise ValueError(
f'`local_stddev` must be non-negative, got {local_stddev}.')
if arity < 2:
raise ValueError(f'`arity` must be at least 2, got {arity}.')
inner_query = distributed_discrete_gaussian_query.DistributedDiscreteGaussianSumQuery(
l2_norm_bound, local_stddev)
return cls(arity=arity, inner_query=inner_query)

View file

@ -423,115 +423,72 @@ class BuildTreeTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(tree[layer][idx], expected_value)
class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
class CentralTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
inner_query=['central', 'distributed'],
params=[(0., 1., 2), (1., -1., 2), (1., 1., 1)],
)
def test_raises_error(self, inner_query, params):
clip_norm, stddev, arity = params
with self.assertRaises(ValueError):
if inner_query == 'central':
tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
clip_norm, stddev, arity)
elif inner_query == 'distributed':
tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
clip_norm, stddev, arity)
def test_initial_global_state_type(self):
@parameterized.product(
inner_query=['central', 'distributed'],
clip_norm=[0.1, 1.0, 10.0],
stddev=[0.1, 1.0, 10.0])
def test_initial_global_state_type(self, inner_query, clip_norm, stddev):
if inner_query == 'central':
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
clip_norm, stddev)
elif inner_query == 'distributed':
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
clip_norm, stddev)
query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD)
global_state = query.initial_global_state()
self.assertIsInstance(global_state,
tree_aggregation_query.TreeRangeSumQuery.GlobalState)
self.assertIsInstance(
global_state, tree_aggregation_query.CentralTreeSumQuery.GlobalState)
@parameterized.product(
inner_query=['central', 'distributed'],
clip_norm=[0.1, 1.0, 10.0],
stddev=[0.1, 1.0, 10.0],
arity=[2, 3, 4])
def test_derive_sample_params(self, inner_query, clip_norm, stddev, arity):
if inner_query == 'central':
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
clip_norm, stddev, arity)
elif inner_query == 'distributed':
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
clip_norm, stddev, arity)
global_state = query.initial_global_state()
derived_arity, inner_query_state = query.derive_sample_params(global_state)
self.assertAllClose(derived_arity, arity)
if inner_query == 'central':
self.assertAllClose(inner_query_state, clip_norm)
elif inner_query == 'distributed':
self.assertAllClose(inner_query_state.l2_norm_bound, clip_norm)
self.assertAllClose(inner_query_state.local_stddev, stddev)
@parameterized.product(
(dict(arity=2, expected_tree=[1, 1, 0, 1, 0, 0, 0]),
dict(arity=3, expected_tree=[1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])),
inner_query=['central', 'distributed'],
)
def test_preprocess_record(self, inner_query, arity, expected_tree):
if inner_query == 'central':
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
10., 0., arity)
record = tf.constant([1, 0, 0, 0], dtype=tf.float32)
expected_tree = tf.cast(expected_tree, tf.float32)
elif inner_query == 'distributed':
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
10., 0., arity)
record = tf.constant([1, 0, 0, 0], dtype=tf.int32)
def test_derive_sample_params(self):
query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(preprocessed_record, expected_tree)
self.assertAllClose(params, 10.)
@parameterized.named_parameters(
('stddev_1', 1, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]),
('stddev_0_1', 4, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]),
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32)),
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32)),
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32)),
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.],
dtype=tf.float32)),
)
def test_distributed_preprocess_record_with_noise(self, local_stddev, record,
expected_tree):
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
10., local_stddev)
def test_preprocess_record(self, arity, record):
query = tree_aggregation_query.CentralTreeSumQuery(
stddev=NOISE_STD, arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(
preprocessed_record, expected_tree, atol=10 * local_stddev)
self.assertAllClose(preprocessed_record, record)
@parameterized.product(
(dict(
arity=2,
expected_tree=tf.ragged.constant([[1], [1, 0], [1, 0, 0, 0]])),
dict(
arity=3,
expected_tree=tf.ragged.constant([[1], [1, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0]]))),
inner_query=['central', 'distributed'],
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.constant([5, 5, 0, 0], dtype=tf.int32)),
('binary_test_float', 2, tf.constant(
[10., 10., 0., 0.],
dtype=tf.float32), tf.constant([5., 5., 0., 0.], dtype=tf.float32)),
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.constant([5, 5, 0, 0], dtype=tf.int32)),
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.constant([5., 5., 0., 0.], dtype=tf.float32)),
)
def test_get_noised_result(self, inner_query, arity, expected_tree):
if inner_query == 'central':
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
10., 0., arity)
record = tf.constant([1, 0, 0, 0], dtype=tf.float32)
expected_tree = tf.cast(expected_tree, tf.float32)
elif inner_query == 'distributed':
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
10., 0., arity)
record = tf.constant([1, 0, 0, 0], dtype=tf.int32)
def test_preprocess_record_clipped(self, arity, record,
expected_clipped_value):
query = tree_aggregation_query.CentralTreeSumQuery(
stddev=NOISE_STD, arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(preprocessed_record, expected_clipped_value)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([[1.], [1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([[1.], [1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
)
def test_get_noised_result(self, arity, record, expected_tree):
query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
@ -540,18 +497,167 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(sample_state, expected_tree)
@parameterized.product(stddev=[0.1, 1.0, 10.0])
def test_central_get_noised_result_with_noise(self, stddev):
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
10., stddev)
@parameterized.named_parameters(
('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
)
def test_get_noised_result_with_noise(self, stddev, record, expected_tree):
query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev, seed=0)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.]))
preprocessed_record = query.preprocess_record(params, record)
sample_state, _ = query.get_noised_result(preprocessed_record, global_state)
self.assertAllClose(
sample_state.flat_values, expected_tree, atol=3 * stddev)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant([[10.], [10., 0., 0.],
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant([[10.], [10., 0., 0.],
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
)
def test_get_noised_result_clipped(self, arity, record, expected_tree):
query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(
sample_state, tf.ragged.constant([[1.], [1., 0.]]), atol=10 * stddev)
self.assertAllClose(sample_state, expected_tree)
class DistributedTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_initial_global_state_type(self):
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
global_state = query.initial_global_state()
self.assertIsInstance(
global_state,
tree_aggregation_query.DistributedTreeSumQuery.GlobalState)
def test_derive_sample_params(self):
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
global_state = query.initial_global_state()
stddev, arity, l1_bound = query.derive_sample_params(global_state)
self.assertAllClose(stddev, NOISE_STD)
self.assertAllClose(arity, 2)
self.assertAllClose(l1_bound, 10)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])),
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])),
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.
])),
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.
])),
)
def test_preprocess_record(self, arity, record, expected_tree):
query = tree_aggregation_query.DistributedTreeSumQuery(
stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(preprocessed_record, expected_tree)
@parameterized.named_parameters(
('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
)
def test_preprocess_record_with_noise(self, stddev, record, expected_tree):
query = tree_aggregation_query.DistributedTreeSumQuery(
stddev=stddev, seed=0)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(preprocessed_record, expected_tree, atol=3 * stddev)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])),
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])),
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant(
[10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])),
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant(
[10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])),
)
def test_preprocess_record_clipped(self, arity, record, expected_tree):
query = tree_aggregation_query.DistributedTreeSumQuery(
stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(preprocessed_record, expected_tree)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([[1.], [1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([[1.], [1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
)
def test_get_noised_result(self, arity, record, expected_tree):
query = tree_aggregation_query.DistributedTreeSumQuery(
stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(sample_state, expected_tree)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant([[10.], [10., 0., 0.],
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant([[10.], [10., 0., 0.],
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
)
def test_get_noised_result_clipped(self, arity, record, expected_tree):
query = tree_aggregation_query.DistributedTreeSumQuery(
stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(sample_state, expected_tree)
if __name__ == '__main__':