forked from 626_privacy/tensorflow_privacy
parent
4d335d1b69
commit
eef5810d94
2 changed files with 393 additions and 207 deletions
|
@ -15,18 +15,21 @@
|
||||||
|
|
||||||
`TreeCumulativeSumQuery` and `TreeResidualSumQuery` are `DPQuery`s for continual
|
`TreeCumulativeSumQuery` and `TreeResidualSumQuery` are `DPQuery`s for continual
|
||||||
online observation queries relying on `tree_aggregation`. 'Online' means that
|
online observation queries relying on `tree_aggregation`. 'Online' means that
|
||||||
the leaf nodes of the tree arrive one by one as the time proceeds.
|
the leaf nodes of the tree arrive one by one as the time proceeds. The leaves
|
||||||
|
are vector records as defined in `dp_query.DPQuery`.
|
||||||
|
|
||||||
`TreeRangeSumQuery` is a `DPQuery`s for offline tree aggregation protocol.
|
`CentralTreeSumQuery` and `DistributedTreeSumQuery` are `DPQuery`s for
|
||||||
'Offline' means all the leaf nodes are ready before the protocol starts.
|
central/distributed offline tree aggregation protocol. 'Offline' means all the
|
||||||
|
leaf nodes are ready before the protocol starts. Each record, different from
|
||||||
|
what is defined in `dp_query.DPQuery`, is a histogram (i.e. the leaf nodes).
|
||||||
"""
|
"""
|
||||||
|
import distutils
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
|
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
|
||||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
||||||
|
|
||||||
|
|
||||||
|
@ -439,84 +442,217 @@ def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
|
||||||
return tree
|
return tree
|
||||||
|
|
||||||
|
|
||||||
class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
def _get_add_noise(stddev, seed: int = None):
|
||||||
"""Implements dp_query for accurate range queries using tree aggregation.
|
"""Utility function to decide which `add_noise` to use according to tf version."""
|
||||||
|
if distutils.version.LooseVersion(
|
||||||
|
tf.__version__) < distutils.version.LooseVersion('2.0.0'):
|
||||||
|
|
||||||
Implements a variant of the tree aggregation protocol from. "Is interaction
|
# The seed should be only used for testing purpose.
|
||||||
necessary for distributed private learning?. Adam Smith, Abhradeep Thakurta,
|
if seed is not None:
|
||||||
Jalaj Upadhyay." Builds a tree on top of the input record and adds noise to
|
tf.random.set_seed(seed)
|
||||||
the tree for differential privacy. Any range query can be decomposed into the
|
|
||||||
sum of O(log(n)) nodes in the tree compared to O(n) when using a histogram.
|
def add_noise(v):
|
||||||
Improves efficiency and reduces noise scale.
|
return v + tf.random.normal(
|
||||||
|
tf.shape(input=v), stddev=stddev, dtype=v.dtype)
|
||||||
|
else:
|
||||||
|
random_normal = tf.random_normal_initializer(stddev=stddev, seed=seed)
|
||||||
|
|
||||||
|
def add_noise(v):
|
||||||
|
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
|
||||||
|
|
||||||
|
return add_noise
|
||||||
|
|
||||||
|
|
||||||
|
class CentralTreeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
"""Implements dp_query for differentially private tree aggregation protocol.
|
||||||
|
|
||||||
|
Implements a central variant of the tree aggregation protocol from the paper
|
||||||
|
"'Is interaction necessary for distributed private learning?.' Adam Smith,
|
||||||
|
Abhradeep Thakurta, Jalaj Upadhyay" by replacing their local randomizer with
|
||||||
|
gaussian mechanism. The first step is to clip the clients' local updates (i.e.
|
||||||
|
a 1-D array containing the leaf nodes of the tree) by L1 norm to make sure it
|
||||||
|
does not exceed a prespecified upper bound. The second step is to construct
|
||||||
|
the tree on the clipped update. The third step is to add independent gaussian
|
||||||
|
noise to each node in the tree. The returned tree can support efficient and
|
||||||
|
accurate range queries with differential privacy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@attr.s(frozen=True)
|
@attr.s(frozen=True)
|
||||||
class GlobalState(object):
|
class GlobalState(object):
|
||||||
"""Class defining global state for TreeRangeSumQuery.
|
"""Class defining global state for `CentralTreeSumQuery`.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
arity: The branching factor of the tree (i.e. the number of children each
|
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||||
internal node has).
|
needed to bound the sensitivity and deploy differential privacy.
|
||||||
inner_query_state: The global state of the inner query.
|
|
||||||
"""
|
"""
|
||||||
arity = attr.ib()
|
l1_bound = attr.ib()
|
||||||
inner_query_state = attr.ib()
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
inner_query: dp_query.SumAggregationDPQuery,
|
stddev: float,
|
||||||
arity: int = 2):
|
arity: int = 2,
|
||||||
"""Initializes the `TreeRangeSumQuery`.
|
l1_bound: int = 10,
|
||||||
|
seed: Optional[int] = None):
|
||||||
|
"""Initializes the `CentralTreeSumQuery`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inner_query: The inner `DPQuery` that adds noise to the tree.
|
stddev: The stddev of the noise added to each internal node of the
|
||||||
arity: The branching factor of the tree (i.e. the number of children each
|
constructed tree.
|
||||||
internal node has). Defaults to 2.
|
arity: The branching factor of the tree.
|
||||||
|
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||||
|
needed to bound the sensitivity and deploy differential privacy.
|
||||||
|
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
|
||||||
|
test purpose.
|
||||||
"""
|
"""
|
||||||
self._inner_query = inner_query
|
self._stddev = stddev
|
||||||
self._arity = arity
|
self._arity = arity
|
||||||
|
self._l1_bound = l1_bound
|
||||||
if self._arity < 1:
|
self._seed = seed
|
||||||
raise ValueError(f'Invalid arity={arity} smaller than 2.')
|
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||||
return TreeRangeSumQuery.GlobalState(
|
return CentralTreeSumQuery.GlobalState(l1_bound=self._l1_bound)
|
||||||
arity=self._arity,
|
|
||||||
inner_query_state=self._inner_query.initial_global_state())
|
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||||
return (global_state.arity,
|
return global_state.l1_bound
|
||||||
self._inner_query.derive_sample_params(
|
|
||||||
global_state.inner_query_state))
|
def preprocess_record(self, params, record):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
|
||||||
|
casted_record = tf.cast(record, tf.float32)
|
||||||
|
l1_norm = tf.norm(casted_record, ord=1)
|
||||||
|
|
||||||
|
l1_bound = tf.cast(params, tf.float32)
|
||||||
|
|
||||||
|
preprocessed_record, _ = tf.clip_by_global_norm([casted_record],
|
||||||
|
l1_bound,
|
||||||
|
use_norm=l1_norm)
|
||||||
|
|
||||||
|
return preprocessed_record[0]
|
||||||
|
|
||||||
|
def get_noised_result(self, sample_state, global_state):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_state: a frequency histogram.
|
||||||
|
global_state: hyper-parameters of the query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a `tf.RaggedTensor` representing the tree built on top of `sample_state`.
|
||||||
|
The jth node on the ith layer of the tree can be accessed by tree[i][j]
|
||||||
|
where tree is the returned value.
|
||||||
|
"""
|
||||||
|
add_noise = _get_add_noise(self._stddev, self._seed)
|
||||||
|
tree = _build_tree_from_leaf(sample_state, self._arity)
|
||||||
|
return tf.map_fn(add_noise, tree), global_state
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
"""Implements dp_query for differentially private tree aggregation protocol.
|
||||||
|
|
||||||
|
The difference from `CentralTreeSumQuery` is that the tree construction and
|
||||||
|
gaussian noise addition happen in `preprocess_records`. The difference only
|
||||||
|
takes effect when used together with
|
||||||
|
`tff.aggregators.DifferentiallyPrivateFactory`. In other cases, this class
|
||||||
|
should be treated as equal with `CentralTreeSumQuery`.
|
||||||
|
|
||||||
|
Implements a distributed version of the tree aggregation protocol from. "Is
|
||||||
|
interaction necessary for distributed private learning?." by replacing their
|
||||||
|
local randomizer with gaussian mechanism. The first step is to check the L1
|
||||||
|
norm of the clients' local updates (i.e. a 1-D array containing the leaf nodes
|
||||||
|
of the tree) to make sure it does not exceed a prespecified upper bound. The
|
||||||
|
second step is to construct the tree. The third step is to add independent
|
||||||
|
gaussian noise to each node in the tree. The returned tree can support
|
||||||
|
efficient and accurate range queries with differential privacy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@attr.s(frozen=True)
|
||||||
|
class GlobalState(object):
|
||||||
|
"""Class defining global state for DistributedTreeSumQuery.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
stddev: The stddev of the noise added to each internal node in the
|
||||||
|
constructed tree.
|
||||||
|
arity: The branching factor of the tree (i.e. the number of children each
|
||||||
|
internal node has).
|
||||||
|
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||||
|
needed to bound the sensitivity and deploy differential privacy.
|
||||||
|
"""
|
||||||
|
stddev = attr.ib()
|
||||||
|
arity = attr.ib()
|
||||||
|
l1_bound = attr.ib()
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
stddev: float,
|
||||||
|
arity: int = 2,
|
||||||
|
l1_bound: int = 10,
|
||||||
|
seed: Optional[int] = None):
|
||||||
|
"""Initializes the `DistributedTreeSumQuery`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stddev: The stddev of the noise added to each node in the tree.
|
||||||
|
arity: The branching factor of the tree.
|
||||||
|
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||||
|
needed to bound the sensitivity and deploy differential privacy.
|
||||||
|
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
|
||||||
|
test purpose.
|
||||||
|
"""
|
||||||
|
self._stddev = stddev
|
||||||
|
self._arity = arity
|
||||||
|
self._l1_bound = l1_bound
|
||||||
|
self._seed = seed
|
||||||
|
|
||||||
|
def initial_global_state(self):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||||
|
return DistributedTreeSumQuery.GlobalState(
|
||||||
|
stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound)
|
||||||
|
|
||||||
|
def derive_sample_params(self, global_state):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||||
|
return (global_state.stddev, global_state.arity, global_state.l1_bound)
|
||||||
|
|
||||||
def preprocess_record(self, params, record):
|
def preprocess_record(self, params, record):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
|
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
|
||||||
|
|
||||||
This method builds the tree, flattens it and applies
|
This method clips the input record by L1 norm, constructs a tree on top of
|
||||||
`inner_query.preprocess_record` to the flattened tree.
|
it, and adds gaussian noise to each node of the tree for differential
|
||||||
|
privacy. Unlike `get_noised_result` in `CentralTreeSumQuery`, this function
|
||||||
|
flattens the `tf.RaggedTensor` before outputting it. This is useful when
|
||||||
|
used inside `tff.aggregators.DifferentiallyPrivateFactory` because it does
|
||||||
|
not accept ragged output tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: Hyper-parameters for preprocessing record.
|
params: hyper-parameters for preprocessing record, (stddev, aritry,
|
||||||
record: A histogram representing the leaf nodes of the tree.
|
l1_bound)
|
||||||
|
record: leaf nodes for the tree.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `tf.Tensor` representing the flattened version of the preprocessed tree.
|
`tf.Tensor` representing the flattened version of the tree.
|
||||||
"""
|
"""
|
||||||
arity, inner_query_params = params
|
_, arity, l1_bound_ = params
|
||||||
preprocessed_record = _build_tree_from_leaf(record, arity).flat_values
|
l1_bound = tf.cast(l1_bound_, tf.float32)
|
||||||
preprocessed_record = self._inner_query.preprocess_record(
|
|
||||||
inner_query_params, preprocessed_record)
|
casted_record = tf.cast(record, tf.float32)
|
||||||
|
l1_norm = tf.norm(casted_record, ord=1)
|
||||||
|
|
||||||
|
preprocessed_record, _ = tf.clip_by_global_norm([casted_record],
|
||||||
|
l1_bound,
|
||||||
|
use_norm=l1_norm)
|
||||||
|
preprocessed_record = preprocessed_record[0]
|
||||||
|
|
||||||
|
add_noise = _get_add_noise(self._stddev, self._seed)
|
||||||
|
tree = _build_tree_from_leaf(preprocessed_record, arity)
|
||||||
|
noisy_tree = tf.map_fn(add_noise, tree)
|
||||||
|
|
||||||
# The following codes reshape the output vector so the output shape of can
|
# The following codes reshape the output vector so the output shape of can
|
||||||
# be statically inferred. This is useful when used with
|
# be statically inferred. This is useful when used with
|
||||||
# `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know
|
# `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know
|
||||||
# the output shape of this function statically and explicitly.
|
# the output shape of this function statically and explicitly.
|
||||||
preprocessed_record_shape = [
|
flat_noisy_tree = noisy_tree.flat_values
|
||||||
|
flat_tree_shape = [
|
||||||
(self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) -
|
(self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) -
|
||||||
1) // (self._arity - 1)
|
1) // (self._arity - 1)
|
||||||
]
|
]
|
||||||
return tf.reshape(preprocessed_record, preprocessed_record_shape)
|
return tf.reshape(flat_noisy_tree, flat_tree_shape)
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
|
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
|
||||||
|
@ -525,11 +661,12 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
output by `preprocess_records.`
|
output by `preprocess_records.`
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sample_state: A `tf.Tensor` for the flattened tree.
|
sample_state: `tf.Tensor` for the flattened tree.
|
||||||
global_state: The global state of the protocol.
|
global_state: hyper-parameters including noise multiplier, the branching
|
||||||
|
factor of the tree and the maximum records per user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `tf.RaggedTensor` representing the tree.
|
a `tf.RaggedTensor` for the tree.
|
||||||
"""
|
"""
|
||||||
# The [0] is needed because of how tf.RaggedTensor.from_two_splits works.
|
# The [0] is needed because of how tf.RaggedTensor.from_two_splits works.
|
||||||
# print(tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6],
|
# print(tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6],
|
||||||
|
@ -545,60 +682,3 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
tree = tf.RaggedTensor.from_row_splits(
|
tree = tf.RaggedTensor.from_row_splits(
|
||||||
values=sample_state, row_splits=row_splits)
|
values=sample_state, row_splits=row_splits)
|
||||||
return tree, global_state
|
return tree, global_state
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build_central_gaussian_query(cls,
|
|
||||||
l2_norm_clip: float,
|
|
||||||
stddev: float,
|
|
||||||
arity: int = 2):
|
|
||||||
"""Returns `TreeRangeSumQuery` with central Gaussian noise.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
l2_norm_clip: Each record should be clipped so that it has L2 norm at most
|
|
||||||
`l2_norm_clip`.
|
|
||||||
stddev: Stddev of the central Gaussian noise.
|
|
||||||
arity: The branching factor of the tree (i.e. the number of children each
|
|
||||||
internal node has). Defaults to 2.
|
|
||||||
"""
|
|
||||||
if l2_norm_clip <= 0:
|
|
||||||
raise ValueError(f'`l2_norm_clip` must be positive, got {l2_norm_clip}.')
|
|
||||||
|
|
||||||
if stddev < 0:
|
|
||||||
raise ValueError(f'`stddev` must be non-negative, got {stddev}.')
|
|
||||||
|
|
||||||
if arity < 2:
|
|
||||||
raise ValueError(f'`arity` must be at least 2, got {arity}.')
|
|
||||||
|
|
||||||
inner_query = gaussian_query.GaussianSumQuery(l2_norm_clip, stddev)
|
|
||||||
|
|
||||||
return cls(arity=arity, inner_query=inner_query)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build_distributed_discrete_gaussian_query(cls,
|
|
||||||
l2_norm_bound: float,
|
|
||||||
local_stddev: float,
|
|
||||||
arity: int = 2):
|
|
||||||
"""Returns `TreeRangeSumQuery` with central Gaussian noise.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
l2_norm_bound: Each record should be clipped so that it has L2 norm at
|
|
||||||
most `l2_norm_bound`.
|
|
||||||
local_stddev: Scale/stddev of the local discrete Gaussian noise.
|
|
||||||
arity: The branching factor of the tree (i.e. the number of children each
|
|
||||||
internal node has). Defaults to 2.
|
|
||||||
"""
|
|
||||||
if l2_norm_bound <= 0:
|
|
||||||
raise ValueError(
|
|
||||||
f'`l2_clip_bound` must be positive, got {l2_norm_bound}.')
|
|
||||||
|
|
||||||
if local_stddev < 0:
|
|
||||||
raise ValueError(
|
|
||||||
f'`local_stddev` must be non-negative, got {local_stddev}.')
|
|
||||||
|
|
||||||
if arity < 2:
|
|
||||||
raise ValueError(f'`arity` must be at least 2, got {arity}.')
|
|
||||||
|
|
||||||
inner_query = distributed_discrete_gaussian_query.DistributedDiscreteGaussianSumQuery(
|
|
||||||
l2_norm_bound, local_stddev)
|
|
||||||
|
|
||||||
return cls(arity=arity, inner_query=inner_query)
|
|
||||||
|
|
|
@ -423,115 +423,72 @@ class BuildTreeTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertEqual(tree[layer][idx], expected_value)
|
self.assertEqual(tree[layer][idx], expected_value)
|
||||||
|
|
||||||
|
|
||||||
class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
class CentralTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@parameterized.product(
|
def test_initial_global_state_type(self):
|
||||||
inner_query=['central', 'distributed'],
|
|
||||||
params=[(0., 1., 2), (1., -1., 2), (1., 1., 1)],
|
|
||||||
)
|
|
||||||
def test_raises_error(self, inner_query, params):
|
|
||||||
clip_norm, stddev, arity = params
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
if inner_query == 'central':
|
|
||||||
tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
|
|
||||||
clip_norm, stddev, arity)
|
|
||||||
elif inner_query == 'distributed':
|
|
||||||
tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
|
||||||
clip_norm, stddev, arity)
|
|
||||||
|
|
||||||
@parameterized.product(
|
query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD)
|
||||||
inner_query=['central', 'distributed'],
|
|
||||||
clip_norm=[0.1, 1.0, 10.0],
|
|
||||||
stddev=[0.1, 1.0, 10.0])
|
|
||||||
def test_initial_global_state_type(self, inner_query, clip_norm, stddev):
|
|
||||||
|
|
||||||
if inner_query == 'central':
|
|
||||||
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
|
|
||||||
clip_norm, stddev)
|
|
||||||
elif inner_query == 'distributed':
|
|
||||||
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
|
||||||
clip_norm, stddev)
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
self.assertIsInstance(global_state,
|
self.assertIsInstance(
|
||||||
tree_aggregation_query.TreeRangeSumQuery.GlobalState)
|
global_state, tree_aggregation_query.CentralTreeSumQuery.GlobalState)
|
||||||
|
|
||||||
@parameterized.product(
|
def test_derive_sample_params(self):
|
||||||
inner_query=['central', 'distributed'],
|
query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD)
|
||||||
clip_norm=[0.1, 1.0, 10.0],
|
|
||||||
stddev=[0.1, 1.0, 10.0],
|
|
||||||
arity=[2, 3, 4])
|
|
||||||
def test_derive_sample_params(self, inner_query, clip_norm, stddev, arity):
|
|
||||||
if inner_query == 'central':
|
|
||||||
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
|
|
||||||
clip_norm, stddev, arity)
|
|
||||||
elif inner_query == 'distributed':
|
|
||||||
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
|
||||||
clip_norm, stddev, arity)
|
|
||||||
global_state = query.initial_global_state()
|
|
||||||
derived_arity, inner_query_state = query.derive_sample_params(global_state)
|
|
||||||
self.assertAllClose(derived_arity, arity)
|
|
||||||
if inner_query == 'central':
|
|
||||||
self.assertAllClose(inner_query_state, clip_norm)
|
|
||||||
elif inner_query == 'distributed':
|
|
||||||
self.assertAllClose(inner_query_state.l2_norm_bound, clip_norm)
|
|
||||||
self.assertAllClose(inner_query_state.local_stddev, stddev)
|
|
||||||
|
|
||||||
@parameterized.product(
|
|
||||||
(dict(arity=2, expected_tree=[1, 1, 0, 1, 0, 0, 0]),
|
|
||||||
dict(arity=3, expected_tree=[1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])),
|
|
||||||
inner_query=['central', 'distributed'],
|
|
||||||
)
|
|
||||||
def test_preprocess_record(self, inner_query, arity, expected_tree):
|
|
||||||
if inner_query == 'central':
|
|
||||||
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
|
|
||||||
10., 0., arity)
|
|
||||||
record = tf.constant([1, 0, 0, 0], dtype=tf.float32)
|
|
||||||
expected_tree = tf.cast(expected_tree, tf.float32)
|
|
||||||
elif inner_query == 'distributed':
|
|
||||||
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
|
||||||
10., 0., arity)
|
|
||||||
record = tf.constant([1, 0, 0, 0], dtype=tf.int32)
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
params = query.derive_sample_params(global_state)
|
params = query.derive_sample_params(global_state)
|
||||||
preprocessed_record = query.preprocess_record(params, record)
|
self.assertAllClose(params, 10.)
|
||||||
self.assertAllClose(preprocessed_record, expected_tree)
|
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('stddev_1', 1, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]),
|
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32)),
|
||||||
('stddev_0_1', 4, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]),
|
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32)),
|
||||||
|
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32)),
|
||||||
|
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.],
|
||||||
|
dtype=tf.float32)),
|
||||||
)
|
)
|
||||||
def test_distributed_preprocess_record_with_noise(self, local_stddev, record,
|
def test_preprocess_record(self, arity, record):
|
||||||
expected_tree):
|
query = tree_aggregation_query.CentralTreeSumQuery(
|
||||||
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
stddev=NOISE_STD, arity=arity)
|
||||||
10., local_stddev)
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
params = query.derive_sample_params(global_state)
|
params = query.derive_sample_params(global_state)
|
||||||
|
|
||||||
preprocessed_record = query.preprocess_record(params, record)
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
|
||||||
self.assertAllClose(
|
self.assertAllClose(preprocessed_record, record)
|
||||||
preprocessed_record, expected_tree, atol=10 * local_stddev)
|
|
||||||
|
|
||||||
@parameterized.product(
|
@parameterized.named_parameters(
|
||||||
(dict(
|
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
arity=2,
|
tf.constant([5, 5, 0, 0], dtype=tf.int32)),
|
||||||
expected_tree=tf.ragged.constant([[1], [1, 0], [1, 0, 0, 0]])),
|
('binary_test_float', 2, tf.constant(
|
||||||
dict(
|
[10., 10., 0., 0.],
|
||||||
arity=3,
|
dtype=tf.float32), tf.constant([5., 5., 0., 0.], dtype=tf.float32)),
|
||||||
expected_tree=tf.ragged.constant([[1], [1, 0, 0],
|
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
[1, 0, 0, 0, 0, 0, 0, 0, 0]]))),
|
tf.constant([5, 5, 0, 0], dtype=tf.int32)),
|
||||||
inner_query=['central', 'distributed'],
|
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
|
||||||
|
dtype=tf.float32),
|
||||||
|
tf.constant([5., 5., 0., 0.], dtype=tf.float32)),
|
||||||
)
|
)
|
||||||
def test_get_noised_result(self, inner_query, arity, expected_tree):
|
def test_preprocess_record_clipped(self, arity, record,
|
||||||
if inner_query == 'central':
|
expected_clipped_value):
|
||||||
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
|
query = tree_aggregation_query.CentralTreeSumQuery(
|
||||||
10., 0., arity)
|
stddev=NOISE_STD, arity=arity)
|
||||||
record = tf.constant([1, 0, 0, 0], dtype=tf.float32)
|
global_state = query.initial_global_state()
|
||||||
expected_tree = tf.cast(expected_tree, tf.float32)
|
params = query.derive_sample_params(global_state)
|
||||||
elif inner_query == 'distributed':
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
self.assertAllClose(preprocessed_record, expected_clipped_value)
|
||||||
10., 0., arity)
|
|
||||||
record = tf.constant([1, 0, 0, 0], dtype=tf.int32)
|
@parameterized.named_parameters(
|
||||||
|
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
|
||||||
|
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||||
|
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
|
||||||
|
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([[1.], [1., 0., 0.],
|
||||||
|
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||||
|
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||||
|
tf.ragged.constant([[1.], [1., 0., 0.],
|
||||||
|
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||||
|
)
|
||||||
|
def test_get_noised_result(self, arity, record, expected_tree):
|
||||||
|
query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity)
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
params = query.derive_sample_params(global_state)
|
params = query.derive_sample_params(global_state)
|
||||||
preprocessed_record = query.preprocess_record(params, record)
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
@ -540,18 +497,167 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
self.assertAllClose(sample_state, expected_tree)
|
self.assertAllClose(sample_state, expected_tree)
|
||||||
|
|
||||||
@parameterized.product(stddev=[0.1, 1.0, 10.0])
|
@parameterized.named_parameters(
|
||||||
def test_central_get_noised_result_with_noise(self, stddev):
|
('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
||||||
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
|
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
||||||
10., stddev)
|
)
|
||||||
|
def test_get_noised_result_with_noise(self, stddev, record, expected_tree):
|
||||||
|
query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev, seed=0)
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
params = query.derive_sample_params(global_state)
|
params = query.derive_sample_params(global_state)
|
||||||
preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.]))
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
|
||||||
|
sample_state, _ = query.get_noised_result(preprocessed_record, global_state)
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
sample_state.flat_values, expected_tree, atol=3 * stddev)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
|
||||||
|
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
|
||||||
|
dtype=tf.float32),
|
||||||
|
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
|
||||||
|
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([[10.], [10., 0., 0.],
|
||||||
|
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||||
|
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
|
||||||
|
dtype=tf.float32),
|
||||||
|
tf.ragged.constant([[10.], [10., 0., 0.],
|
||||||
|
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||||
|
)
|
||||||
|
def test_get_noised_result_clipped(self, arity, record, expected_tree):
|
||||||
|
query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
sample_state, global_state = query.get_noised_result(
|
sample_state, global_state = query.get_noised_result(
|
||||||
preprocessed_record, global_state)
|
preprocessed_record, global_state)
|
||||||
|
|
||||||
self.assertAllClose(
|
self.assertAllClose(sample_state, expected_tree)
|
||||||
sample_state, tf.ragged.constant([[1.], [1., 0.]]), atol=10 * stddev)
|
|
||||||
|
|
||||||
|
class DistributedTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def test_initial_global_state_type(self):
|
||||||
|
|
||||||
|
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
self.assertIsInstance(
|
||||||
|
global_state,
|
||||||
|
tree_aggregation_query.DistributedTreeSumQuery.GlobalState)
|
||||||
|
|
||||||
|
def test_derive_sample_params(self):
|
||||||
|
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
stddev, arity, l1_bound = query.derive_sample_params(global_state)
|
||||||
|
self.assertAllClose(stddev, NOISE_STD)
|
||||||
|
self.assertAllClose(arity, 2)
|
||||||
|
self.assertAllClose(l1_bound, 10)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])),
|
||||||
|
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||||
|
tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])),
|
||||||
|
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.
|
||||||
|
])),
|
||||||
|
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||||
|
tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.
|
||||||
|
])),
|
||||||
|
)
|
||||||
|
def test_preprocess_record(self, arity, record, expected_tree):
|
||||||
|
query = tree_aggregation_query.DistributedTreeSumQuery(
|
||||||
|
stddev=0., arity=arity)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
self.assertAllClose(preprocessed_record, expected_tree)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
||||||
|
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
||||||
|
)
|
||||||
|
def test_preprocess_record_with_noise(self, stddev, record, expected_tree):
|
||||||
|
query = tree_aggregation_query.DistributedTreeSumQuery(
|
||||||
|
stddev=stddev, seed=0)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
|
||||||
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
|
||||||
|
self.assertAllClose(preprocessed_record, expected_tree, atol=3 * stddev)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])),
|
||||||
|
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
|
||||||
|
dtype=tf.float32),
|
||||||
|
tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])),
|
||||||
|
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant(
|
||||||
|
[10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])),
|
||||||
|
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
|
||||||
|
dtype=tf.float32),
|
||||||
|
tf.ragged.constant(
|
||||||
|
[10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])),
|
||||||
|
)
|
||||||
|
def test_preprocess_record_clipped(self, arity, record, expected_tree):
|
||||||
|
query = tree_aggregation_query.DistributedTreeSumQuery(
|
||||||
|
stddev=0., arity=arity)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
self.assertAllClose(preprocessed_record, expected_tree)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
|
||||||
|
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||||
|
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
|
||||||
|
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([[1.], [1., 0., 0.],
|
||||||
|
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||||
|
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
|
||||||
|
tf.ragged.constant([[1.], [1., 0., 0.],
|
||||||
|
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||||
|
)
|
||||||
|
def test_get_noised_result(self, arity, record, expected_tree):
|
||||||
|
query = tree_aggregation_query.DistributedTreeSumQuery(
|
||||||
|
stddev=0., arity=arity)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
sample_state, global_state = query.get_noised_result(
|
||||||
|
preprocessed_record, global_state)
|
||||||
|
|
||||||
|
self.assertAllClose(sample_state, expected_tree)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
|
||||||
|
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
|
||||||
|
dtype=tf.float32),
|
||||||
|
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
|
||||||
|
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
|
tf.ragged.constant([[10.], [10., 0., 0.],
|
||||||
|
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||||
|
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
|
||||||
|
dtype=tf.float32),
|
||||||
|
tf.ragged.constant([[10.], [10., 0., 0.],
|
||||||
|
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
|
||||||
|
)
|
||||||
|
def test_get_noised_result_clipped(self, arity, record, expected_tree):
|
||||||
|
query = tree_aggregation_query.DistributedTreeSumQuery(
|
||||||
|
stddev=0., arity=arity)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
|
sample_state, global_state = query.get_noised_result(
|
||||||
|
preprocessed_record, global_state)
|
||||||
|
|
||||||
|
self.assertAllClose(sample_state, expected_tree)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in a new issue