From 8850c23f67d31a3baee2490224dea9a605581ddd Mon Sep 17 00:00:00 2001 From: Zheng Xu Date: Tue, 7 Dec 2021 10:48:30 -0800 Subject: [PATCH] Move tree_aggregation accountant to their own module. PiperOrigin-RevId: 414770173 --- tensorflow_privacy/__init__.py | 3 +- .../privacy/analysis/rdp_accountant.py | 251 -------------- .../privacy/analysis/rdp_accountant_test.py | 156 --------- .../analysis/tree_aggregation_accountant.py | 315 ++++++++++++++++++ .../tree_aggregation_accountant_test.py | 185 ++++++++++ 5 files changed, 502 insertions(+), 408 deletions(-) create mode 100644 tensorflow_privacy/privacy/analysis/tree_aggregation_accountant.py create mode 100644 tensorflow_privacy/privacy/analysis/tree_aggregation_accountant_test.py diff --git a/tensorflow_privacy/__init__.py b/tensorflow_privacy/__init__.py index 72cc746..fcf607e 100644 --- a/tensorflow_privacy/__init__.py +++ b/tensorflow_privacy/__init__.py @@ -45,8 +45,9 @@ else: from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_heterogeneous_rdp from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp - from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp_tree_restart from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent + from tensorflow_privacy.privacy.analysis.tree_aggregation_accountant import compute_rdp_tree_restart + from tensorflow_privacy.privacy.analysis.tree_aggregation_accountant import compute_rdp_single_tree # DPQuery classes from tensorflow_privacy.privacy.dp_query.dp_query import DPQuery diff --git a/tensorflow_privacy/privacy/analysis/rdp_accountant.py b/tensorflow_privacy/privacy/analysis/rdp_accountant.py index e00c7b3..380ff9c 100644 --- a/tensorflow_privacy/privacy/analysis/rdp_accountant.py +++ b/tensorflow_privacy/privacy/analysis/rdp_accountant.py @@ -40,11 +40,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools import math import sys -from typing import Collection, Union - import numpy as np from scipy import special import six @@ -399,254 +396,6 @@ def compute_rdp(q, noise_multiplier, steps, orders): return rdp * steps -# TODO(b/193679963): move accounting for tree aggregation to a separate module -def _compute_rdp_tree_restart(sigma, steps_list, alpha): - """Computes RDP of the Tree Aggregation Protocol at order alpha.""" - if np.isinf(alpha): - return np.inf - tree_depths = [ - math.floor(math.log2(float(steps))) + 1 - for steps in steps_list - if steps > 0 - ] - return _compute_gaussian_rdp( - alpha=alpha, sum_sensitivity_square=sum(tree_depths), sigma=sigma) - - -def compute_rdp_tree_restart( - noise_multiplier: float, steps_list: Union[int, Collection[int]], - orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]: - """Computes RDP of the Tree Aggregation Protocol for Gaussian Mechanism. - - This function implements the accounting when the tree is restarted at every - epoch. See appendix of - "Practical and Private (Deep) Learning without Sampling or Shuffling" - https://arxiv.org/abs/2103.00039. - - Args: - noise_multiplier: A non-negative float representing the ratio of the - standard deviation of the Gaussian noise to the l2-sensitivity of the - function to which it is added. - steps_list: A scalar or a list of non-negative intergers representing the - number of steps per epoch (between two restarts). - orders: An array (or a scalar) of RDP orders. - - Returns: - The RDPs at all orders. Can be `np.inf`. - """ - _check_nonnegative(noise_multiplier, "noise_multiplier") - if noise_multiplier == 0: - return np.inf - - if not steps_list: - raise ValueError( - "steps_list must be a non-empty list, or a non-zero scalar, got " - f"{steps_list}.") - - if np.isscalar(steps_list): - steps_list = [steps_list] - - for steps in steps_list: - if steps < 0: - raise ValueError(f"Steps must be non-negative, got {steps_list}") - - if np.isscalar(orders): - rdp = _compute_rdp_tree_restart(noise_multiplier, steps_list, orders) - else: - rdp = np.array([ - _compute_rdp_tree_restart(noise_multiplier, steps_list, alpha) - for alpha in orders - ]) - - return rdp - - -def _check_nonnegative(value: Union[int, float], name: str): - if value < 0: - raise ValueError(f"Provided {name} must be non-negative, got {value}") - - -def _check_possible_tree_participation(num_participation: int, - min_separation: int, start: int, - end: int, steps: int) -> bool: - """Check if participation is possible with `min_separation` in `steps`. - - This function checks if it is possible for a sample to appear - `num_participation` in `steps`, assuming there are at least `min_separation` - nodes between the appearance of the same sample in the streaming data (leaf - nodes in tree aggregation). The first appearance of the sample is after - `start` steps, and the sample won't appear in the `end` steps after the given - `steps`. - - Args: - num_participation: The number of times a sample will appear. - min_separation: The minimum number of nodes between two appearance of a - sample. If a sample appears in consecutive x, y steps in a streaming - setting, then `min_separation=y-x-1`. - start: The first appearance of the sample is after `start` steps. - end: The sample won't appear in the `end` steps after the given `steps`. - steps: Total number of steps (leaf nodes in tree aggregation). - - Returns: - True if a sample can appear `num_participation` with given conditions. - """ - return start + (min_separation + 1) * num_participation <= steps + end - - -@functools.lru_cache(maxsize=None) -def _tree_sensitivity_square_sum(num_participation: int, min_separation: int, - start: int, end: int, size: int) -> float: - """Compute the worst-case sum of sensitivtiy square for `num_participation`. - - This is the key algorithm for DP accounting for DP-FTRL tree aggregation - without restart, which recurrently counts the worst-case occurence of a sample - in all the nodes in a tree. This implements a dynamic programming algorithm - that exhausts the possible `num_participation` appearance of a sample in - `size` leaf nodes. See Appendix D.2 (DP-FTRL-NoTreeRestart) of - "Practical and Private (Deep) Learning without Sampling or Shuffling" - https://arxiv.org/abs/2103.00039. - - Args: - num_participation: The number of times a sample will appear. - min_separation: The minimum number of nodes between two appearance of a - sample. If a sample appears in consecutive x, y size in a streaming - setting, then `min_separation=y-x-1`. - start: The first appearance of the sample is after `start` steps. - end: The sample won't appear in the `end` steps after given `size` steps. - size: Total number of steps (leaf nodes in tree aggregation). - - Returns: - The worst-case sum of sensitivity square for the given input. - """ - if not _check_possible_tree_participation(num_participation, min_separation, - start, end, size): - sum_value = -np.inf - elif num_participation == 0: - sum_value = 0. - elif num_participation == 1 and size == 1: - sum_value = 1. - else: - size_log2 = math.log2(size) - max_2power = math.floor(size_log2) - if max_2power == size_log2: - sum_value = num_participation**2 - max_2power -= 1 - else: - sum_value = 0. - candidate_sum = [] - # i is the `num_participation` in the right subtree - for i in range(num_participation + 1): - # j is the `start` in the right subtree - for j in range(min_separation + 1): - left_sum = _tree_sensitivity_square_sum( - num_participation=num_participation - i, - min_separation=min_separation, - start=start, - end=j, - size=2**max_2power) - if np.isinf(left_sum): - candidate_sum.append(-np.inf) - continue # Early pruning for dynamic programming - right_sum = _tree_sensitivity_square_sum( - num_participation=i, - min_separation=min_separation, - start=j, - end=end, - size=size - 2**max_2power) - candidate_sum.append(left_sum + right_sum) - sum_value += max(candidate_sum) - return sum_value - - -def _max_tree_sensitivity_square_sum(max_participation: int, - min_separation: int, steps: int) -> float: - """Compute the worst-case sum of sensitivity square in tree aggregation. - - See Appendix D.2 of - "Practical and Private (Deep) Learning without Sampling or Shuffling" - https://arxiv.org/abs/2103.00039. - - Args: - max_participation: The maximum number of times a sample will appear. - min_separation: The minimum number of nodes between two appearance of a - sample. If a sample appears in consecutive x, y steps in a streaming - setting, then `min_separation=y-x-1`. - steps: Total number of steps (leaf nodes in tree aggregation). - - Returns: - The worst-case sum of sensitivity square for the given input. - """ - num_participation = max_participation - while not _check_possible_tree_participation( - num_participation, min_separation, 0, min_separation, steps): - num_participation -= 1 - candidate_sum = [] - for num_part in range(1, num_participation + 1): - candidate_sum.append( - _tree_sensitivity_square_sum(num_part, min_separation, 0, - min_separation, steps)) - return max(candidate_sum) - - -def _compute_gaussian_rdp(sigma: float, sum_sensitivity_square: float, - alpha: float) -> float: - """Computes RDP of Gaussian mechanism.""" - if np.isinf(alpha): - return np.inf - return alpha * sum_sensitivity_square / (2 * sigma**2) - - -def compute_rdp_single_tree( - noise_multiplier: float, total_steps: int, max_participation: int, - min_separation: int, - orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]: - """Computes RDP of the Tree Aggregation Protocol for a single tree. - - The accounting assume a single tree is constructed for `total_steps` leaf - nodes, where the same sample will appear at most `max_participation` times, - and there are at least `min_separation` nodes between two appearance. The key - idea is to (recurrently) count the worst-case occurence of a sample - in all the nodes in a tree, which implements a dynamic programming algorithm - that exhausts the possible `num_participation` appearance of a sample in - `steps` leaf nodes. - - See Appendix D of - "Practical and Private (Deep) Learning without Sampling or Shuffling" - https://arxiv.org/abs/2103.00039. - - Args: - noise_multiplier: A non-negative float representing the ratio of the - standard deviation of the Gaussian noise to the l2-sensitivity of the - function to which it is added. - total_steps: Total number of steps (leaf nodes in tree aggregation). - max_participation: The maximum number of times a sample can appear. - min_separation: The minimum number of nodes between two appearance of a - sample. If a sample appears in consecutive x, y steps in a streaming - setting, then `min_separation=y-x-1`. - orders: An array (or a scalar) of RDP orders. - - Returns: - The RDPs at all orders. Can be `np.inf`. - """ - _check_nonnegative(noise_multiplier, "noise_multiplier") - if noise_multiplier == 0: - return np.inf - _check_nonnegative(total_steps, "total_steps") - _check_nonnegative(max_participation, "max_participation") - _check_nonnegative(min_separation, "min_separation") - sum_sensitivity_square = _max_tree_sensitivity_square_sum( - max_participation, min_separation, total_steps) - if np.isscalar(orders): - rdp = _compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square, - orders) - else: - rdp = np.array([ - _compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square, alpha) - for alpha in orders - ]) - return rdp - - def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders): """Compute RDP of Gaussian Mechanism using sampling without replacement. diff --git a/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py b/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py index df241af..63983ad 100644 --- a/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py +++ b/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py @@ -264,161 +264,5 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase): self.assertLessEqual(delta, delta1 + 1e-300) -class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase): - - @parameterized.named_parameters(('eps20', 1.13, 19.74), ('eps2', 8.83, 2.04)) - def test_compute_eps_tree(self, noise_multiplier, eps): - orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64)) - # This tests is based on the StackOverflow setting in "Practical and - # Private (Deep) Learning without Sampling or Shuffling". The calculated - # epsilon could be better as the method in this package keeps improving. - steps_list, target_delta = 1600, 1e-6 - rdp = rdp_accountant.compute_rdp_tree_restart(noise_multiplier, steps_list, - orders) - new_eps = rdp_accountant.get_privacy_spent( - orders, rdp, target_delta=target_delta)[0] - self.assertLess(new_eps, eps) - - @parameterized.named_parameters( - ('restart4', [400] * 4), - ('restart2', [800] * 2), - ('adaptive', [10, 400, 400, 400, 390]), - ) - def test_compose_tree_rdp(self, steps_list): - noise_multiplier, orders = 0.1, 1 - rdp_list = [ - rdp_accountant.compute_rdp_tree_restart(noise_multiplier, steps, orders) - for steps in steps_list - ] - rdp_composed = rdp_accountant.compute_rdp_tree_restart( - noise_multiplier, steps_list, orders) - self.assertAllClose(rdp_composed, sum(rdp_list), rtol=1e-12) - - @parameterized.named_parameters( - ('restart4', [400] * 4), - ('restart2', [800] * 2), - ('adaptive', [10, 400, 400, 400, 390]), - ) - def test_compute_eps_tree_decreasing(self, steps_list): - # Test privacy epsilon decreases with noise multiplier increasing when - # keeping other parameters the same. - orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64)) - target_delta = 1e-6 - prev_eps = rdp_accountant.compute_rdp_tree_restart(0, steps_list, orders) - for noise_multiplier in [0.1 * x for x in range(1, 100, 5)]: - rdp = rdp_accountant.compute_rdp_tree_restart(noise_multiplier, - steps_list, orders) - eps = rdp_accountant.get_privacy_spent( - orders, rdp, target_delta=target_delta)[0] - self.assertLess(eps, prev_eps) - - @parameterized.named_parameters( - ('negative_noise', -1, 3, 1), - ('empty_steps', 1, [], 1), - ('negative_steps', 1, -3, 1), - ) - def test_compute_rdp_tree_restart_raise(self, noise_multiplier, steps_list, - orders): - with self.assertRaisesRegex(ValueError, 'must be'): - rdp_accountant.compute_rdp_tree_restart(noise_multiplier, steps_list, - orders) - - @parameterized.named_parameters( - ('t100n0.1', 100, 0.1), - ('t1000n0.01', 1000, 0.01), - ) - def test_no_tree_no_sampling(self, total_steps, noise_multiplier): - orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64)) - tree_rdp = rdp_accountant.compute_rdp_tree_restart(noise_multiplier, - [1] * total_steps, - orders) - rdp = rdp_accountant.compute_rdp(1., noise_multiplier, total_steps, orders) - self.assertAllClose(tree_rdp, rdp, rtol=1e-12) - - @parameterized.named_parameters( - ('negative_noise', -1, 3, 1, 1), - ('negative_steps', 0.1, -3, 1, 1), - ('negative_part', 0.1, 3, -1, 1), - ('negative_sep', 0.1, 3, 1, -1), - ) - def test_compute_rdp_single_tree_raise(self, noise_multiplier, total_steps, - max_participation, min_separation): - orders = 1 - with self.assertRaisesRegex(ValueError, 'must be'): - rdp_accountant.compute_rdp_single_tree(noise_multiplier, total_steps, - max_participation, min_separation, - orders) - - @parameterized.named_parameters( - ('3', 3), - ('8', 8), - ('11', 11), - ('19', 19), - ) - def test_max_tree_sensitivity_square_sum_every_step(self, steps): - max_participation, min_separation = steps, 0 - # If a sample will appear in every leaf node, we can infer the total - # sensitivity by adding all the nodes. - steps_bin = bin(steps)[2:] - depth = [ - len(steps_bin) - 1 - i for i, v in enumerate(steps_bin) if v == '1' - ] - expected = sum([2**d * (2**(d + 1) - 1) for d in depth]) - self.assertEqual( - expected, - rdp_accountant._max_tree_sensitivity_square_sum(max_participation, - min_separation, steps)) - - @parameterized.named_parameters( - ('11', 11), - ('19', 19), - ('200', 200), - ) - def test_max_tree_sensitivity_square_sum_every_step_part(self, max_part): - steps, min_separation = 8, 0 - assert max_part > steps - # If a sample will appear in every leaf node, we can infer the total - # sensitivity by adding all the nodes. - expected = 120 - self.assertEqual( - expected, - rdp_accountant._max_tree_sensitivity_square_sum(max_part, - min_separation, steps)) - - @parameterized.named_parameters( - ('3', 3), - ('8', 8), - ('11', 11), - ('19', 19), - ) - def test_max_tree_sensitivity_square_sum_every_step_part2(self, steps): - max_participation, min_separation = 2, 0 - # If a sample will appear twice, the worst case is to put the two nodes at - # consecutive nodes of the deepest subtree. - steps_bin = bin(steps)[2:] - depth = len(steps_bin) - 1 - expected = 2 + 4 * depth - self.assertEqual( - expected, - rdp_accountant._max_tree_sensitivity_square_sum(max_participation, - min_separation, steps)) - - @parameterized.named_parameters( - ('test1', 1, 7, 8, 4), - ('test2', 3, 3, 9, 11), - ('test3', 3, 2, 7, 9), - # This is an example showing worst-case sensitivity is larger than greedy - # in "Practical and Private (Deep) Learning without Sampling or Shuffling" - # https://arxiv.org/abs/2103.00039. - ('test4', 8, 2, 24, 88), - ) - def test_max_tree_sensitivity_square_sum_toy(self, max_participation, - min_separation, steps, expected): - self.assertEqual( - expected, - rdp_accountant._max_tree_sensitivity_square_sum(max_participation, - min_separation, steps)) - - if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_privacy/privacy/analysis/tree_aggregation_accountant.py b/tensorflow_privacy/privacy/analysis/tree_aggregation_accountant.py new file mode 100644 index 0000000..417c910 --- /dev/null +++ b/tensorflow_privacy/privacy/analysis/tree_aggregation_accountant.py @@ -0,0 +1,315 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""DP analysis of tree aggregation. + +See Appendix D of +"Practical and Private (Deep) Learning without Sampling or Shuffling" + https://arxiv.org/abs/2103.00039. + +Functionality for computing differential privacy of tree aggregation of Gaussian +mechanism. Its public interface consists of the following methods: + compute_rdp_tree_restart( + noise_multiplier: float, steps_list: Union[int, Collection[int]], + orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]: + computes RDP for DP-FTRL-TreeRestart. + compute_rdp_single_tree( + noise_multiplier: float, total_steps: int, max_participation: int, + min_separation: int, + orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]: + computes RDP for DP-FTRL-NoTreeRestart. + +For RDP to (epsilon, delta)-DP conversion, use the following public function +described in `rdp_accountant.py`: + get_privacy_spent(orders, rdp, target_eps, target_delta) computes delta + (or eps) given RDP at multiple orders and + a target value for eps (or delta). + +Example use: + +(1) DP-FTRL-TreeRestart RDP: +Suppose we use Gaussian mechanism of `noise_multiplier`; a sample may appear +at most once for every epoch and tree is restarted every epoch; the number of +leaf nodes for every epoch are tracked in `steps_list`. For `target_delta`, the +estimated epsilon is: + orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64)) + rdp = compute_rdp_tree_restart(noise_multiplier, steps_list, orders) + eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0] + +(2) DP-FTRL-NoTreeRestart RDP: +Suppose we use Gaussian mechanism of `noise_multiplier`; a sample may appear +at most `max_participation` times for a total of `total_steps` leaf nodes in a +single tree; there are at least `min_separation` leaf nodes between the two +appearance of a same sample. For `target_delta`, the estimated epsilon is: + orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64)) + rdp = compute_rdp_single_tree(noise_multiplier, total_steps, + max_participation, min_separation, orders) + eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0] +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import math +from typing import Collection, Union + +import numpy as np + + +def _compute_rdp_tree_restart(sigma, steps_list, alpha): + """Computes RDP of the Tree Aggregation Protocol at order alpha.""" + if np.isinf(alpha): + return np.inf + tree_depths = [ + math.floor(math.log2(float(steps))) + 1 + for steps in steps_list + if steps > 0 + ] + return _compute_gaussian_rdp( + alpha=alpha, sum_sensitivity_square=sum(tree_depths), sigma=sigma) + + +def compute_rdp_tree_restart( + noise_multiplier: float, steps_list: Union[int, Collection[int]], + orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]: + """Computes RDP of the Tree Aggregation Protocol for Gaussian Mechanism. + + This function implements the accounting when the tree is restarted at every + epoch. See appendix D of + "Practical and Private (Deep) Learning without Sampling or Shuffling" + https://arxiv.org/abs/2103.00039. + + Args: + noise_multiplier: A non-negative float representing the ratio of the + standard deviation of the Gaussian noise to the l2-sensitivity of the + function to which it is added. + steps_list: A scalar or a list of non-negative intergers representing the + number of steps per epoch (between two restarts). + orders: An array (or a scalar) of RDP orders. + + Returns: + The RDPs at all orders. Can be `np.inf`. + """ + _check_nonnegative(noise_multiplier, "noise_multiplier") + if noise_multiplier == 0: + return np.inf + + if not steps_list: + raise ValueError( + "steps_list must be a non-empty list, or a non-zero scalar, got " + f"{steps_list}.") + + if np.isscalar(steps_list): + steps_list = [steps_list] + + for steps in steps_list: + if steps < 0: + raise ValueError(f"Steps must be non-negative, got {steps_list}") + + if np.isscalar(orders): + rdp = _compute_rdp_tree_restart(noise_multiplier, steps_list, orders) + else: + rdp = np.array([ + _compute_rdp_tree_restart(noise_multiplier, steps_list, alpha) + for alpha in orders + ]) + + return rdp + + +def _check_nonnegative(value: Union[int, float], name: str): + if value < 0: + raise ValueError(f"Provided {name} must be non-negative, got {value}") + + +def _check_possible_tree_participation(num_participation: int, + min_separation: int, start: int, + end: int, steps: int) -> bool: + """Check if participation is possible with `min_separation` in `steps`. + + This function checks if it is possible for a sample to appear + `num_participation` in `steps`, assuming there are at least `min_separation` + nodes between the appearance of the same sample in the streaming data (leaf + nodes in tree aggregation). The first appearance of the sample is after + `start` steps, and the sample won't appear in the `end` steps after the given + `steps`. + + Args: + num_participation: The number of times a sample will appear. + min_separation: The minimum number of nodes between two appearance of a + sample. If a sample appears in consecutive x, y steps in a streaming + setting, then `min_separation=y-x-1`. + start: The first appearance of the sample is after `start` steps. + end: The sample won't appear in the `end` steps after the given `steps`. + steps: Total number of steps (leaf nodes in tree aggregation). + + Returns: + True if a sample can appear `num_participation` with given conditions. + """ + return start + (min_separation + 1) * num_participation <= steps + end + + +@functools.lru_cache(maxsize=None) +def _tree_sensitivity_square_sum(num_participation: int, min_separation: int, + start: int, end: int, size: int) -> float: + """Compute the worst-case sum of sensitivtiy square for `num_participation`. + + This is the key algorithm for DP accounting for DP-FTRL tree aggregation + without restart, which recurrently counts the worst-case occurence of a sample + in all the nodes in a tree. This implements a dynamic programming algorithm + that exhausts the possible `num_participation` appearance of a sample in + `size` leaf nodes. See Appendix D.2 (DP-FTRL-NoTreeRestart) of + "Practical and Private (Deep) Learning without Sampling or Shuffling" + https://arxiv.org/abs/2103.00039. + + Args: + num_participation: The number of times a sample will appear. + min_separation: The minimum number of nodes between two appearance of a + sample. If a sample appears in consecutive x, y size in a streaming + setting, then `min_separation=y-x-1`. + start: The first appearance of the sample is after `start` steps. + end: The sample won't appear in the `end` steps after given `size` steps. + size: Total number of steps (leaf nodes in tree aggregation). + + Returns: + The worst-case sum of sensitivity square for the given input. + """ + if not _check_possible_tree_participation(num_participation, min_separation, + start, end, size): + sum_value = -np.inf + elif num_participation == 0: + sum_value = 0. + elif num_participation == 1 and size == 1: + sum_value = 1. + else: + size_log2 = math.log2(size) + max_2power = math.floor(size_log2) + if max_2power == size_log2: + sum_value = num_participation**2 + max_2power -= 1 + else: + sum_value = 0. + candidate_sum = [] + # i is the `num_participation` in the right subtree + for i in range(num_participation + 1): + # j is the `start` in the right subtree + for j in range(min_separation + 1): + left_sum = _tree_sensitivity_square_sum( + num_participation=num_participation - i, + min_separation=min_separation, + start=start, + end=j, + size=2**max_2power) + if np.isinf(left_sum): + candidate_sum.append(-np.inf) + continue # Early pruning for dynamic programming + right_sum = _tree_sensitivity_square_sum( + num_participation=i, + min_separation=min_separation, + start=j, + end=end, + size=size - 2**max_2power) + candidate_sum.append(left_sum + right_sum) + sum_value += max(candidate_sum) + return sum_value + + +def _max_tree_sensitivity_square_sum(max_participation: int, + min_separation: int, steps: int) -> float: + """Compute the worst-case sum of sensitivity square in tree aggregation. + + See Appendix D.2 of + "Practical and Private (Deep) Learning without Sampling or Shuffling" + https://arxiv.org/abs/2103.00039. + + Args: + max_participation: The maximum number of times a sample will appear. + min_separation: The minimum number of nodes between two appearance of a + sample. If a sample appears in consecutive x, y steps in a streaming + setting, then `min_separation=y-x-1`. + steps: Total number of steps (leaf nodes in tree aggregation). + + Returns: + The worst-case sum of sensitivity square for the given input. + """ + num_participation = max_participation + while not _check_possible_tree_participation( + num_participation, min_separation, 0, min_separation, steps): + num_participation -= 1 + candidate_sum = [] + for num_part in range(1, num_participation + 1): + candidate_sum.append( + _tree_sensitivity_square_sum(num_part, min_separation, 0, + min_separation, steps)) + return max(candidate_sum) + + +def _compute_gaussian_rdp(sigma: float, sum_sensitivity_square: float, + alpha: float) -> float: + """Computes RDP of Gaussian mechanism.""" + if np.isinf(alpha): + return np.inf + return alpha * sum_sensitivity_square / (2 * sigma**2) + + +def compute_rdp_single_tree( + noise_multiplier: float, total_steps: int, max_participation: int, + min_separation: int, + orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]: + """Computes RDP of the Tree Aggregation Protocol for a single tree. + + The accounting assume a single tree is constructed for `total_steps` leaf + nodes, where the same sample will appear at most `max_participation` times, + and there are at least `min_separation` nodes between two appearance. The key + idea is to (recurrently) count the worst-case occurence of a sample + in all the nodes in a tree, which implements a dynamic programming algorithm + that exhausts the possible `num_participation` appearance of a sample in + `steps` leaf nodes. + + See Appendix D of + "Practical and Private (Deep) Learning without Sampling or Shuffling" + https://arxiv.org/abs/2103.00039. + + Args: + noise_multiplier: A non-negative float representing the ratio of the + standard deviation of the Gaussian noise to the l2-sensitivity of the + function to which it is added. + total_steps: Total number of steps (leaf nodes in tree aggregation). + max_participation: The maximum number of times a sample can appear. + min_separation: The minimum number of nodes between two appearance of a + sample. If a sample appears in consecutive x, y steps in a streaming + setting, then `min_separation=y-x-1`. + orders: An array (or a scalar) of RDP orders. + + Returns: + The RDPs at all orders. Can be `np.inf`. + """ + _check_nonnegative(noise_multiplier, "noise_multiplier") + if noise_multiplier == 0: + return np.inf + _check_nonnegative(total_steps, "total_steps") + _check_nonnegative(max_participation, "max_participation") + _check_nonnegative(min_separation, "min_separation") + sum_sensitivity_square = _max_tree_sensitivity_square_sum( + max_participation, min_separation, total_steps) + if np.isscalar(orders): + rdp = _compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square, + orders) + else: + rdp = np.array([ + _compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square, alpha) + for alpha in orders + ]) + return rdp diff --git a/tensorflow_privacy/privacy/analysis/tree_aggregation_accountant_test.py b/tensorflow_privacy/privacy/analysis/tree_aggregation_accountant_test.py new file mode 100644 index 0000000..17f6437 --- /dev/null +++ b/tensorflow_privacy/privacy/analysis/tree_aggregation_accountant_test.py @@ -0,0 +1,185 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for rdp_accountant.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import tensorflow as tf + +from tensorflow_privacy.privacy.analysis import rdp_accountant +from tensorflow_privacy.privacy.analysis import tree_aggregation_accountant + + +class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters(('eps20', 1.13, 19.74), ('eps2', 8.83, 2.04)) + def test_compute_eps_tree(self, noise_multiplier, eps): + orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64)) + # This tests is based on the StackOverflow setting in "Practical and + # Private (Deep) Learning without Sampling or Shuffling". The calculated + # epsilon could be better as the method in this package keeps improving. + steps_list, target_delta = 1600, 1e-6 + rdp = tree_aggregation_accountant.compute_rdp_tree_restart( + noise_multiplier, steps_list, orders) + new_eps = rdp_accountant.get_privacy_spent( + orders, rdp, target_delta=target_delta)[0] + self.assertLess(new_eps, eps) + + @parameterized.named_parameters( + ('restart4', [400] * 4), + ('restart2', [800] * 2), + ('adaptive', [10, 400, 400, 400, 390]), + ) + def test_compose_tree_rdp(self, steps_list): + noise_multiplier, orders = 0.1, 1 + rdp_list = [ + tree_aggregation_accountant.compute_rdp_tree_restart( + noise_multiplier, steps, orders) for steps in steps_list + ] + rdp_composed = tree_aggregation_accountant.compute_rdp_tree_restart( + noise_multiplier, steps_list, orders) + self.assertAllClose(rdp_composed, sum(rdp_list), rtol=1e-12) + + @parameterized.named_parameters( + ('restart4', [400] * 4), + ('restart2', [800] * 2), + ('adaptive', [10, 400, 400, 400, 390]), + ) + def test_compute_eps_tree_decreasing(self, steps_list): + # Test privacy epsilon decreases with noise multiplier increasing when + # keeping other parameters the same. + orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64)) + target_delta = 1e-6 + prev_eps = tree_aggregation_accountant.compute_rdp_tree_restart( + 0, steps_list, orders) + for noise_multiplier in [0.1 * x for x in range(1, 100, 5)]: + rdp = tree_aggregation_accountant.compute_rdp_tree_restart( + noise_multiplier, steps_list, orders) + eps = rdp_accountant.get_privacy_spent( + orders, rdp, target_delta=target_delta)[0] + self.assertLess(eps, prev_eps) + + @parameterized.named_parameters( + ('negative_noise', -1, 3, 1), + ('empty_steps', 1, [], 1), + ('negative_steps', 1, -3, 1), + ) + def test_compute_rdp_tree_restart_raise(self, noise_multiplier, steps_list, + orders): + with self.assertRaisesRegex(ValueError, 'must be'): + tree_aggregation_accountant.compute_rdp_tree_restart( + noise_multiplier, steps_list, orders) + + @parameterized.named_parameters( + ('t100n0.1', 100, 0.1), + ('t1000n0.01', 1000, 0.01), + ) + def test_no_tree_no_sampling(self, total_steps, noise_multiplier): + orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64)) + tree_rdp = tree_aggregation_accountant.compute_rdp_tree_restart( + noise_multiplier, [1] * total_steps, orders) + rdp = rdp_accountant.compute_rdp(1., noise_multiplier, total_steps, orders) + self.assertAllClose(tree_rdp, rdp, rtol=1e-12) + + @parameterized.named_parameters( + ('negative_noise', -1, 3, 1, 1), + ('negative_steps', 0.1, -3, 1, 1), + ('negative_part', 0.1, 3, -1, 1), + ('negative_sep', 0.1, 3, 1, -1), + ) + def test_compute_rdp_single_tree_raise(self, noise_multiplier, total_steps, + max_participation, min_separation): + orders = 1 + with self.assertRaisesRegex(ValueError, 'must be'): + tree_aggregation_accountant.compute_rdp_single_tree( + noise_multiplier, total_steps, max_participation, min_separation, + orders) + + @parameterized.named_parameters( + ('3', 3), + ('8', 8), + ('11', 11), + ('19', 19), + ) + def test_max_tree_sensitivity_square_sum_every_step(self, steps): + max_participation, min_separation = steps, 0 + # If a sample will appear in every leaf node, we can infer the total + # sensitivity by adding all the nodes. + steps_bin = bin(steps)[2:] + depth = [ + len(steps_bin) - 1 - i for i, v in enumerate(steps_bin) if v == '1' + ] + expected = sum([2**d * (2**(d + 1) - 1) for d in depth]) + self.assertEqual( + expected, + tree_aggregation_accountant._max_tree_sensitivity_square_sum( + max_participation, min_separation, steps)) + + @parameterized.named_parameters( + ('11', 11), + ('19', 19), + ('200', 200), + ) + def test_max_tree_sensitivity_square_sum_every_step_part(self, max_part): + steps, min_separation = 8, 0 + assert max_part > steps + # If a sample will appear in every leaf node, we can infer the total + # sensitivity by adding all the nodes. + expected = 120 + self.assertEqual( + expected, + tree_aggregation_accountant._max_tree_sensitivity_square_sum( + max_part, min_separation, steps)) + + @parameterized.named_parameters( + ('3', 3), + ('8', 8), + ('11', 11), + ('19', 19), + ) + def test_max_tree_sensitivity_square_sum_every_step_part2(self, steps): + max_participation, min_separation = 2, 0 + # If a sample will appear twice, the worst case is to put the two nodes at + # consecutive nodes of the deepest subtree. + steps_bin = bin(steps)[2:] + depth = len(steps_bin) - 1 + expected = 2 + 4 * depth + self.assertEqual( + expected, + tree_aggregation_accountant._max_tree_sensitivity_square_sum( + max_participation, min_separation, steps)) + + @parameterized.named_parameters( + ('test1', 1, 7, 8, 4), + ('test2', 3, 3, 9, 11), + ('test3', 3, 2, 7, 9), + # This is an example showing worst-case sensitivity is larger than greedy + # in "Practical and Private (Deep) Learning without Sampling or Shuffling" + # https://arxiv.org/abs/2103.00039. + ('test4', 8, 2, 24, 88), + ) + def test_max_tree_sensitivity_square_sum_toy(self, max_participation, + min_separation, steps, expected): + self.assertEqual( + expected, + tree_aggregation_accountant._max_tree_sensitivity_square_sum( + max_participation, min_separation, steps)) + + +if __name__ == '__main__': + tf.test.main()