Move TreeRangeSumQuery to its own module. This is the first step, will remove the function in the old module after a TFP release.

PiperOrigin-RevId: 392776774
This commit is contained in:
Zheng Xu 2021-08-24 16:49:58 -07:00 committed by A. Unique TensorFlower
parent 477b5b2899
commit 853b18929d
4 changed files with 472 additions and 5 deletions

View file

@ -56,6 +56,7 @@ else:
from tensorflow_privacy.privacy.dp_query import tree_aggregation from tensorflow_privacy.privacy.dp_query import tree_aggregation
from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeCumulativeSumQuery from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeCumulativeSumQuery
from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeResidualSumQuery from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeResidualSumQuery
from tensorflow_privacy.privacy.dp_query.tree_range_query import TreeRangeSumQuery
# Estimators # Estimators
from tensorflow_privacy.privacy.estimators.dnn import DNNClassifier from tensorflow_privacy.privacy.estimators.dnn import DNNClassifier

View file

@ -15,10 +15,9 @@
`TreeCumulativeSumQuery` and `TreeResidualSumQuery` are `DPQuery`s for continual `TreeCumulativeSumQuery` and `TreeResidualSumQuery` are `DPQuery`s for continual
online observation queries relying on `tree_aggregation`. 'Online' means that online observation queries relying on `tree_aggregation`. 'Online' means that
the leaf nodes of the tree arrive one by one as the time proceeds. the leaf nodes of the tree arrive one by one as the time proceeds. The core
logic of tree aggregation is implemented in `tree_aggregation.TreeAggregator`
`TreeRangeSumQuery` is a `DPQuery`s for offline tree aggregation protocol. and `tree_aggregation.EfficientTreeAggregator`.
'Offline' means all the leaf nodes are ready before the protocol starts.
""" """
import distutils import distutils
import math import math
@ -31,7 +30,7 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation from tensorflow_privacy.privacy.dp_query import tree_aggregation
# TODO(b/192464750): define `RestartQuery` and move `RestartIndicator` to be # TODO(b/193679963): define `RestartQuery` and move `RestartIndicator` to be
# in the same module. # in the same module.
@ -477,6 +476,10 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
restart_indicator=restart_indicator) restart_indicator=restart_indicator)
# TODO(b/197596864): Remove `TreeRangeSumQuery` from this file after the next
# TFP release
@tf.function @tf.function
def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor: def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
"""A function constructs a complete tree given all the leaf nodes. """A function constructs a complete tree given all the leaf nodes.

View file

@ -0,0 +1,281 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""`DPQuery`s for offline differentially private tree aggregation protocols.
'Offline' means all the leaf nodes are ready before the protocol starts.
"""
import distutils
import math
import attr
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query
@tf.function
def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
"""A function constructs a complete tree given all the leaf nodes.
The function takes a 1-D array representing the leaf nodes of a tree and the
tree's arity, and constructs a complete tree by recursively summing the
adjacent children to get the parent until reaching the root node. Because we
assume a complete tree, if the number of leaf nodes does not divide arity, the
leaf nodes will be padded with zeros.
Args:
leaf_nodes: A 1-D array storing the leaf nodes of the tree.
arity: A `int` for the branching factor of the tree, i.e. the number of
children for each internal node.
Returns:
`tf.RaggedTensor` representing the tree. For example, if
`leaf_nodes=tf.Tensor([1, 2, 3, 4])` and `arity=2`, then the returned value
should be `tree=tf.RaggedTensor([[10],[3,7],[1,2,3,4]])`. In this way,
`tree[layer][index]` can be used to access the node indexed by (layer,
index) in the tree,
"""
def pad_zero(leaf_nodes, size):
paddings = [[0, size - len(leaf_nodes)]]
return tf.pad(leaf_nodes, paddings)
leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32)
num_layers = tf.math.ceil(
tf.math.log(leaf_nodes_size) /
tf.math.log(tf.cast(arity, dtype=tf.float32))) + 1
leaf_nodes = pad_zero(
leaf_nodes, tf.math.pow(tf.cast(arity, dtype=tf.float32), num_layers - 1))
def _shrink_layer(layer: tf.Tensor, arity: int) -> tf.Tensor:
return tf.reduce_sum((tf.reshape(layer, (-1, arity))), 1)
# The following `tf.while_loop` constructs the tree from bottom up by
# iteratively applying `_shrink_layer` to each layer of the tree. The reason
# for the choice of TF1.0-style `tf.while_loop` is that @tf.function does not
# support auto-translation from python loop to tf loop when loop variables
# contain a `RaggedTensor` whose shape changes across iterations.
idx = tf.identity(num_layers)
loop_cond = lambda i, h: tf.less_equal(2.0, i)
def _loop_body(i, h):
return [
tf.add(i, -1.0),
tf.concat(([_shrink_layer(h[0], arity)], h), axis=0)
]
_, tree = tf.while_loop(
loop_cond,
_loop_body, [idx, tf.RaggedTensor.from_tensor([leaf_nodes])],
shape_invariants=[
idx.get_shape(),
tf.RaggedTensorSpec(dtype=leaf_nodes.dtype, ragged_rank=1)
])
return tree
class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
"""Implements dp_query for accurate range queries using tree aggregation.
Implements a variant of the tree aggregation protocol from. "Is interaction
necessary for distributed private learning?. Adam Smith, Abhradeep Thakurta,
Jalaj Upadhyay." Builds a tree on top of the input record and adds noise to
the tree for differential privacy. Any range query can be decomposed into the
sum of O(log(n)) nodes in the tree compared to O(n) when using a histogram.
Improves efficiency and reduces noise scale.
"""
@attr.s(frozen=True)
class GlobalState(object):
"""Class defining global state for TreeRangeSumQuery.
Attributes:
arity: The branching factor of the tree (i.e. the number of children each
internal node has).
inner_query_state: The global state of the inner query.
"""
arity = attr.ib()
inner_query_state = attr.ib()
def __init__(self,
inner_query: dp_query.SumAggregationDPQuery,
arity: int = 2):
"""Initializes the `TreeRangeSumQuery`.
Args:
inner_query: The inner `DPQuery` that adds noise to the tree.
arity: The branching factor of the tree (i.e. the number of children each
internal node has). Defaults to 2.
"""
self._inner_query = inner_query
self._arity = arity
if self._arity < 1:
raise ValueError(f'Invalid arity={arity} smaller than 2.')
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return TreeRangeSumQuery.GlobalState(
arity=self._arity,
inner_query_state=self._inner_query.initial_global_state())
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return (global_state.arity,
self._inner_query.derive_sample_params(
global_state.inner_query_state))
def preprocess_record(self, params, record):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
This method builds the tree, flattens it and applies
`inner_query.preprocess_record` to the flattened tree.
Args:
params: Hyper-parameters for preprocessing record.
record: A histogram representing the leaf nodes of the tree.
Returns:
A `tf.Tensor` representing the flattened version of the preprocessed tree.
"""
arity, inner_query_params = params
preprocessed_record = _build_tree_from_leaf(record, arity).flat_values
# The following codes reshape the output vector so the output shape of can
# be statically inferred. This is useful when used with
# `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know
# the output shape of this function statically and explicitly.
preprocessed_record_shape = [
(self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) -
1) // (self._arity - 1)
]
preprocessed_record = tf.reshape(preprocessed_record,
preprocessed_record_shape)
preprocessed_record = self._inner_query.preprocess_record(
inner_query_params, preprocessed_record)
return preprocessed_record
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
This function re-constructs the `tf.RaggedTensor` from the flattened tree
output by `preprocess_records.`
Args:
sample_state: A `tf.Tensor` for the flattened tree.
global_state: The global state of the protocol.
Returns:
A `tf.RaggedTensor` representing the tree.
"""
# The [0] is needed because of how tf.RaggedTensor.from_two_splits works.
# print(tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6],
# row_splits=[0, 4, 4, 7, 8, 8]))
# <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
# This part is not written in tensorflow and will be executed on the server
# side instead of the client side if used with
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
sample_state, inner_query_state = self._inner_query.get_noised_result(
sample_state, global_state.inner_query_state)
new_global_state = TreeRangeSumQuery.GlobalState(
arity=global_state.arity, inner_query_state=inner_query_state)
row_splits = [0] + [
(self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(
math.floor(math.log(sample_state.shape[0], self._arity)) + 1)
]
tree = tf.RaggedTensor.from_row_splits(
values=sample_state, row_splits=row_splits)
return tree, new_global_state
@classmethod
def build_central_gaussian_query(cls,
l2_norm_clip: float,
stddev: float,
arity: int = 2):
"""Returns `TreeRangeSumQuery` with central Gaussian noise.
Args:
l2_norm_clip: Each record should be clipped so that it has L2 norm at most
`l2_norm_clip`.
stddev: Stddev of the central Gaussian noise.
arity: The branching factor of the tree (i.e. the number of children each
internal node has). Defaults to 2.
"""
if l2_norm_clip <= 0:
raise ValueError(f'`l2_norm_clip` must be positive, got {l2_norm_clip}.')
if stddev < 0:
raise ValueError(f'`stddev` must be non-negative, got {stddev}.')
if arity < 2:
raise ValueError(f'`arity` must be at least 2, got {arity}.')
inner_query = gaussian_query.GaussianSumQuery(l2_norm_clip, stddev)
return cls(arity=arity, inner_query=inner_query)
@classmethod
def build_distributed_discrete_gaussian_query(cls,
l2_norm_bound: float,
local_stddev: float,
arity: int = 2):
"""Returns `TreeRangeSumQuery` with central Gaussian noise.
Args:
l2_norm_bound: Each record should be clipped so that it has L2 norm at
most `l2_norm_bound`.
local_stddev: Scale/stddev of the local discrete Gaussian noise.
arity: The branching factor of the tree (i.e. the number of children each
internal node has). Defaults to 2.
"""
if l2_norm_bound <= 0:
raise ValueError(
f'`l2_clip_bound` must be positive, got {l2_norm_bound}.')
if local_stddev < 0:
raise ValueError(
f'`local_stddev` must be non-negative, got {local_stddev}.')
if arity < 2:
raise ValueError(f'`arity` must be at least 2, got {arity}.')
inner_query = distributed_discrete_gaussian_query.DistributedDiscreteGaussianSumQuery(
l2_norm_bound, local_stddev)
return cls(arity=arity, inner_query=inner_query)
def _get_add_noise(stddev, seed: int = None):
"""Utility function to decide which `add_noise` to use according to tf version."""
if distutils.version.LooseVersion(
tf.__version__) < distutils.version.LooseVersion('2.0.0'):
# The seed should be only used for testing purpose.
if seed is not None:
tf.random.set_seed(seed)
def add_noise(v):
return v + tf.random.normal(
tf.shape(input=v), stddev=stddev, dtype=v.dtype)
else:
random_normal = tf.random_normal_initializer(stddev=stddev, seed=seed)
def add_noise(v):
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
return add_noise

View file

@ -0,0 +1,182 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for `tree_range_query`."""
import math
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import tree_range_query
class BuildTreeTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
leaf_nodes_size=[1, 2, 3, 4, 5],
arity=[2, 3],
dtype=[tf.int32, tf.float32],
)
def test_build_tree_from_leaf(self, leaf_nodes_size, arity, dtype):
"""Test whether `_build_tree_from_leaf` will output the correct tree."""
leaf_nodes = tf.cast(tf.range(leaf_nodes_size), dtype)
depth = math.ceil(math.log(leaf_nodes_size, arity)) + 1
tree = tree_range_query._build_tree_from_leaf(leaf_nodes, arity)
self.assertEqual(depth, tree.shape[0])
for layer in range(depth):
reverse_depth = tree.shape[0] - layer - 1
span_size = arity**reverse_depth
for idx in range(arity**layer):
left = idx * span_size
right = (idx + 1) * span_size
expected_value = sum(leaf_nodes[left:right])
self.assertEqual(tree[layer][idx], expected_value)
class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
inner_query=['central', 'distributed'],
params=[(0., 1., 2), (1., -1., 2), (1., 1., 1)],
)
def test_raises_error(self, inner_query, params):
clip_norm, stddev, arity = params
with self.assertRaises(ValueError):
if inner_query == 'central':
tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
clip_norm, stddev, arity)
elif inner_query == 'distributed':
tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
clip_norm, stddev, arity)
@parameterized.product(
inner_query=['central', 'distributed'],
clip_norm=[0.1, 1.0, 10.0],
stddev=[0.1, 1.0, 10.0])
def test_initial_global_state_type(self, inner_query, clip_norm, stddev):
if inner_query == 'central':
query = tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
clip_norm, stddev)
elif inner_query == 'distributed':
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
clip_norm, stddev)
global_state = query.initial_global_state()
self.assertIsInstance(global_state,
tree_range_query.TreeRangeSumQuery.GlobalState)
@parameterized.product(
inner_query=['central', 'distributed'],
clip_norm=[0.1, 1.0, 10.0],
stddev=[0.1, 1.0, 10.0],
arity=[2, 3, 4])
def test_derive_sample_params(self, inner_query, clip_norm, stddev, arity):
if inner_query == 'central':
query = tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
clip_norm, stddev, arity)
elif inner_query == 'distributed':
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
clip_norm, stddev, arity)
global_state = query.initial_global_state()
derived_arity, inner_query_state = query.derive_sample_params(global_state)
self.assertAllClose(derived_arity, arity)
if inner_query == 'central':
self.assertAllClose(inner_query_state, clip_norm)
elif inner_query == 'distributed':
self.assertAllClose(inner_query_state.l2_norm_bound, clip_norm)
self.assertAllClose(inner_query_state.local_stddev, stddev)
@parameterized.product(
(dict(arity=2, expected_tree=[1, 1, 0, 1, 0, 0, 0]),
dict(arity=3, expected_tree=[1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])),
inner_query=['central', 'distributed'],
)
def test_preprocess_record(self, inner_query, arity, expected_tree):
if inner_query == 'central':
query = tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
10., 0., arity)
record = tf.constant([1, 0, 0, 0], dtype=tf.float32)
expected_tree = tf.cast(expected_tree, tf.float32)
elif inner_query == 'distributed':
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
10., 0., arity)
record = tf.constant([1, 0, 0, 0], dtype=tf.int32)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(preprocessed_record, expected_tree)
@parameterized.named_parameters(
('stddev_1', 1, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]),
('stddev_0_1', 4, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]),
)
def test_distributed_preprocess_record_with_noise(self, local_stddev, record,
expected_tree):
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
10., local_stddev)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(
preprocessed_record, expected_tree, atol=10 * local_stddev)
@parameterized.product(
(dict(
arity=2,
expected_tree=tf.ragged.constant([[1], [1, 0], [1, 0, 0, 0]])),
dict(
arity=3,
expected_tree=tf.ragged.constant([[1], [1, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0]]))),
inner_query=['central', 'distributed'],
)
def test_get_noised_result(self, inner_query, arity, expected_tree):
if inner_query == 'central':
query = tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
10., 0., arity)
record = tf.constant([1, 0, 0, 0], dtype=tf.float32)
expected_tree = tf.cast(expected_tree, tf.float32)
elif inner_query == 'distributed':
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
10., 0., arity)
record = tf.constant([1, 0, 0, 0], dtype=tf.int32)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(sample_state, expected_tree)
@parameterized.product(stddev=[0.1, 1.0, 10.0])
def test_central_get_noised_result_with_noise(self, stddev):
query = tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
10., stddev)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.]))
sample_state, global_state = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(
sample_state, tf.ragged.constant([[1.], [1., 0.]]), atol=10 * stddev)
if __name__ == '__main__':
tf.test.main()