diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index 7ef73a1..4e19a49 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -22,7 +22,6 @@ the leaf nodes of the tree arrive one by one as the time proceeds. """ import distutils import math -from typing import Optional import attr import tensorflow as tf @@ -732,224 +731,3 @@ def _get_add_noise(stddev, seed: int = None): return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype) return add_noise - - -class CentralTreeSumQuery(dp_query.SumAggregationDPQuery): - """Implements dp_query for differentially private tree aggregation protocol. - - Implements a central variant of the tree aggregation protocol from the paper - "'Is interaction necessary for distributed private learning?.' Adam Smith, - Abhradeep Thakurta, Jalaj Upadhyay" by replacing their local randomizer with - gaussian mechanism. The first step is to clip the clients' local updates (i.e. - a 1-D array containing the leaf nodes of the tree) by L1 norm to make sure it - does not exceed a prespecified upper bound. The second step is to construct - the tree on the clipped update. The third step is to add independent gaussian - noise to each node in the tree. The returned tree can support efficient and - accurate range queries with differential privacy. - """ - - @attr.s(frozen=True) - class GlobalState(object): - """Class defining global state for `CentralTreeSumQuery`. - - Attributes: - l1_bound: An upper bound on the L1 norm of the input record. This is - needed to bound the sensitivity and deploy differential privacy. - """ - l1_bound = attr.ib() - - def __init__(self, - stddev: float, - arity: int = 2, - l1_bound: int = 10, - seed: Optional[int] = None): - """Initializes the `CentralTreeSumQuery`. - - Args: - stddev: The stddev of the noise added to each internal node of the - constructed tree. - arity: The branching factor of the tree. - l1_bound: An upper bound on the L1 norm of the input record. This is - needed to bound the sensitivity and deploy differential privacy. - seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for - test purpose. - """ - self._stddev = stddev - self._arity = arity - self._l1_bound = l1_bound - self._seed = seed - - def initial_global_state(self): - """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" - return CentralTreeSumQuery.GlobalState(l1_bound=self._l1_bound) - - def derive_sample_params(self, global_state): - """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" - return global_state.l1_bound - - def preprocess_record(self, params, record): - """Implements `tensorflow_privacy.DPQuery.preprocess_record`.""" - casted_record = tf.cast(record, tf.float32) - l1_norm = tf.norm(casted_record, ord=1) - - l1_bound = tf.cast(params, tf.float32) - - preprocessed_record, _ = tf.clip_by_global_norm([casted_record], - l1_bound, - use_norm=l1_norm) - - return preprocessed_record[0] - - def get_noised_result(self, sample_state, global_state): - """Implements `tensorflow_privacy.DPQuery.get_noised_result`. - - Args: - sample_state: a frequency histogram. - global_state: hyper-parameters of the query. - - Returns: - a `tf.RaggedTensor` representing the tree built on top of `sample_state`. - The jth node on the ith layer of the tree can be accessed by tree[i][j] - where tree is the returned value. - """ - add_noise = _get_add_noise(self._stddev, self._seed) - tree = _build_tree_from_leaf(sample_state, self._arity) - return tf.map_fn(add_noise, tree), global_state - - -class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery): - """Implements dp_query for differentially private tree aggregation protocol. - - The difference from `CentralTreeSumQuery` is that the tree construction and - gaussian noise addition happen in `preprocess_records`. The difference only - takes effect when used together with - `tff.aggregators.DifferentiallyPrivateFactory`. In other cases, this class - should be treated as equal with `CentralTreeSumQuery`. - - Implements a distributed version of the tree aggregation protocol from. "Is - interaction necessary for distributed private learning?." by replacing their - local randomizer with gaussian mechanism. The first step is to check the L1 - norm of the clients' local updates (i.e. a 1-D array containing the leaf nodes - of the tree) to make sure it does not exceed a prespecified upper bound. The - second step is to construct the tree. The third step is to add independent - gaussian noise to each node in the tree. The returned tree can support - efficient and accurate range queries with differential privacy. - """ - - @attr.s(frozen=True) - class GlobalState(object): - """Class defining global state for DistributedTreeSumQuery. - - Attributes: - stddev: The stddev of the noise added to each internal node in the - constructed tree. - arity: The branching factor of the tree (i.e. the number of children each - internal node has). - l1_bound: An upper bound on the L1 norm of the input record. This is - needed to bound the sensitivity and deploy differential privacy. - """ - stddev = attr.ib() - arity = attr.ib() - l1_bound = attr.ib() - - def __init__(self, - stddev: float, - arity: int = 2, - l1_bound: int = 10, - seed: Optional[int] = None): - """Initializes the `DistributedTreeSumQuery`. - - Args: - stddev: The stddev of the noise added to each node in the tree. - arity: The branching factor of the tree. - l1_bound: An upper bound on the L1 norm of the input record. This is - needed to bound the sensitivity and deploy differential privacy. - seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for - test purpose. - """ - self._stddev = stddev - self._arity = arity - self._l1_bound = l1_bound - self._seed = seed - - def initial_global_state(self): - """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" - return DistributedTreeSumQuery.GlobalState( - stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound) - - def derive_sample_params(self, global_state): - """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" - return (global_state.stddev, global_state.arity, global_state.l1_bound) - - def preprocess_record(self, params, record): - """Implements `tensorflow_privacy.DPQuery.preprocess_record`. - - This method clips the input record by L1 norm, constructs a tree on top of - it, and adds gaussian noise to each node of the tree for differential - privacy. Unlike `get_noised_result` in `CentralTreeSumQuery`, this function - flattens the `tf.RaggedTensor` before outputting it. This is useful when - used inside `tff.aggregators.DifferentiallyPrivateFactory` because it does - not accept ragged output tensor. - - Args: - params: hyper-parameters for preprocessing record, (stddev, aritry, - l1_bound) - record: leaf nodes for the tree. - - Returns: - `tf.Tensor` representing the flattened version of the tree. - """ - _, arity, l1_bound_ = params - l1_bound = tf.cast(l1_bound_, tf.float32) - - casted_record = tf.cast(record, tf.float32) - l1_norm = tf.norm(casted_record, ord=1) - - preprocessed_record, _ = tf.clip_by_global_norm([casted_record], - l1_bound, - use_norm=l1_norm) - preprocessed_record = preprocessed_record[0] - - add_noise = _get_add_noise(self._stddev, self._seed) - tree = _build_tree_from_leaf(preprocessed_record, arity) - noisy_tree = tf.map_fn(add_noise, tree) - - # The following codes reshape the output vector so the output shape of can - # be statically inferred. This is useful when used with - # `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know - # the output shape of this function statically and explicitly. - flat_noisy_tree = noisy_tree.flat_values - flat_tree_shape = [ - (self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) - - 1) // (self._arity - 1) - ] - return tf.reshape(flat_noisy_tree, flat_tree_shape) - - def get_noised_result(self, sample_state, global_state): - """Implements `tensorflow_privacy.DPQuery.get_noised_result`. - - This function re-constructs the `tf.RaggedTensor` from the flattened tree - output by `preprocess_records.` - - Args: - sample_state: `tf.Tensor` for the flattened tree. - global_state: hyper-parameters including noise multiplier, the branching - factor of the tree and the maximum records per user. - - Returns: - a `tf.RaggedTensor` for the tree. - """ - # The [0] is needed because of how tf.RaggedTensor.from_two_splits works. - # print(tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6], - # row_splits=[0, 4, 4, 7, 8, 8])) - # - # This part is not written in tensorflow and will be executed on the server - # side instead of the client side if used with - # tff.aggregators.DifferentiallyPrivateFactory for federated learning. - row_splits = [0] + [ - (self._arity**(x + 1) - 1) // (self._arity - 1) for x in range( - math.floor(math.log(sample_state.shape[0], self._arity)) + 1) - ] - tree = tf.RaggedTensor.from_row_splits( - values=sample_state, row_splits=row_splits) - return tree, global_state diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py index 1bfaa21..f88ed90 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py @@ -630,242 +630,5 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase): sample_state, tf.ragged.constant([[1.], [1., 0.]]), atol=10 * stddev) -class CentralTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase): - - def test_initial_global_state_type(self): - - query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD) - global_state = query.initial_global_state() - self.assertIsInstance( - global_state, tree_aggregation_query.CentralTreeSumQuery.GlobalState) - - def test_derive_sample_params(self): - query = tree_aggregation_query.CentralTreeSumQuery(stddev=NOISE_STD) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - self.assertAllClose(params, 10.) - - @parameterized.named_parameters( - ('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32)), - ('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32)), - ('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32)), - ('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], - dtype=tf.float32)), - ) - def test_preprocess_record(self, arity, record): - query = tree_aggregation_query.CentralTreeSumQuery( - stddev=NOISE_STD, arity=arity) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - - self.assertAllClose(preprocessed_record, record) - - @parameterized.named_parameters( - ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32), - tf.constant([5, 5, 0, 0], dtype=tf.int32)), - ('binary_test_float', 2, tf.constant( - [10., 10., 0., 0.], - dtype=tf.float32), tf.constant([5., 5., 0., 0.], dtype=tf.float32)), - ('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32), - tf.constant([5, 5, 0, 0], dtype=tf.int32)), - ('ternary_test_float', 3, tf.constant([10., 10., 0., 0.], - dtype=tf.float32), - tf.constant([5., 5., 0., 0.], dtype=tf.float32)), - ) - def test_preprocess_record_clipped(self, arity, record, - expected_clipped_value): - query = tree_aggregation_query.CentralTreeSumQuery( - stddev=NOISE_STD, arity=arity) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - self.assertAllClose(preprocessed_record, expected_clipped_value) - - @parameterized.named_parameters( - ('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32), - tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])), - ('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32), - tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])), - ('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32), - tf.ragged.constant([[1.], [1., 0., 0.], - [1., 0., 0., 0., 0., 0., 0., 0., 0.]])), - ('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32), - tf.ragged.constant([[1.], [1., 0., 0.], - [1., 0., 0., 0., 0., 0., 0., 0., 0.]])), - ) - def test_get_noised_result(self, arity, record, expected_tree): - query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - sample_state, global_state = query.get_noised_result( - preprocessed_record, global_state) - - self.assertAllClose(sample_state, expected_tree) - - @parameterized.named_parameters( - ('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), - ('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), - ) - def test_get_noised_result_with_noise(self, stddev, record, expected_tree): - query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev, seed=0) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - - sample_state, _ = query.get_noised_result(preprocessed_record, global_state) - - self.assertAllClose( - sample_state.flat_values, expected_tree, atol=3 * stddev) - - @parameterized.named_parameters( - ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32), - tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])), - ('binary_test_float', 2, tf.constant([10., 10., 0., 0.], - dtype=tf.float32), - tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])), - ('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32), - tf.ragged.constant([[10.], [10., 0., 0.], - [5., 5., 0., 0., 0., 0., 0., 0., 0.]])), - ('ternary_test_float', 3, tf.constant([10., 10., 0., 0.], - dtype=tf.float32), - tf.ragged.constant([[10.], [10., 0., 0.], - [5., 5., 0., 0., 0., 0., 0., 0., 0.]])), - ) - def test_get_noised_result_clipped(self, arity, record, expected_tree): - query = tree_aggregation_query.CentralTreeSumQuery(stddev=0., arity=arity) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - sample_state, global_state = query.get_noised_result( - preprocessed_record, global_state) - - self.assertAllClose(sample_state, expected_tree) - - -class DistributedTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase): - - def test_initial_global_state_type(self): - - query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD) - global_state = query.initial_global_state() - self.assertIsInstance( - global_state, - tree_aggregation_query.DistributedTreeSumQuery.GlobalState) - - def test_derive_sample_params(self): - query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD) - global_state = query.initial_global_state() - stddev, arity, l1_bound = query.derive_sample_params(global_state) - self.assertAllClose(stddev, NOISE_STD) - self.assertAllClose(arity, 2) - self.assertAllClose(l1_bound, 10) - - @parameterized.named_parameters( - ('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32), - tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])), - ('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32), - tf.ragged.constant([1., 1., 0., 1., 0., 0., 0.])), - ('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32), - tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0. - ])), - ('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32), - tf.ragged.constant([1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0. - ])), - ) - def test_preprocess_record(self, arity, record, expected_tree): - query = tree_aggregation_query.DistributedTreeSumQuery( - stddev=0., arity=arity) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - self.assertAllClose(preprocessed_record, expected_tree) - - @parameterized.named_parameters( - ('stddev_0_01', 0.01, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), - ('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), - ) - def test_preprocess_record_with_noise(self, stddev, record, expected_tree): - query = tree_aggregation_query.DistributedTreeSumQuery( - stddev=stddev, seed=0) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - - preprocessed_record = query.preprocess_record(params, record) - - self.assertAllClose(preprocessed_record, expected_tree, atol=3 * stddev) - - @parameterized.named_parameters( - ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32), - tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])), - ('binary_test_float', 2, tf.constant([10., 10., 0., 0.], - dtype=tf.float32), - tf.ragged.constant([10., 10., 0., 5., 5., 0., 0.])), - ('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32), - tf.ragged.constant( - [10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])), - ('ternary_test_float', 3, tf.constant([10., 10., 0., 0.], - dtype=tf.float32), - tf.ragged.constant( - [10., 10., 0., 0., 5., 5., 0., 0., 0., 0., 0., 0., 0.])), - ) - def test_preprocess_record_clipped(self, arity, record, expected_tree): - query = tree_aggregation_query.DistributedTreeSumQuery( - stddev=0., arity=arity) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - self.assertAllClose(preprocessed_record, expected_tree) - - @parameterized.named_parameters( - ('binary_test_int', 2, tf.constant([1, 0, 0, 0], dtype=tf.int32), - tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])), - ('binary_test_float', 2, tf.constant([1., 0., 0., 0.], dtype=tf.float32), - tf.ragged.constant([[1.], [1., 0.], [1., 0., 0., 0.]])), - ('ternary_test_int', 3, tf.constant([1, 0, 0, 0], dtype=tf.int32), - tf.ragged.constant([[1.], [1., 0., 0.], - [1., 0., 0., 0., 0., 0., 0., 0., 0.]])), - ('ternary_test_float', 3, tf.constant([1., 0., 0., 0.], dtype=tf.float32), - tf.ragged.constant([[1.], [1., 0., 0.], - [1., 0., 0., 0., 0., 0., 0., 0., 0.]])), - ) - def test_get_noised_result(self, arity, record, expected_tree): - query = tree_aggregation_query.DistributedTreeSumQuery( - stddev=0., arity=arity) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - sample_state, global_state = query.get_noised_result( - preprocessed_record, global_state) - - self.assertAllClose(sample_state, expected_tree) - - @parameterized.named_parameters( - ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32), - tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])), - ('binary_test_float', 2, tf.constant([10., 10., 0., 0.], - dtype=tf.float32), - tf.ragged.constant([[10.], [10., 0.], [5., 5., 0., 0.]])), - ('ternary_test_int', 3, tf.constant([10, 10, 0, 0], dtype=tf.int32), - tf.ragged.constant([[10.], [10., 0., 0.], - [5., 5., 0., 0., 0., 0., 0., 0., 0.]])), - ('ternary_test_float', 3, tf.constant([10., 10., 0., 0.], - dtype=tf.float32), - tf.ragged.constant([[10.], [10., 0., 0.], - [5., 5., 0., 0., 0., 0., 0., 0., 0.]])), - ) - def test_get_noised_result_clipped(self, arity, record, expected_tree): - query = tree_aggregation_query.DistributedTreeSumQuery( - stddev=0., arity=arity) - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - preprocessed_record = query.preprocess_record(params, record) - sample_state, global_state = query.get_noised_result( - preprocessed_record, global_state) - - self.assertAllClose(sample_state, expected_tree) - - if __name__ == '__main__': tf.test.main()