Automated rollback of commit d9a7596815

PiperOrigin-RevId: 391885401
This commit is contained in:
Michael Reneer 2021-08-19 17:56:36 -07:00 committed by A. Unique TensorFlower
parent d9a7596815
commit 0600fa26a2
2 changed files with 459 additions and 0 deletions

View file

@ -22,6 +22,7 @@ the leaf nodes of the tree arrive one by one as the time proceeds.
""" """
import distutils import distutils
import math import math
from typing import Optional
import attr import attr
import tensorflow as tf import tensorflow as tf
@ -731,3 +732,224 @@ def _get_add_noise(stddev, seed: int = None):
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype) return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
return add_noise return add_noise
class CentralTreeSumQuery(dp_query.SumAggregationDPQuery):
"""Implements dp_query for differentially private tree aggregation protocol.
Implements a central variant of the tree aggregation protocol from the paper
"'Is interaction necessary for distributed private learning?.' Adam Smith,
Abhradeep Thakurta, Jalaj Upadhyay" by replacing their local randomizer with
gaussian mechanism. The first step is to clip the clients' local updates (i.e.
a 1-D array containing the leaf nodes of the tree) by L1 norm to make sure it
does not exceed a prespecified upper bound. The second step is to construct
the tree on the clipped update. The third step is to add independent gaussian
noise to each node in the tree. The returned tree can support efficient and
accurate range queries with differential privacy.
"""
@attr.s(frozen=True)
class GlobalState(object):
"""Class defining global state for `CentralTreeSumQuery`.
Attributes:
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
"""
l1_bound = attr.ib()
def __init__(self,
stddev: float,
arity: int = 2,
l1_bound: int = 10,
seed: Optional[int] = None):
"""Initializes the `CentralTreeSumQuery`.
Args:
stddev: The stddev of the noise added to each internal node of the
constructed tree.
arity: The branching factor of the tree.
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
test purpose.
"""
self._stddev = stddev
self._arity = arity
self._l1_bound = l1_bound
self._seed = seed
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return CentralTreeSumQuery.GlobalState(l1_bound=self._l1_bound)
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return global_state.l1_bound
def preprocess_record(self, params, record):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
casted_record = tf.cast(record, tf.float32)
l1_norm = tf.norm(casted_record, ord=1)
l1_bound = tf.cast(params, tf.float32)
preprocessed_record, _ = tf.clip_by_global_norm([casted_record],
l1_bound,
use_norm=l1_norm)
return preprocessed_record[0]
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
Args:
sample_state: a frequency histogram.
global_state: hyper-parameters of the query.
Returns:
a `tf.RaggedTensor` representing the tree built on top of `sample_state`.
The jth node on the ith layer of the tree can be accessed by tree[i][j]
where tree is the returned value.
"""
add_noise = _get_add_noise(self._stddev, self._seed)
tree = _build_tree_from_leaf(sample_state, self._arity)
return tf.map_fn(add_noise, tree), global_state
class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
"""Implements dp_query for differentially private tree aggregation protocol.
The difference from `CentralTreeSumQuery` is that the tree construction and
gaussian noise addition happen in `preprocess_records`. The difference only
takes effect when used together with
`tff.aggregators.DifferentiallyPrivateFactory`. In other cases, this class
should be treated as equal with `CentralTreeSumQuery`.
Implements a distributed version of the tree aggregation protocol from. "Is
interaction necessary for distributed private learning?." by replacing their
local randomizer with gaussian mechanism. The first step is to check the L1
norm of the clients' local updates (i.e. a 1-D array containing the leaf nodes
of the tree) to make sure it does not exceed a prespecified upper bound. The
second step is to construct the tree. The third step is to add independent
gaussian noise to each node in the tree. The returned tree can support
efficient and accurate range queries with differential privacy.
"""
@attr.s(frozen=True)
class GlobalState(object):
"""Class defining global state for DistributedTreeSumQuery.
Attributes:
stddev: The stddev of the noise added to each internal node in the
constructed tree.
arity: The branching factor of the tree (i.e. the number of children each
internal node has).
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
"""
stddev = attr.ib()
arity = attr.ib()
l1_bound = attr.ib()
def __init__(self,
stddev: float,
arity: int = 2,
l1_bound: int = 10,
seed: Optional[int] = None):
"""Initializes the `DistributedTreeSumQuery`.
Args:
stddev: The stddev of the noise added to each node in the tree.
arity: The branching factor of the tree.
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
test purpose.
"""
self._stddev = stddev
self._arity = arity
self._l1_bound = l1_bound
self._seed = seed
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return DistributedTreeSumQuery.GlobalState(
stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound)
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return (global_state.stddev, global_state.arity, global_state.l1_bound)
def preprocess_record(self, params, record):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
This method clips the input record by L1 norm, constructs a tree on top of
it, and adds gaussian noise to each node of the tree for differential
privacy. Unlike `get_noised_result` in `CentralTreeSumQuery`, this function
flattens the `tf.RaggedTensor` before outputting it. This is useful when
used inside `tff.aggregators.DifferentiallyPrivateFactory` because it does
not accept ragged output tensor.
Args:
params: hyper-parameters for preprocessing record, (stddev, aritry,
l1_bound)
record: leaf nodes for the tree.
Returns:
`tf.Tensor` representing the flattened version of the tree.
"""
_, arity, l1_bound_ = params
l1_bound = tf.cast(l1_bound_, tf.float32)
casted_record = tf.cast(record, tf.float32)
l1_norm = tf.norm(casted_record, ord=1)
preprocessed_record, _ = tf.clip_by_global_norm([casted_record],
l1_bound,
use_norm=l1_norm)
preprocessed_record = preprocessed_record[0]
add_noise = _get_add_noise(self._stddev, self._seed)
tree = _build_tree_from_leaf(preprocessed_record, arity)
noisy_tree = tf.map_fn(add_noise, tree)
# The following codes reshape the output vector so the output shape of can
# be statically inferred. This is useful when used with
# `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know
# the output shape of this function statically and explicitly.
flat_noisy_tree = noisy_tree.flat_values
flat_tree_shape = [
(self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) -
1) // (self._arity - 1)
]
return tf.reshape(flat_noisy_tree, flat_tree_shape)
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
This function re-constructs the `tf.RaggedTensor` from the flattened tree
output by `preprocess_records.`
Args:
sample_state: `tf.Tensor` for the flattened tree.
global_state: hyper-parameters including noise multiplier, the branching
factor of the tree and the maximum records per user.
Returns:
a `tf.RaggedTensor` for the tree.
"""
# The [0] is needed because of how tf.RaggedTensor.from_two_splits works.
# print(tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6],
# row_splits=[0, 4, 4, 7, 8, 8]))
# <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
# This part is not written in tensorflow and will be executed on the server
# side instead of the client side if used with
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
row_splits = [0] + [
(self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(
math.floor(math.log(sample_state.shape[0], self._arity)) + 1)
]
tree = tf.RaggedTensor.from_row_splits(
values=sample_state, row_splits=row_splits)
return tree, global_state

View file

@ -630,5 +630,242 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
sample_state, tf.ragged.constant([[1.], [1., 0.]]), atol=10 * stddev) sample_state, tf.ragged.constant([[1.], [1., 0.]]), atol=10 * stddev)
class CentralTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_initial_global_state_type(self):
query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD)
global_state = query.initial_global_state()
self.assertIsInstance(
global_state, tree_aggregation_query.CentralTreeSumQuery.GlobalState)
def test_derive_sample_params(self):
query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
self.assertAllClose(params, 10.)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32)),
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32)),
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32)),
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.],
dtype=tf.float32)),
)
def test_preprocess_record(self, arity, record):
query = tree_aggregation_query.CentralTreeSumQuery(
stddev=NOISE_STD, arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(preprocessed_record, record)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.constant([5, 5, 0, 0], dtype=tf.int32)),
('binary_test_float', 2, tf.constant(
[10., 10., 0., 0.],
dtype=tf.float32), tf.constant([5., 5., 0., 0.], dtype=tf.float32)),
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.constant([5, 5, 0, 0], dtype=tf.int32)),
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.constant([5., 5., 0., 0.], dtype=tf.float32)),
)
def test_preprocess_record_clipped(self, arity, record,
expected_clipped_value):
query = tree_aggregation_query.CentralTreeSumQuery(
stddev=NOISE_STD, arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(preprocessed_record, expected_clipped_value)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([[1.], [1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([[1.], [1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
)
def test_get_noised_result(self, arity, record, expected_tree):
query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(sample_state, expected_tree)
@parameterized.named_parameters(
('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
)
def test_get_noised_result_with_noise(self, stddev, record, expected_tree):
query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev, seed=0)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state, _ = query.get_noised_result(preprocessed_record, global_state)
self.assertAllClose(
sample_state.flat_values, expected_tree, atol=3 * stddev)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant([[10.], [10., 0., 0.],
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant([[10.], [10., 0., 0.],
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
)
def test_get_noised_result_clipped(self, arity, record, expected_tree):
query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(sample_state, expected_tree)
class DistributedTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_initial_global_state_type(self):
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
global_state = query.initial_global_state()
self.assertIsInstance(
global_state,
tree_aggregation_query.DistributedTreeSumQuery.GlobalState)
def test_derive_sample_params(self):
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
global_state = query.initial_global_state()
stddev, arity, l1_bound = query.derive_sample_params(global_state)
self.assertAllClose(stddev, NOISE_STD)
self.assertAllClose(arity, 2)
self.assertAllClose(l1_bound, 10)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])),
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])),
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.
])),
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.
])),
)
def test_preprocess_record(self, arity, record, expected_tree):
query = tree_aggregation_query.DistributedTreeSumQuery(
stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(preprocessed_record, expected_tree)
@parameterized.named_parameters(
('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
)
def test_preprocess_record_with_noise(self, stddev, record, expected_tree):
query = tree_aggregation_query.DistributedTreeSumQuery(
stddev=stddev, seed=0)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(preprocessed_record, expected_tree, atol=3 * stddev)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])),
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])),
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant(
[10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])),
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant(
[10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])),
)
def test_preprocess_record_clipped(self, arity, record, expected_tree):
query = tree_aggregation_query.DistributedTreeSumQuery(
stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(preprocessed_record, expected_tree)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])),
('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32),
tf.ragged.constant([[1.], [1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32),
tf.ragged.constant([[1.], [1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0.]])),
)
def test_get_noised_result(self, arity, record, expected_tree):
query = tree_aggregation_query.DistributedTreeSumQuery(
stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(sample_state, expected_tree)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
('binary_test_float', 2, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])),
('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32),
tf.ragged.constant([[10.], [10., 0., 0.],
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
('ternary_test_float', 3, tf.constant([10., 10., 0., 0.],
dtype=tf.float32),
tf.ragged.constant([[10.], [10., 0., 0.],
[5., 5., 0., 0., 0., 0., 0., 0., 0.]])),
)
def test_get_noised_result_clipped(self, arity, record, expected_tree):
query = tree_aggregation_query.DistributedTreeSumQuery(
stddev=0., arity=arity)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(sample_state, expected_tree)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()