(1) Merge CentralTreeSumQuery and DistributedTreeSumQuery into one DPQuery to modularize things. The new query takes in an inner_query argument. Depending on the behavior of inner query, the query will follow central DP or distributed DP.

(2) Remove the hard-coded L1 clipping and replace with norm bound checking in the inner query. This design allows us to use whatever clipping factory we want outside the DPQuery.

PiperOrigin-RevId: 387398741
This commit is contained in:
A. Unique TensorFlower 2021-07-28 11:39:48 -07:00
parent eef5810d94
commit 2672559471
2 changed files with 301 additions and 6 deletions

View file

@ -15,13 +15,10 @@
`TreeCumulativeSumQuery` and `TreeResidualSumQuery` are `DPQuery`s for continual `TreeCumulativeSumQuery` and `TreeResidualSumQuery` are `DPQuery`s for continual
online observation queries relying on `tree_aggregation`. 'Online' means that online observation queries relying on `tree_aggregation`. 'Online' means that
the leaf nodes of the tree arrive one by one as the time proceeds. The leaves the leaf nodes of the tree arrive one by one as the time proceeds.
are vector records as defined in `dp_query.DPQuery`.
`CentralTreeSumQuery` and `DistributedTreeSumQuery` are `DPQuery`s for `TreeRangeSumQuery` is a `DPQuery`s for offline tree aggregation protocol.
central/distributed offline tree aggregation protocol. 'Offline' means all the 'Offline' means all the leaf nodes are ready before the protocol starts.
leaf nodes are ready before the protocol starts. Each record, different from
what is defined in `dp_query.DPQuery`, is a histogram (i.e. the leaf nodes).
""" """
import distutils import distutils
import math import math
@ -29,7 +26,9 @@ from typing import Optional
import attr import attr
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation from tensorflow_privacy.privacy.dp_query import tree_aggregation
@ -442,6 +441,171 @@ def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
return tree return tree
class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
"""Implements dp_query for accurate range queries using tree aggregation.
Implements a variant of the tree aggregation protocol from. "Is interaction
necessary for distributed private learning?. Adam Smith, Abhradeep Thakurta,
Jalaj Upadhyay." Builds a tree on top of the input record and adds noise to
the tree for differential privacy. Any range query can be decomposed into the
sum of O(log(n)) nodes in the tree compared to O(n) when using a histogram.
Improves efficiency and reduces noise scale.
"""
@attr.s(frozen=True)
class GlobalState(object):
"""Class defining global state for TreeRangeSumQuery.
Attributes:
arity: The branching factor of the tree (i.e. the number of children each
internal node has).
inner_query_state: The global state of the inner query.
"""
arity = attr.ib()
inner_query_state = attr.ib()
def __init__(self,
inner_query: dp_query.SumAggregationDPQuery,
arity: int = 2):
"""Initializes the `TreeRangeSumQuery`.
Args:
inner_query: The inner `DPQuery` that adds noise to the tree.
arity: The branching factor of the tree (i.e. the number of children each
internal node has). Defaults to 2.
"""
self._inner_query = inner_query
self._arity = arity
if self._arity < 1:
raise ValueError(f'Invalid arity={arity} smaller than 2.')
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return TreeRangeSumQuery.GlobalState(
arity=self._arity,
inner_query_state=self._inner_query.initial_global_state())
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return (global_state.arity,
self._inner_query.derive_sample_params(
global_state.inner_query_state))
def preprocess_record(self, params, record):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
This method builds the tree, flattens it and applies
`inner_query.preprocess_record` to the flattened tree.
Args:
params: Hyper-parameters for preprocessing record.
record: A histogram representing the leaf nodes of the tree.
Returns:
A `tf.Tensor` representing the flattened version of the preprocessed tree.
"""
arity, inner_query_params = params
preprocessed_record = _build_tree_from_leaf(record, arity).flat_values
preprocessed_record = self._inner_query.preprocess_record(
inner_query_params, preprocessed_record)
# The following codes reshape the output vector so the output shape of can
# be statically inferred. This is useful when used with
# `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know
# the output shape of this function statically and explicitly.
preprocessed_record_shape = [
(self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) -
1) // (self._arity - 1)
]
return tf.reshape(preprocessed_record, preprocessed_record_shape)
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
This function re-constructs the `tf.RaggedTensor` from the flattened tree
output by `preprocess_records.`
Args:
sample_state: A `tf.Tensor` for the flattened tree.
global_state: The global state of the protocol.
Returns:
A `tf.RaggedTensor` representing the tree.
"""
# The [0] is needed because of how tf.RaggedTensor.from_two_splits works.
# print(tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6],
# row_splits=[0, 4, 4, 7, 8, 8]))
# <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
# This part is not written in tensorflow and will be executed on the server
# side instead of the client side if used with
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
row_splits = [0] + [
(self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(
math.floor(math.log(sample_state.shape[0], self._arity)) + 1)
]
tree = tf.RaggedTensor.from_row_splits(
values=sample_state, row_splits=row_splits)
return tree, global_state
@classmethod
def build_central_gaussian_query(cls,
l2_norm_clip: float,
stddev: float,
arity: int = 2):
"""Returns `TreeRangeSumQuery` with central Gaussian noise.
Args:
l2_norm_clip: Each record should be clipped so that it has L2 norm at most
`l2_norm_clip`.
stddev: Stddev of the central Gaussian noise.
arity: The branching factor of the tree (i.e. the number of children each
internal node has). Defaults to 2.
"""
if l2_norm_clip <= 0:
raise ValueError(f'`l2_norm_clip` must be positive, got {l2_norm_clip}.')
if stddev < 0:
raise ValueError(f'`stddev` must be non-negative, got {stddev}.')
if arity < 2:
raise ValueError(f'`arity` must be at least 2, got {arity}.')
inner_query = gaussian_query.GaussianSumQuery(l2_norm_clip, stddev)
return cls(arity=arity, inner_query=inner_query)
@classmethod
def build_distributed_discrete_gaussian_query(cls,
l2_norm_bound: float,
local_stddev: float,
arity: int = 2):
"""Returns `TreeRangeSumQuery` with central Gaussian noise.
Args:
l2_norm_bound: Each record should be clipped so that it has L2 norm at
most `l2_norm_bound`.
local_stddev: Scale/stddev of the local discrete Gaussian noise.
arity: The branching factor of the tree (i.e. the number of children each
internal node has). Defaults to 2.
"""
if l2_norm_bound <= 0:
raise ValueError(
f'`l2_clip_bound` must be positive, got {l2_norm_bound}.')
if local_stddev < 0:
raise ValueError(
f'`local_stddev` must be non-negative, got {local_stddev}.')
if arity < 2:
raise ValueError(f'`arity` must be at least 2, got {arity}.')
inner_query = distributed_discrete_gaussian_query.DistributedDiscreteGaussianSumQuery(
l2_norm_bound, local_stddev)
return cls(arity=arity, inner_query=inner_query)
def _get_add_noise(stddev, seed: int = None): def _get_add_noise(stddev, seed: int = None):
"""Utility function to decide which `add_noise` to use according to tf version.""" """Utility function to decide which `add_noise` to use according to tf version."""
if distutils.version.LooseVersion( if distutils.version.LooseVersion(

View file

@ -423,6 +423,137 @@ class BuildTreeTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(tree[layer][idx], expected_value) self.assertEqual(tree[layer][idx], expected_value)
class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
inner_query=['central', 'distributed'],
params=[(0., 1., 2), (1., -1., 2), (1., 1., 1)],
)
def test_raises_error(self, inner_query, params):
clip_norm, stddev, arity = params
with self.assertRaises(ValueError):
if inner_query == 'central':
tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
clip_norm, stddev, arity)
elif inner_query == 'distributed':
tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
clip_norm, stddev, arity)
@parameterized.product(
inner_query=['central', 'distributed'],
clip_norm=[0.1, 1.0, 10.0],
stddev=[0.1, 1.0, 10.0])
def test_initial_global_state_type(self, inner_query, clip_norm, stddev):
if inner_query == 'central':
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
clip_norm, stddev)
elif inner_query == 'distributed':
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
clip_norm, stddev)
global_state = query.initial_global_state()
self.assertIsInstance(global_state,
tree_aggregation_query.TreeRangeSumQuery.GlobalState)
@parameterized.product(
inner_query=['central', 'distributed'],
clip_norm=[0.1, 1.0, 10.0],
stddev=[0.1, 1.0, 10.0],
arity=[2, 3, 4])
def test_derive_sample_params(self, inner_query, clip_norm, stddev, arity):
if inner_query == 'central':
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
clip_norm, stddev, arity)
elif inner_query == 'distributed':
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
clip_norm, stddev, arity)
global_state = query.initial_global_state()
derived_arity, inner_query_state = query.derive_sample_params(global_state)
self.assertAllClose(derived_arity, arity)
if inner_query == 'central':
self.assertAllClose(inner_query_state, clip_norm)
elif inner_query == 'distributed':
self.assertAllClose(inner_query_state.l2_norm_bound, clip_norm)
self.assertAllClose(inner_query_state.local_stddev, stddev)
@parameterized.product(
(dict(arity=2, expected_tree=[1, 1, 0, 1, 0, 0, 0]),
dict(arity=3, expected_tree=[1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])),
inner_query=['central', 'distributed'],
)
def test_preprocess_record(self, inner_query, arity, expected_tree):
if inner_query == 'central':
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
10., 0., arity)
record = tf.constant([1, 0, 0, 0], dtype=tf.float32)
expected_tree = tf.cast(expected_tree, tf.float32)
elif inner_query == 'distributed':
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
10., 0., arity)
record = tf.constant([1, 0, 0, 0], dtype=tf.int32)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(preprocessed_record, expected_tree)
@parameterized.named_parameters(
('stddev_1', 1, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]),
('stddev_0_1', 4, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]),
)
def test_distributed_preprocess_record_with_noise(self, local_stddev, record,
expected_tree):
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
10., local_stddev)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(
preprocessed_record, expected_tree, atol=10 * local_stddev)
@parameterized.product(
(dict(
arity=2,
expected_tree=tf.ragged.constant([[1], [1, 0], [1, 0, 0, 0]])),
dict(
arity=3,
expected_tree=tf.ragged.constant([[1], [1, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0]]))),
inner_query=['central', 'distributed'],
)
def test_get_noised_result(self, inner_query, arity, expected_tree):
if inner_query == 'central':
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
10., 0., arity)
record = tf.constant([1, 0, 0, 0], dtype=tf.float32)
expected_tree = tf.cast(expected_tree, tf.float32)
elif inner_query == 'distributed':
query = tree_aggregation_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
10., 0., arity)
record = tf.constant([1, 0, 0, 0], dtype=tf.int32)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(sample_state, expected_tree)
@parameterized.product(stddev=[0.1, 1.0, 10.0])
def test_central_get_noised_result_with_noise(self, stddev):
query = tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
10., stddev)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.]))
sample_state, global_state = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(
sample_state, tf.ragged.constant([[1.], [1., 0.]]), atol=10 * stddev)
class CentralTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase): class CentralTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_initial_global_state_type(self): def test_initial_global_state_type(self):