Move tree_aggregation accountant to their own module.

PiperOrigin-RevId: 414770173
This commit is contained in:
Zheng Xu 2021-12-07 10:48:30 -08:00 committed by A. Unique TensorFlower
parent 245fd069ca
commit 8850c23f67
5 changed files with 502 additions and 408 deletions

View file

@ -45,8 +45,9 @@ else:
from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_heterogeneous_rdp from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_heterogeneous_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp_tree_restart
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from tensorflow_privacy.privacy.analysis.tree_aggregation_accountant import compute_rdp_tree_restart
from tensorflow_privacy.privacy.analysis.tree_aggregation_accountant import compute_rdp_single_tree
# DPQuery classes # DPQuery classes
from tensorflow_privacy.privacy.dp_query.dp_query import DPQuery from tensorflow_privacy.privacy.dp_query.dp_query import DPQuery

View file

@ -40,11 +40,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
import math import math
import sys import sys
from typing import Collection, Union
import numpy as np import numpy as np
from scipy import special from scipy import special
import six import six
@ -399,254 +396,6 @@ def compute_rdp(q, noise_multiplier, steps, orders):
return rdp * steps return rdp * steps
# TODO(b/193679963): move accounting for tree aggregation to a separate module
def _compute_rdp_tree_restart(sigma, steps_list, alpha):
"""Computes RDP of the Tree Aggregation Protocol at order alpha."""
if np.isinf(alpha):
return np.inf
tree_depths = [
math.floor(math.log2(float(steps))) + 1
for steps in steps_list
if steps > 0
]
return _compute_gaussian_rdp(
alpha=alpha, sum_sensitivity_square=sum(tree_depths), sigma=sigma)
def compute_rdp_tree_restart(
noise_multiplier: float, steps_list: Union[int, Collection[int]],
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
"""Computes RDP of the Tree Aggregation Protocol for Gaussian Mechanism.
This function implements the accounting when the tree is restarted at every
epoch. See appendix of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Args:
noise_multiplier: A non-negative float representing the ratio of the
standard deviation of the Gaussian noise to the l2-sensitivity of the
function to which it is added.
steps_list: A scalar or a list of non-negative intergers representing the
number of steps per epoch (between two restarts).
orders: An array (or a scalar) of RDP orders.
Returns:
The RDPs at all orders. Can be `np.inf`.
"""
_check_nonnegative(noise_multiplier, "noise_multiplier")
if noise_multiplier == 0:
return np.inf
if not steps_list:
raise ValueError(
"steps_list must be a non-empty list, or a non-zero scalar, got "
f"{steps_list}.")
if np.isscalar(steps_list):
steps_list = [steps_list]
for steps in steps_list:
if steps < 0:
raise ValueError(f"Steps must be non-negative, got {steps_list}")
if np.isscalar(orders):
rdp = _compute_rdp_tree_restart(noise_multiplier, steps_list, orders)
else:
rdp = np.array([
_compute_rdp_tree_restart(noise_multiplier, steps_list, alpha)
for alpha in orders
])
return rdp
def _check_nonnegative(value: Union[int, float], name: str):
if value < 0:
raise ValueError(f"Provided {name} must be non-negative, got {value}")
def _check_possible_tree_participation(num_participation: int,
min_separation: int, start: int,
end: int, steps: int) -> bool:
"""Check if participation is possible with `min_separation` in `steps`.
This function checks if it is possible for a sample to appear
`num_participation` in `steps`, assuming there are at least `min_separation`
nodes between the appearance of the same sample in the streaming data (leaf
nodes in tree aggregation). The first appearance of the sample is after
`start` steps, and the sample won't appear in the `end` steps after the given
`steps`.
Args:
num_participation: The number of times a sample will appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y steps in a streaming
setting, then `min_separation=y-x-1`.
start: The first appearance of the sample is after `start` steps.
end: The sample won't appear in the `end` steps after the given `steps`.
steps: Total number of steps (leaf nodes in tree aggregation).
Returns:
True if a sample can appear `num_participation` with given conditions.
"""
return start + (min_separation + 1) * num_participation <= steps + end
@functools.lru_cache(maxsize=None)
def _tree_sensitivity_square_sum(num_participation: int, min_separation: int,
start: int, end: int, size: int) -> float:
"""Compute the worst-case sum of sensitivtiy square for `num_participation`.
This is the key algorithm for DP accounting for DP-FTRL tree aggregation
without restart, which recurrently counts the worst-case occurence of a sample
in all the nodes in a tree. This implements a dynamic programming algorithm
that exhausts the possible `num_participation` appearance of a sample in
`size` leaf nodes. See Appendix D.2 (DP-FTRL-NoTreeRestart) of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Args:
num_participation: The number of times a sample will appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y size in a streaming
setting, then `min_separation=y-x-1`.
start: The first appearance of the sample is after `start` steps.
end: The sample won't appear in the `end` steps after given `size` steps.
size: Total number of steps (leaf nodes in tree aggregation).
Returns:
The worst-case sum of sensitivity square for the given input.
"""
if not _check_possible_tree_participation(num_participation, min_separation,
start, end, size):
sum_value = -np.inf
elif num_participation == 0:
sum_value = 0.
elif num_participation == 1 and size == 1:
sum_value = 1.
else:
size_log2 = math.log2(size)
max_2power = math.floor(size_log2)
if max_2power == size_log2:
sum_value = num_participation**2
max_2power -= 1
else:
sum_value = 0.
candidate_sum = []
# i is the `num_participation` in the right subtree
for i in range(num_participation + 1):
# j is the `start` in the right subtree
for j in range(min_separation + 1):
left_sum = _tree_sensitivity_square_sum(
num_participation=num_participation - i,
min_separation=min_separation,
start=start,
end=j,
size=2**max_2power)
if np.isinf(left_sum):
candidate_sum.append(-np.inf)
continue # Early pruning for dynamic programming
right_sum = _tree_sensitivity_square_sum(
num_participation=i,
min_separation=min_separation,
start=j,
end=end,
size=size - 2**max_2power)
candidate_sum.append(left_sum + right_sum)
sum_value += max(candidate_sum)
return sum_value
def _max_tree_sensitivity_square_sum(max_participation: int,
min_separation: int, steps: int) -> float:
"""Compute the worst-case sum of sensitivity square in tree aggregation.
See Appendix D.2 of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Args:
max_participation: The maximum number of times a sample will appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y steps in a streaming
setting, then `min_separation=y-x-1`.
steps: Total number of steps (leaf nodes in tree aggregation).
Returns:
The worst-case sum of sensitivity square for the given input.
"""
num_participation = max_participation
while not _check_possible_tree_participation(
num_participation, min_separation, 0, min_separation, steps):
num_participation -= 1
candidate_sum = []
for num_part in range(1, num_participation + 1):
candidate_sum.append(
_tree_sensitivity_square_sum(num_part, min_separation, 0,
min_separation, steps))
return max(candidate_sum)
def _compute_gaussian_rdp(sigma: float, sum_sensitivity_square: float,
alpha: float) -> float:
"""Computes RDP of Gaussian mechanism."""
if np.isinf(alpha):
return np.inf
return alpha * sum_sensitivity_square / (2 * sigma**2)
def compute_rdp_single_tree(
noise_multiplier: float, total_steps: int, max_participation: int,
min_separation: int,
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
"""Computes RDP of the Tree Aggregation Protocol for a single tree.
The accounting assume a single tree is constructed for `total_steps` leaf
nodes, where the same sample will appear at most `max_participation` times,
and there are at least `min_separation` nodes between two appearance. The key
idea is to (recurrently) count the worst-case occurence of a sample
in all the nodes in a tree, which implements a dynamic programming algorithm
that exhausts the possible `num_participation` appearance of a sample in
`steps` leaf nodes.
See Appendix D of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Args:
noise_multiplier: A non-negative float representing the ratio of the
standard deviation of the Gaussian noise to the l2-sensitivity of the
function to which it is added.
total_steps: Total number of steps (leaf nodes in tree aggregation).
max_participation: The maximum number of times a sample can appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y steps in a streaming
setting, then `min_separation=y-x-1`.
orders: An array (or a scalar) of RDP orders.
Returns:
The RDPs at all orders. Can be `np.inf`.
"""
_check_nonnegative(noise_multiplier, "noise_multiplier")
if noise_multiplier == 0:
return np.inf
_check_nonnegative(total_steps, "total_steps")
_check_nonnegative(max_participation, "max_participation")
_check_nonnegative(min_separation, "min_separation")
sum_sensitivity_square = _max_tree_sensitivity_square_sum(
max_participation, min_separation, total_steps)
if np.isscalar(orders):
rdp = _compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square,
orders)
else:
rdp = np.array([
_compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square, alpha)
for alpha in orders
])
return rdp
def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders): def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders):
"""Compute RDP of Gaussian Mechanism using sampling without replacement. """Compute RDP of Gaussian Mechanism using sampling without replacement.

View file

@ -264,161 +264,5 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
self.assertLessEqual(delta, delta1 + 1e-300) self.assertLessEqual(delta, delta1 + 1e-300)
class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('eps20', 1.13, 19.74), ('eps2', 8.83, 2.04))
def test_compute_eps_tree(self, noise_multiplier, eps):
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
# This tests is based on the StackOverflow setting in "Practical and
# Private (Deep) Learning without Sampling or Shuffling". The calculated
# epsilon could be better as the method in this package keeps improving.
steps_list, target_delta = 1600, 1e-6
rdp = rdp_accountant.compute_rdp_tree_restart(noise_multiplier, steps_list,
orders)
new_eps = rdp_accountant.get_privacy_spent(
orders, rdp, target_delta=target_delta)[0]
self.assertLess(new_eps, eps)
@parameterized.named_parameters(
('restart4', [400] * 4),
('restart2', [800] * 2),
('adaptive', [10, 400, 400, 400, 390]),
)
def test_compose_tree_rdp(self, steps_list):
noise_multiplier, orders = 0.1, 1
rdp_list = [
rdp_accountant.compute_rdp_tree_restart(noise_multiplier, steps, orders)
for steps in steps_list
]
rdp_composed = rdp_accountant.compute_rdp_tree_restart(
noise_multiplier, steps_list, orders)
self.assertAllClose(rdp_composed, sum(rdp_list), rtol=1e-12)
@parameterized.named_parameters(
('restart4', [400] * 4),
('restart2', [800] * 2),
('adaptive', [10, 400, 400, 400, 390]),
)
def test_compute_eps_tree_decreasing(self, steps_list):
# Test privacy epsilon decreases with noise multiplier increasing when
# keeping other parameters the same.
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
target_delta = 1e-6
prev_eps = rdp_accountant.compute_rdp_tree_restart(0, steps_list, orders)
for noise_multiplier in [0.1 * x for x in range(1, 100, 5)]:
rdp = rdp_accountant.compute_rdp_tree_restart(noise_multiplier,
steps_list, orders)
eps = rdp_accountant.get_privacy_spent(
orders, rdp, target_delta=target_delta)[0]
self.assertLess(eps, prev_eps)
@parameterized.named_parameters(
('negative_noise', -1, 3, 1),
('empty_steps', 1, [], 1),
('negative_steps', 1, -3, 1),
)
def test_compute_rdp_tree_restart_raise(self, noise_multiplier, steps_list,
orders):
with self.assertRaisesRegex(ValueError, 'must be'):
rdp_accountant.compute_rdp_tree_restart(noise_multiplier, steps_list,
orders)
@parameterized.named_parameters(
('t100n0.1', 100, 0.1),
('t1000n0.01', 1000, 0.01),
)
def test_no_tree_no_sampling(self, total_steps, noise_multiplier):
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
tree_rdp = rdp_accountant.compute_rdp_tree_restart(noise_multiplier,
[1] * total_steps,
orders)
rdp = rdp_accountant.compute_rdp(1., noise_multiplier, total_steps, orders)
self.assertAllClose(tree_rdp, rdp, rtol=1e-12)
@parameterized.named_parameters(
('negative_noise', -1, 3, 1, 1),
('negative_steps', 0.1, -3, 1, 1),
('negative_part', 0.1, 3, -1, 1),
('negative_sep', 0.1, 3, 1, -1),
)
def test_compute_rdp_single_tree_raise(self, noise_multiplier, total_steps,
max_participation, min_separation):
orders = 1
with self.assertRaisesRegex(ValueError, 'must be'):
rdp_accountant.compute_rdp_single_tree(noise_multiplier, total_steps,
max_participation, min_separation,
orders)
@parameterized.named_parameters(
('3', 3),
('8', 8),
('11', 11),
('19', 19),
)
def test_max_tree_sensitivity_square_sum_every_step(self, steps):
max_participation, min_separation = steps, 0
# If a sample will appear in every leaf node, we can infer the total
# sensitivity by adding all the nodes.
steps_bin = bin(steps)[2:]
depth = [
len(steps_bin) - 1 - i for i, v in enumerate(steps_bin) if v == '1'
]
expected = sum([2**d * (2**(d + 1) - 1) for d in depth])
self.assertEqual(
expected,
rdp_accountant._max_tree_sensitivity_square_sum(max_participation,
min_separation, steps))
@parameterized.named_parameters(
('11', 11),
('19', 19),
('200', 200),
)
def test_max_tree_sensitivity_square_sum_every_step_part(self, max_part):
steps, min_separation = 8, 0
assert max_part > steps
# If a sample will appear in every leaf node, we can infer the total
# sensitivity by adding all the nodes.
expected = 120
self.assertEqual(
expected,
rdp_accountant._max_tree_sensitivity_square_sum(max_part,
min_separation, steps))
@parameterized.named_parameters(
('3', 3),
('8', 8),
('11', 11),
('19', 19),
)
def test_max_tree_sensitivity_square_sum_every_step_part2(self, steps):
max_participation, min_separation = 2, 0
# If a sample will appear twice, the worst case is to put the two nodes at
# consecutive nodes of the deepest subtree.
steps_bin = bin(steps)[2:]
depth = len(steps_bin) - 1
expected = 2 + 4 * depth
self.assertEqual(
expected,
rdp_accountant._max_tree_sensitivity_square_sum(max_participation,
min_separation, steps))
@parameterized.named_parameters(
('test1', 1, 7, 8, 4),
('test2', 3, 3, 9, 11),
('test3', 3, 2, 7, 9),
# This is an example showing worst-case sensitivity is larger than greedy
# in "Practical and Private (Deep) Learning without Sampling or Shuffling"
# https://arxiv.org/abs/2103.00039.
('test4', 8, 2, 24, 88),
)
def test_max_tree_sensitivity_square_sum_toy(self, max_participation,
min_separation, steps, expected):
self.assertEqual(
expected,
rdp_accountant._max_tree_sensitivity_square_sum(max_participation,
min_separation, steps))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()

View file

@ -0,0 +1,315 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DP analysis of tree aggregation.
See Appendix D of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Functionality for computing differential privacy of tree aggregation of Gaussian
mechanism. Its public interface consists of the following methods:
compute_rdp_tree_restart(
noise_multiplier: float, steps_list: Union[int, Collection[int]],
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
computes RDP for DP-FTRL-TreeRestart.
compute_rdp_single_tree(
noise_multiplier: float, total_steps: int, max_participation: int,
min_separation: int,
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
computes RDP for DP-FTRL-NoTreeRestart.
For RDP to (epsilon, delta)-DP conversion, use the following public function
described in `rdp_accountant.py`:
get_privacy_spent(orders, rdp, target_eps, target_delta) computes delta
(or eps) given RDP at multiple orders and
a target value for eps (or delta).
Example use:
(1) DP-FTRL-TreeRestart RDP:
Suppose we use Gaussian mechanism of `noise_multiplier`; a sample may appear
at most once for every epoch and tree is restarted every epoch; the number of
leaf nodes for every epoch are tracked in `steps_list`. For `target_delta`, the
estimated epsilon is:
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
rdp = compute_rdp_tree_restart(noise_multiplier, steps_list, orders)
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
(2) DP-FTRL-NoTreeRestart RDP:
Suppose we use Gaussian mechanism of `noise_multiplier`; a sample may appear
at most `max_participation` times for a total of `total_steps` leaf nodes in a
single tree; there are at least `min_separation` leaf nodes between the two
appearance of a same sample. For `target_delta`, the estimated epsilon is:
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
rdp = compute_rdp_single_tree(noise_multiplier, total_steps,
max_participation, min_separation, orders)
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import math
from typing import Collection, Union
import numpy as np
def _compute_rdp_tree_restart(sigma, steps_list, alpha):
"""Computes RDP of the Tree Aggregation Protocol at order alpha."""
if np.isinf(alpha):
return np.inf
tree_depths = [
math.floor(math.log2(float(steps))) + 1
for steps in steps_list
if steps > 0
]
return _compute_gaussian_rdp(
alpha=alpha, sum_sensitivity_square=sum(tree_depths), sigma=sigma)
def compute_rdp_tree_restart(
noise_multiplier: float, steps_list: Union[int, Collection[int]],
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
"""Computes RDP of the Tree Aggregation Protocol for Gaussian Mechanism.
This function implements the accounting when the tree is restarted at every
epoch. See appendix D of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Args:
noise_multiplier: A non-negative float representing the ratio of the
standard deviation of the Gaussian noise to the l2-sensitivity of the
function to which it is added.
steps_list: A scalar or a list of non-negative intergers representing the
number of steps per epoch (between two restarts).
orders: An array (or a scalar) of RDP orders.
Returns:
The RDPs at all orders. Can be `np.inf`.
"""
_check_nonnegative(noise_multiplier, "noise_multiplier")
if noise_multiplier == 0:
return np.inf
if not steps_list:
raise ValueError(
"steps_list must be a non-empty list, or a non-zero scalar, got "
f"{steps_list}.")
if np.isscalar(steps_list):
steps_list = [steps_list]
for steps in steps_list:
if steps < 0:
raise ValueError(f"Steps must be non-negative, got {steps_list}")
if np.isscalar(orders):
rdp = _compute_rdp_tree_restart(noise_multiplier, steps_list, orders)
else:
rdp = np.array([
_compute_rdp_tree_restart(noise_multiplier, steps_list, alpha)
for alpha in orders
])
return rdp
def _check_nonnegative(value: Union[int, float], name: str):
if value < 0:
raise ValueError(f"Provided {name} must be non-negative, got {value}")
def _check_possible_tree_participation(num_participation: int,
min_separation: int, start: int,
end: int, steps: int) -> bool:
"""Check if participation is possible with `min_separation` in `steps`.
This function checks if it is possible for a sample to appear
`num_participation` in `steps`, assuming there are at least `min_separation`
nodes between the appearance of the same sample in the streaming data (leaf
nodes in tree aggregation). The first appearance of the sample is after
`start` steps, and the sample won't appear in the `end` steps after the given
`steps`.
Args:
num_participation: The number of times a sample will appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y steps in a streaming
setting, then `min_separation=y-x-1`.
start: The first appearance of the sample is after `start` steps.
end: The sample won't appear in the `end` steps after the given `steps`.
steps: Total number of steps (leaf nodes in tree aggregation).
Returns:
True if a sample can appear `num_participation` with given conditions.
"""
return start + (min_separation + 1) * num_participation <= steps + end
@functools.lru_cache(maxsize=None)
def _tree_sensitivity_square_sum(num_participation: int, min_separation: int,
start: int, end: int, size: int) -> float:
"""Compute the worst-case sum of sensitivtiy square for `num_participation`.
This is the key algorithm for DP accounting for DP-FTRL tree aggregation
without restart, which recurrently counts the worst-case occurence of a sample
in all the nodes in a tree. This implements a dynamic programming algorithm
that exhausts the possible `num_participation` appearance of a sample in
`size` leaf nodes. See Appendix D.2 (DP-FTRL-NoTreeRestart) of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Args:
num_participation: The number of times a sample will appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y size in a streaming
setting, then `min_separation=y-x-1`.
start: The first appearance of the sample is after `start` steps.
end: The sample won't appear in the `end` steps after given `size` steps.
size: Total number of steps (leaf nodes in tree aggregation).
Returns:
The worst-case sum of sensitivity square for the given input.
"""
if not _check_possible_tree_participation(num_participation, min_separation,
start, end, size):
sum_value = -np.inf
elif num_participation == 0:
sum_value = 0.
elif num_participation == 1 and size == 1:
sum_value = 1.
else:
size_log2 = math.log2(size)
max_2power = math.floor(size_log2)
if max_2power == size_log2:
sum_value = num_participation**2
max_2power -= 1
else:
sum_value = 0.
candidate_sum = []
# i is the `num_participation` in the right subtree
for i in range(num_participation + 1):
# j is the `start` in the right subtree
for j in range(min_separation + 1):
left_sum = _tree_sensitivity_square_sum(
num_participation=num_participation - i,
min_separation=min_separation,
start=start,
end=j,
size=2**max_2power)
if np.isinf(left_sum):
candidate_sum.append(-np.inf)
continue # Early pruning for dynamic programming
right_sum = _tree_sensitivity_square_sum(
num_participation=i,
min_separation=min_separation,
start=j,
end=end,
size=size - 2**max_2power)
candidate_sum.append(left_sum + right_sum)
sum_value += max(candidate_sum)
return sum_value
def _max_tree_sensitivity_square_sum(max_participation: int,
min_separation: int, steps: int) -> float:
"""Compute the worst-case sum of sensitivity square in tree aggregation.
See Appendix D.2 of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Args:
max_participation: The maximum number of times a sample will appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y steps in a streaming
setting, then `min_separation=y-x-1`.
steps: Total number of steps (leaf nodes in tree aggregation).
Returns:
The worst-case sum of sensitivity square for the given input.
"""
num_participation = max_participation
while not _check_possible_tree_participation(
num_participation, min_separation, 0, min_separation, steps):
num_participation -= 1
candidate_sum = []
for num_part in range(1, num_participation + 1):
candidate_sum.append(
_tree_sensitivity_square_sum(num_part, min_separation, 0,
min_separation, steps))
return max(candidate_sum)
def _compute_gaussian_rdp(sigma: float, sum_sensitivity_square: float,
alpha: float) -> float:
"""Computes RDP of Gaussian mechanism."""
if np.isinf(alpha):
return np.inf
return alpha * sum_sensitivity_square / (2 * sigma**2)
def compute_rdp_single_tree(
noise_multiplier: float, total_steps: int, max_participation: int,
min_separation: int,
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
"""Computes RDP of the Tree Aggregation Protocol for a single tree.
The accounting assume a single tree is constructed for `total_steps` leaf
nodes, where the same sample will appear at most `max_participation` times,
and there are at least `min_separation` nodes between two appearance. The key
idea is to (recurrently) count the worst-case occurence of a sample
in all the nodes in a tree, which implements a dynamic programming algorithm
that exhausts the possible `num_participation` appearance of a sample in
`steps` leaf nodes.
See Appendix D of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Args:
noise_multiplier: A non-negative float representing the ratio of the
standard deviation of the Gaussian noise to the l2-sensitivity of the
function to which it is added.
total_steps: Total number of steps (leaf nodes in tree aggregation).
max_participation: The maximum number of times a sample can appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y steps in a streaming
setting, then `min_separation=y-x-1`.
orders: An array (or a scalar) of RDP orders.
Returns:
The RDPs at all orders. Can be `np.inf`.
"""
_check_nonnegative(noise_multiplier, "noise_multiplier")
if noise_multiplier == 0:
return np.inf
_check_nonnegative(total_steps, "total_steps")
_check_nonnegative(max_participation, "max_participation")
_check_nonnegative(min_separation, "min_separation")
sum_sensitivity_square = _max_tree_sensitivity_square_sum(
max_participation, min_separation, total_steps)
if np.isscalar(orders):
rdp = _compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square,
orders)
else:
rdp = np.array([
_compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square, alpha)
for alpha in orders
])
return rdp

View file

@ -0,0 +1,185 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for rdp_accountant.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import rdp_accountant
from tensorflow_privacy.privacy.analysis import tree_aggregation_accountant
class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('eps20', 1.13, 19.74), ('eps2', 8.83, 2.04))
def test_compute_eps_tree(self, noise_multiplier, eps):
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
# This tests is based on the StackOverflow setting in "Practical and
# Private (Deep) Learning without Sampling or Shuffling". The calculated
# epsilon could be better as the method in this package keeps improving.
steps_list, target_delta = 1600, 1e-6
rdp = tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, steps_list, orders)
new_eps = rdp_accountant.get_privacy_spent(
orders, rdp, target_delta=target_delta)[0]
self.assertLess(new_eps, eps)
@parameterized.named_parameters(
('restart4', [400] * 4),
('restart2', [800] * 2),
('adaptive', [10, 400, 400, 400, 390]),
)
def test_compose_tree_rdp(self, steps_list):
noise_multiplier, orders = 0.1, 1
rdp_list = [
tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, steps, orders) for steps in steps_list
]
rdp_composed = tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, steps_list, orders)
self.assertAllClose(rdp_composed, sum(rdp_list), rtol=1e-12)
@parameterized.named_parameters(
('restart4', [400] * 4),
('restart2', [800] * 2),
('adaptive', [10, 400, 400, 400, 390]),
)
def test_compute_eps_tree_decreasing(self, steps_list):
# Test privacy epsilon decreases with noise multiplier increasing when
# keeping other parameters the same.
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
target_delta = 1e-6
prev_eps = tree_aggregation_accountant.compute_rdp_tree_restart(
0, steps_list, orders)
for noise_multiplier in [0.1 * x for x in range(1, 100, 5)]:
rdp = tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, steps_list, orders)
eps = rdp_accountant.get_privacy_spent(
orders, rdp, target_delta=target_delta)[0]
self.assertLess(eps, prev_eps)
@parameterized.named_parameters(
('negative_noise', -1, 3, 1),
('empty_steps', 1, [], 1),
('negative_steps', 1, -3, 1),
)
def test_compute_rdp_tree_restart_raise(self, noise_multiplier, steps_list,
orders):
with self.assertRaisesRegex(ValueError, 'must be'):
tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, steps_list, orders)
@parameterized.named_parameters(
('t100n0.1', 100, 0.1),
('t1000n0.01', 1000, 0.01),
)
def test_no_tree_no_sampling(self, total_steps, noise_multiplier):
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
tree_rdp = tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, [1] * total_steps, orders)
rdp = rdp_accountant.compute_rdp(1., noise_multiplier, total_steps, orders)
self.assertAllClose(tree_rdp, rdp, rtol=1e-12)
@parameterized.named_parameters(
('negative_noise', -1, 3, 1, 1),
('negative_steps', 0.1, -3, 1, 1),
('negative_part', 0.1, 3, -1, 1),
('negative_sep', 0.1, 3, 1, -1),
)
def test_compute_rdp_single_tree_raise(self, noise_multiplier, total_steps,
max_participation, min_separation):
orders = 1
with self.assertRaisesRegex(ValueError, 'must be'):
tree_aggregation_accountant.compute_rdp_single_tree(
noise_multiplier, total_steps, max_participation, min_separation,
orders)
@parameterized.named_parameters(
('3', 3),
('8', 8),
('11', 11),
('19', 19),
)
def test_max_tree_sensitivity_square_sum_every_step(self, steps):
max_participation, min_separation = steps, 0
# If a sample will appear in every leaf node, we can infer the total
# sensitivity by adding all the nodes.
steps_bin = bin(steps)[2:]
depth = [
len(steps_bin) - 1 - i for i, v in enumerate(steps_bin) if v == '1'
]
expected = sum([2**d * (2**(d + 1) - 1) for d in depth])
self.assertEqual(
expected,
tree_aggregation_accountant._max_tree_sensitivity_square_sum(
max_participation, min_separation, steps))
@parameterized.named_parameters(
('11', 11),
('19', 19),
('200', 200),
)
def test_max_tree_sensitivity_square_sum_every_step_part(self, max_part):
steps, min_separation = 8, 0
assert max_part > steps
# If a sample will appear in every leaf node, we can infer the total
# sensitivity by adding all the nodes.
expected = 120
self.assertEqual(
expected,
tree_aggregation_accountant._max_tree_sensitivity_square_sum(
max_part, min_separation, steps))
@parameterized.named_parameters(
('3', 3),
('8', 8),
('11', 11),
('19', 19),
)
def test_max_tree_sensitivity_square_sum_every_step_part2(self, steps):
max_participation, min_separation = 2, 0
# If a sample will appear twice, the worst case is to put the two nodes at
# consecutive nodes of the deepest subtree.
steps_bin = bin(steps)[2:]
depth = len(steps_bin) - 1
expected = 2 + 4 * depth
self.assertEqual(
expected,
tree_aggregation_accountant._max_tree_sensitivity_square_sum(
max_participation, min_separation, steps))
@parameterized.named_parameters(
('test1', 1, 7, 8, 4),
('test2', 3, 3, 9, 11),
('test3', 3, 2, 7, 9),
# This is an example showing worst-case sensitivity is larger than greedy
# in "Practical and Private (Deep) Learning without Sampling or Shuffling"
# https://arxiv.org/abs/2103.00039.
('test4', 8, 2, 24, 88),
)
def test_max_tree_sensitivity_square_sum_toy(self, max_participation,
min_separation, steps, expected):
self.assertEqual(
expected,
tree_aggregation_accountant._max_tree_sensitivity_square_sum(
max_participation, min_separation, steps))
if __name__ == '__main__':
tf.test.main()