forked from 626_privacy/tensorflow_privacy
(1) Merge CentralTreeSumQuery
and DistributedTreeSumQuery
into one DPQuery to modularize things. The new query takes in an inner_query
argument. Depending on the behavior of inner query, the query will follow central DP or distributed DP.
(2) Remove the hard-coded L1 clipping and replace with norm bound checking in the inner query. This design allows us to use whatever clipping factory we want outside the DPQuery. PiperOrigin-RevId: 387398741
This commit is contained in:
parent
eef5810d94
commit
2672559471
2 changed files with 301 additions and 6 deletions
|
@ -15,13 +15,10 @@
|
|||
|
||||
`TreeCumulativeSumQuery` and `TreeResidualSumQuery` are `DPQuery`s for continual
|
||||
online observation queries relying on `tree_aggregation`. 'Online' means that
|
||||
the leaf nodes of the tree arrive one by one as the time proceeds. The leaves
|
||||
are vector records as defined in `dp_query.DPQuery`.
|
||||
the leaf nodes of the tree arrive one by one as the time proceeds.
|
||||
|
||||
`CentralTreeSumQuery` and `DistributedTreeSumQuery` are `DPQuery`s for
|
||||
central/distributed offline tree aggregation protocol. 'Offline' means all the
|
||||
leaf nodes are ready before the protocol starts. Each record, different from
|
||||
what is defined in `dp_query.DPQuery`, is a histogram (i.e. the leaf nodes).
|
||||
`TreeRangeSumQuery` is a `DPQuery`s for offline tree aggregation protocol.
|
||||
'Offline' means all the leaf nodes are ready before the protocol starts.
|
||||
"""
|
||||
import distutils
|
||||
import math
|
||||
|
@ -29,7 +26,9 @@ from typing import Optional
|
|||
|
||||
import attr
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
||||
|
||||
|
||||
|
@ -442,6 +441,171 @@ def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
|
|||
return tree
|
||||
|
||||
|
||||
class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
||||
"""Implements dp_query for accurate range queries using tree aggregation.
|
||||
|
||||
Implements a variant of the tree aggregation protocol from. "Is interaction
|
||||
necessary for distributed private learning?. Adam Smith, Abhradeep Thakurta,
|
||||
Jalaj Upadhyay." Builds a tree on top of the input record and adds noise to
|
||||
the tree for differential privacy. Any range query can be decomposed into the
|
||||
sum of O(log(n)) nodes in the tree compared to O(n) when using a histogram.
|
||||
Improves efficiency and reduces noise scale.
|
||||
"""
|
||||
|
||||
@attr.s(frozen=True)
|
||||
class GlobalState(object):
|
||||
"""Class defining global state for TreeRangeSumQuery.
|
||||
|
||||
Attributes:
|
||||
arity: The branching factor of the tree (i.e. the number of children each
|
||||
internal node has).
|
||||
inner_query_state: The global state of the inner query.
|
||||
"""
|
||||
arity = attr.ib()
|
||||
inner_query_state = attr.ib()
|
||||
|
||||
def __init__(self,
|
||||
inner_query: dp_query.SumAggregationDPQuery,
|
||||
arity: int = 2):
|
||||
"""Initializes the `TreeRangeSumQuery`.
|
||||
|
||||
Args:
|
||||
inner_query: The inner `DPQuery` that adds noise to the tree.
|
||||
arity: The branching factor of the tree (i.e. the number of children each
|
||||
internal node has). Defaults to 2.
|
||||
"""
|
||||
self._inner_query = inner_query
|
||||
self._arity = arity
|
||||
|
||||
if self._arity < 1:
|
||||
raise ValueError(f'Invalid arity={arity} smaller than 2.')
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
return TreeRangeSumQuery.GlobalState(
|
||||
arity=self._arity,
|
||||
inner_query_state=self._inner_query.initial_global_state())
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
return (global_state.arity,
|
||||
self._inner_query.derive_sample_params(
|
||||
global_state.inner_query_state))
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
|
||||
|
||||
This method builds the tree, flattens it and applies
|
||||
`inner_query.preprocess_record` to the flattened tree.
|
||||
|
||||
Args:
|
||||
params: Hyper-parameters for preprocessing record.
|
||||
record: A histogram representing the leaf nodes of the tree.
|
||||
|
||||
Returns:
|
||||
A `tf.Tensor` representing the flattened version of the preprocessed tree.
|
||||
"""
|
||||
arity, inner_query_params = params
|
||||
preprocessed_record = _build_tree_from_leaf(record, arity).flat_values
|
||||
preprocessed_record = self._inner_query.preprocess_record(
|
||||
inner_query_params, preprocessed_record)
|
||||
|
||||
# The following codes reshape the output vector so the output shape of can
|
||||
# be statically inferred. This is useful when used with
|
||||
# `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know
|
||||
# the output shape of this function statically and explicitly.
|
||||
preprocessed_record_shape = [
|
||||
(self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) -
|
||||
1) // (self._arity - 1)
|
||||
]
|
||||
return tf.reshape(preprocessed_record, preprocessed_record_shape)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
|
||||
|
||||
This function re-constructs the `tf.RaggedTensor` from the flattened tree
|
||||
output by `preprocess_records.`
|
||||
|
||||
Args:
|
||||
sample_state: A `tf.Tensor` for the flattened tree.
|
||||
global_state: The global state of the protocol.
|
||||
|
||||
Returns:
|
||||
A `tf.RaggedTensor` representing the tree.
|
||||
"""
|
||||
# The [0] is needed because of how tf.RaggedTensor.from_two_splits works.
|
||||
# print(tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6],
|
||||
# row_splits=[0, 4, 4, 7, 8, 8]))
|
||||
# <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
|
||||
# This part is not written in tensorflow and will be executed on the server
|
||||
# side instead of the client side if used with
|
||||
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
|
||||
row_splits = [0] + [
|
||||
(self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(
|
||||
math.floor(math.log(sample_state.shape[0], self._arity)) + 1)
|
||||
]
|
||||
tree = tf.RaggedTensor.from_row_splits(
|
||||
values=sample_state, row_splits=row_splits)
|
||||
return tree, global_state
|
||||
|
||||
@classmethod
|
||||
def build_central_gaussian_query(cls,
|
||||
l2_norm_clip: float,
|
||||
stddev: float,
|
||||
arity: int = 2):
|
||||
"""Returns `TreeRangeSumQuery` with central Gaussian noise.
|
||||
|
||||
Args:
|
||||
l2_norm_clip: Each record should be clipped so that it has L2 norm at most
|
||||
`l2_norm_clip`.
|
||||
stddev: Stddev of the central Gaussian noise.
|
||||
arity: The branching factor of the tree (i.e. the number of children each
|
||||
internal node has). Defaults to 2.
|
||||
"""
|
||||
if l2_norm_clip <= 0:
|
||||
raise ValueError(f'`l2_norm_clip` must be positive, got {l2_norm_clip}.')
|
||||
|
||||
if stddev < 0:
|
||||
raise ValueError(f'`stddev` must be non-negative, got {stddev}.')
|
||||
|
||||
if arity < 2:
|
||||
raise ValueError(f'`arity` must be at least 2, got {arity}.')
|
||||
|
||||
inner_query = gaussian_query.GaussianSumQuery(l2_norm_clip, stddev)
|
||||
|
||||
return cls(arity=arity, inner_query=inner_query)
|
||||
|
||||
@classmethod
|
||||
def build_distributed_discrete_gaussian_query(cls,
|
||||
l2_norm_bound: float,
|
||||
local_stddev: float,
|
||||
arity: int = 2):
|
||||
"""Returns `TreeRangeSumQuery` with central Gaussian noise.
|
||||
|
||||
Args:
|
||||
l2_norm_bound: Each record should be clipped so that it has L2 norm at
|
||||
most `l2_norm_bound`.
|
||||
local_stddev: Scale/stddev of the local discrete Gaussian noise.
|
||||
arity: The branching factor of the tree (i.e. the number of children each
|
||||
internal node has). Defaults to 2.
|
||||
"""
|
||||
if l2_norm_bound <= 0:
|
||||
raise ValueError(
|
||||
f'`l2_clip_bound` must be positive, got {l2_norm_bound}.')
|
||||
|
||||
if local_stddev < 0:
|
||||
raise ValueError(
|
||||
f'`local_stddev` must be non-negative, got {local_stddev}.')
|
||||
|
||||
if arity < 2:
|
||||
raise ValueError(f'`arity` must be at least 2, got {arity}.')
|
||||
|
||||
inner_query = distributed_discrete_gaussian_query.DistributedDiscreteGaussianSumQuery(
|
||||
l2_norm_bound, local_stddev)
|
||||
|
||||
return cls(arity=arity, inner_query=inner_query)
|
||||
|
||||
|
||||
def _get_add_noise(stddev, seed: int = None):
|
||||
"""Utility function to decide which `add_noise` to use according to tf version."""
|
||||
if distutils.version.LooseVersion(
|
||||
|
|
|
@ -423,6 +423,137 @@ class BuildTreeTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertEqual(tree[layer][idx], expected_value)
|
||||
|
||||
|
||||
class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.product(
|
||||
inner_query=['central', 'distributed'],
|
||||
params=[(0., 1., 2), (1., -1., 2), (1., 1., 1)],
|
||||
)
|
||||
def test_raises_error(self, inner_query, params):
|
||||
clip_norm, stddev, arity = params
|
||||
with self.assertRaises(ValueError):
|
||||
if inner_query == 'central':
|
||||
tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
|
||||
clip_norm, stddev, arity)
|
||||
elif inner_query == 'distributed':
|
||||
tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
||||
clip_norm, stddev, arity)
|
||||
|
||||
@parameterized.product(
|
||||
inner_query=['central', 'distributed'],
|
||||
clip_norm=[0.1, 1.0, 10.0],
|
||||
stddev=[0.1, 1.0, 10.0])
|
||||
def test_initial_global_state_type(self, inner_query, clip_norm, stddev):
|
||||
|
||||
if inner_query == 'central':
|
||||
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
|
||||
clip_norm, stddev)
|
||||
elif inner_query == 'distributed':
|
||||
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
||||
clip_norm, stddev)
|
||||
global_state = query.initial_global_state()
|
||||
self.assertIsInstance(global_state,
|
||||
tree_aggregation_query.TreeRangeSumQuery.GlobalState)
|
||||
|
||||
@parameterized.product(
|
||||
inner_query=['central', 'distributed'],
|
||||
clip_norm=[0.1, 1.0, 10.0],
|
||||
stddev=[0.1, 1.0, 10.0],
|
||||
arity=[2, 3, 4])
|
||||
def test_derive_sample_params(self, inner_query, clip_norm, stddev, arity):
|
||||
if inner_query == 'central':
|
||||
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
|
||||
clip_norm, stddev, arity)
|
||||
elif inner_query == 'distributed':
|
||||
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
||||
clip_norm, stddev, arity)
|
||||
global_state = query.initial_global_state()
|
||||
derived_arity, inner_query_state = query.derive_sample_params(global_state)
|
||||
self.assertAllClose(derived_arity, arity)
|
||||
if inner_query == 'central':
|
||||
self.assertAllClose(inner_query_state, clip_norm)
|
||||
elif inner_query == 'distributed':
|
||||
self.assertAllClose(inner_query_state.l2_norm_bound, clip_norm)
|
||||
self.assertAllClose(inner_query_state.local_stddev, stddev)
|
||||
|
||||
@parameterized.product(
|
||||
(dict(arity=2, expected_tree=[1, 1, 0, 1, 0, 0, 0]),
|
||||
dict(arity=3, expected_tree=[1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
inner_query=['central', 'distributed'],
|
||||
)
|
||||
def test_preprocess_record(self, inner_query, arity, expected_tree):
|
||||
if inner_query == 'central':
|
||||
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
|
||||
10., 0., arity)
|
||||
record = tf.constant([1, 0, 0, 0], dtype=tf.float32)
|
||||
expected_tree = tf.cast(expected_tree, tf.float32)
|
||||
elif inner_query == 'distributed':
|
||||
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
||||
10., 0., arity)
|
||||
record = tf.constant([1, 0, 0, 0], dtype=tf.int32)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
self.assertAllClose(preprocessed_record, expected_tree)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('stddev_1', 1, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]),
|
||||
('stddev_0_1', 4, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]),
|
||||
)
|
||||
def test_distributed_preprocess_record_with_noise(self, local_stddev, record,
|
||||
expected_tree):
|
||||
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
||||
10., local_stddev)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
|
||||
self.assertAllClose(
|
||||
preprocessed_record, expected_tree, atol=10 * local_stddev)
|
||||
|
||||
@parameterized.product(
|
||||
(dict(
|
||||
arity=2,
|
||||
expected_tree=tf.ragged.constant([[1], [1, 0], [1, 0, 0, 0]])),
|
||||
dict(
|
||||
arity=3,
|
||||
expected_tree=tf.ragged.constant([[1], [1, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0, 0, 0, 0]]))),
|
||||
inner_query=['central', 'distributed'],
|
||||
)
|
||||
def test_get_noised_result(self, inner_query, arity, expected_tree):
|
||||
if inner_query == 'central':
|
||||
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
|
||||
10., 0., arity)
|
||||
record = tf.constant([1, 0, 0, 0], dtype=tf.float32)
|
||||
expected_tree = tf.cast(expected_tree, tf.float32)
|
||||
elif inner_query == 'distributed':
|
||||
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
||||
10., 0., arity)
|
||||
record = tf.constant([1, 0, 0, 0], dtype=tf.int32)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
sample_state, global_state = query.get_noised_result(
|
||||
preprocessed_record, global_state)
|
||||
|
||||
self.assertAllClose(sample_state, expected_tree)
|
||||
|
||||
@parameterized.product(stddev=[0.1, 1.0, 10.0])
|
||||
def test_central_get_noised_result_with_noise(self, stddev):
|
||||
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
|
||||
10., stddev)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.]))
|
||||
sample_state, global_state = query.get_noised_result(
|
||||
preprocessed_record, global_state)
|
||||
|
||||
self.assertAllClose(
|
||||
sample_state, tf.ragged.constant([[1.], [1., 0.]]), atol=10 * stddev)
|
||||
|
||||
|
||||
class CentralTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_initial_global_state_type(self):
|
||||
|
|
Loading…
Reference in a new issue