RDP accounting for tree aggregation without restart. This implements the dynamic programming algorithm detailed in the updated version of "Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039. PiperOrigin-RevId: 414583453
This commit is contained in:
parent
49db04e356
commit
245fd069ca
2 changed files with 276 additions and 5 deletions
|
@ -40,6 +40,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import math
|
||||
import sys
|
||||
from typing import Collection, Union
|
||||
|
@ -398,6 +399,7 @@ def compute_rdp(q, noise_multiplier, steps, orders):
|
|||
return rdp * steps
|
||||
|
||||
|
||||
# TODO(b/193679963): move accounting for tree aggregation to a separate module
|
||||
def _compute_rdp_tree_restart(sigma, steps_list, alpha):
|
||||
"""Computes RDP of the Tree Aggregation Protocol at order alpha."""
|
||||
if np.isinf(alpha):
|
||||
|
@ -407,7 +409,8 @@ def _compute_rdp_tree_restart(sigma, steps_list, alpha):
|
|||
for steps in steps_list
|
||||
if steps > 0
|
||||
]
|
||||
return alpha * sum(tree_depths) / (2 * sigma**2)
|
||||
return _compute_gaussian_rdp(
|
||||
alpha=alpha, sum_sensitivity_square=sum(tree_depths), sigma=sigma)
|
||||
|
||||
|
||||
def compute_rdp_tree_restart(
|
||||
|
@ -431,10 +434,8 @@ def compute_rdp_tree_restart(
|
|||
Returns:
|
||||
The RDPs at all orders. Can be `np.inf`.
|
||||
"""
|
||||
if noise_multiplier < 0:
|
||||
raise ValueError(
|
||||
f"Noise multiplier must be non-negative, got {noise_multiplier}")
|
||||
elif noise_multiplier == 0:
|
||||
_check_nonnegative(noise_multiplier, "noise_multiplier")
|
||||
if noise_multiplier == 0:
|
||||
return np.inf
|
||||
|
||||
if not steps_list:
|
||||
|
@ -460,6 +461,192 @@ def compute_rdp_tree_restart(
|
|||
return rdp
|
||||
|
||||
|
||||
def _check_nonnegative(value: Union[int, float], name: str):
|
||||
if value < 0:
|
||||
raise ValueError(f"Provided {name} must be non-negative, got {value}")
|
||||
|
||||
|
||||
def _check_possible_tree_participation(num_participation: int,
|
||||
min_separation: int, start: int,
|
||||
end: int, steps: int) -> bool:
|
||||
"""Check if participation is possible with `min_separation` in `steps`.
|
||||
|
||||
This function checks if it is possible for a sample to appear
|
||||
`num_participation` in `steps`, assuming there are at least `min_separation`
|
||||
nodes between the appearance of the same sample in the streaming data (leaf
|
||||
nodes in tree aggregation). The first appearance of the sample is after
|
||||
`start` steps, and the sample won't appear in the `end` steps after the given
|
||||
`steps`.
|
||||
|
||||
Args:
|
||||
num_participation: The number of times a sample will appear.
|
||||
min_separation: The minimum number of nodes between two appearance of a
|
||||
sample. If a sample appears in consecutive x, y steps in a streaming
|
||||
setting, then `min_separation=y-x-1`.
|
||||
start: The first appearance of the sample is after `start` steps.
|
||||
end: The sample won't appear in the `end` steps after the given `steps`.
|
||||
steps: Total number of steps (leaf nodes in tree aggregation).
|
||||
|
||||
Returns:
|
||||
True if a sample can appear `num_participation` with given conditions.
|
||||
"""
|
||||
return start + (min_separation + 1) * num_participation <= steps + end
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _tree_sensitivity_square_sum(num_participation: int, min_separation: int,
|
||||
start: int, end: int, size: int) -> float:
|
||||
"""Compute the worst-case sum of sensitivtiy square for `num_participation`.
|
||||
|
||||
This is the key algorithm for DP accounting for DP-FTRL tree aggregation
|
||||
without restart, which recurrently counts the worst-case occurence of a sample
|
||||
in all the nodes in a tree. This implements a dynamic programming algorithm
|
||||
that exhausts the possible `num_participation` appearance of a sample in
|
||||
`size` leaf nodes. See Appendix D.2 (DP-FTRL-NoTreeRestart) of
|
||||
"Practical and Private (Deep) Learning without Sampling or Shuffling"
|
||||
https://arxiv.org/abs/2103.00039.
|
||||
|
||||
Args:
|
||||
num_participation: The number of times a sample will appear.
|
||||
min_separation: The minimum number of nodes between two appearance of a
|
||||
sample. If a sample appears in consecutive x, y size in a streaming
|
||||
setting, then `min_separation=y-x-1`.
|
||||
start: The first appearance of the sample is after `start` steps.
|
||||
end: The sample won't appear in the `end` steps after given `size` steps.
|
||||
size: Total number of steps (leaf nodes in tree aggregation).
|
||||
|
||||
Returns:
|
||||
The worst-case sum of sensitivity square for the given input.
|
||||
"""
|
||||
if not _check_possible_tree_participation(num_participation, min_separation,
|
||||
start, end, size):
|
||||
sum_value = -np.inf
|
||||
elif num_participation == 0:
|
||||
sum_value = 0.
|
||||
elif num_participation == 1 and size == 1:
|
||||
sum_value = 1.
|
||||
else:
|
||||
size_log2 = math.log2(size)
|
||||
max_2power = math.floor(size_log2)
|
||||
if max_2power == size_log2:
|
||||
sum_value = num_participation**2
|
||||
max_2power -= 1
|
||||
else:
|
||||
sum_value = 0.
|
||||
candidate_sum = []
|
||||
# i is the `num_participation` in the right subtree
|
||||
for i in range(num_participation + 1):
|
||||
# j is the `start` in the right subtree
|
||||
for j in range(min_separation + 1):
|
||||
left_sum = _tree_sensitivity_square_sum(
|
||||
num_participation=num_participation - i,
|
||||
min_separation=min_separation,
|
||||
start=start,
|
||||
end=j,
|
||||
size=2**max_2power)
|
||||
if np.isinf(left_sum):
|
||||
candidate_sum.append(-np.inf)
|
||||
continue # Early pruning for dynamic programming
|
||||
right_sum = _tree_sensitivity_square_sum(
|
||||
num_participation=i,
|
||||
min_separation=min_separation,
|
||||
start=j,
|
||||
end=end,
|
||||
size=size - 2**max_2power)
|
||||
candidate_sum.append(left_sum + right_sum)
|
||||
sum_value += max(candidate_sum)
|
||||
return sum_value
|
||||
|
||||
|
||||
def _max_tree_sensitivity_square_sum(max_participation: int,
|
||||
min_separation: int, steps: int) -> float:
|
||||
"""Compute the worst-case sum of sensitivity square in tree aggregation.
|
||||
|
||||
See Appendix D.2 of
|
||||
"Practical and Private (Deep) Learning without Sampling or Shuffling"
|
||||
https://arxiv.org/abs/2103.00039.
|
||||
|
||||
Args:
|
||||
max_participation: The maximum number of times a sample will appear.
|
||||
min_separation: The minimum number of nodes between two appearance of a
|
||||
sample. If a sample appears in consecutive x, y steps in a streaming
|
||||
setting, then `min_separation=y-x-1`.
|
||||
steps: Total number of steps (leaf nodes in tree aggregation).
|
||||
|
||||
Returns:
|
||||
The worst-case sum of sensitivity square for the given input.
|
||||
"""
|
||||
num_participation = max_participation
|
||||
while not _check_possible_tree_participation(
|
||||
num_participation, min_separation, 0, min_separation, steps):
|
||||
num_participation -= 1
|
||||
candidate_sum = []
|
||||
for num_part in range(1, num_participation + 1):
|
||||
candidate_sum.append(
|
||||
_tree_sensitivity_square_sum(num_part, min_separation, 0,
|
||||
min_separation, steps))
|
||||
return max(candidate_sum)
|
||||
|
||||
|
||||
def _compute_gaussian_rdp(sigma: float, sum_sensitivity_square: float,
|
||||
alpha: float) -> float:
|
||||
"""Computes RDP of Gaussian mechanism."""
|
||||
if np.isinf(alpha):
|
||||
return np.inf
|
||||
return alpha * sum_sensitivity_square / (2 * sigma**2)
|
||||
|
||||
|
||||
def compute_rdp_single_tree(
|
||||
noise_multiplier: float, total_steps: int, max_participation: int,
|
||||
min_separation: int,
|
||||
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
|
||||
"""Computes RDP of the Tree Aggregation Protocol for a single tree.
|
||||
|
||||
The accounting assume a single tree is constructed for `total_steps` leaf
|
||||
nodes, where the same sample will appear at most `max_participation` times,
|
||||
and there are at least `min_separation` nodes between two appearance. The key
|
||||
idea is to (recurrently) count the worst-case occurence of a sample
|
||||
in all the nodes in a tree, which implements a dynamic programming algorithm
|
||||
that exhausts the possible `num_participation` appearance of a sample in
|
||||
`steps` leaf nodes.
|
||||
|
||||
See Appendix D of
|
||||
"Practical and Private (Deep) Learning without Sampling or Shuffling"
|
||||
https://arxiv.org/abs/2103.00039.
|
||||
|
||||
Args:
|
||||
noise_multiplier: A non-negative float representing the ratio of the
|
||||
standard deviation of the Gaussian noise to the l2-sensitivity of the
|
||||
function to which it is added.
|
||||
total_steps: Total number of steps (leaf nodes in tree aggregation).
|
||||
max_participation: The maximum number of times a sample can appear.
|
||||
min_separation: The minimum number of nodes between two appearance of a
|
||||
sample. If a sample appears in consecutive x, y steps in a streaming
|
||||
setting, then `min_separation=y-x-1`.
|
||||
orders: An array (or a scalar) of RDP orders.
|
||||
|
||||
Returns:
|
||||
The RDPs at all orders. Can be `np.inf`.
|
||||
"""
|
||||
_check_nonnegative(noise_multiplier, "noise_multiplier")
|
||||
if noise_multiplier == 0:
|
||||
return np.inf
|
||||
_check_nonnegative(total_steps, "total_steps")
|
||||
_check_nonnegative(max_participation, "max_participation")
|
||||
_check_nonnegative(min_separation, "min_separation")
|
||||
sum_sensitivity_square = _max_tree_sensitivity_square_sum(
|
||||
max_participation, min_separation, total_steps)
|
||||
if np.isscalar(orders):
|
||||
rdp = _compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square,
|
||||
orders)
|
||||
else:
|
||||
rdp = np.array([
|
||||
_compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square, alpha)
|
||||
for alpha in orders
|
||||
])
|
||||
return rdp
|
||||
|
||||
|
||||
def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders):
|
||||
"""Compute RDP of Gaussian Mechanism using sampling without replacement.
|
||||
|
||||
|
|
|
@ -335,6 +335,90 @@ class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase):
|
|||
rdp = rdp_accountant.compute_rdp(1., noise_multiplier, total_steps, orders)
|
||||
self.assertAllClose(tree_rdp, rdp, rtol=1e-12)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('negative_noise', -1, 3, 1, 1),
|
||||
('negative_steps', 0.1, -3, 1, 1),
|
||||
('negative_part', 0.1, 3, -1, 1),
|
||||
('negative_sep', 0.1, 3, 1, -1),
|
||||
)
|
||||
def test_compute_rdp_single_tree_raise(self, noise_multiplier, total_steps,
|
||||
max_participation, min_separation):
|
||||
orders = 1
|
||||
with self.assertRaisesRegex(ValueError, 'must be'):
|
||||
rdp_accountant.compute_rdp_single_tree(noise_multiplier, total_steps,
|
||||
max_participation, min_separation,
|
||||
orders)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('3', 3),
|
||||
('8', 8),
|
||||
('11', 11),
|
||||
('19', 19),
|
||||
)
|
||||
def test_max_tree_sensitivity_square_sum_every_step(self, steps):
|
||||
max_participation, min_separation = steps, 0
|
||||
# If a sample will appear in every leaf node, we can infer the total
|
||||
# sensitivity by adding all the nodes.
|
||||
steps_bin = bin(steps)[2:]
|
||||
depth = [
|
||||
len(steps_bin) - 1 - i for i, v in enumerate(steps_bin) if v == '1'
|
||||
]
|
||||
expected = sum([2**d * (2**(d + 1) - 1) for d in depth])
|
||||
self.assertEqual(
|
||||
expected,
|
||||
rdp_accountant._max_tree_sensitivity_square_sum(max_participation,
|
||||
min_separation, steps))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('11', 11),
|
||||
('19', 19),
|
||||
('200', 200),
|
||||
)
|
||||
def test_max_tree_sensitivity_square_sum_every_step_part(self, max_part):
|
||||
steps, min_separation = 8, 0
|
||||
assert max_part > steps
|
||||
# If a sample will appear in every leaf node, we can infer the total
|
||||
# sensitivity by adding all the nodes.
|
||||
expected = 120
|
||||
self.assertEqual(
|
||||
expected,
|
||||
rdp_accountant._max_tree_sensitivity_square_sum(max_part,
|
||||
min_separation, steps))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('3', 3),
|
||||
('8', 8),
|
||||
('11', 11),
|
||||
('19', 19),
|
||||
)
|
||||
def test_max_tree_sensitivity_square_sum_every_step_part2(self, steps):
|
||||
max_participation, min_separation = 2, 0
|
||||
# If a sample will appear twice, the worst case is to put the two nodes at
|
||||
# consecutive nodes of the deepest subtree.
|
||||
steps_bin = bin(steps)[2:]
|
||||
depth = len(steps_bin) - 1
|
||||
expected = 2 + 4 * depth
|
||||
self.assertEqual(
|
||||
expected,
|
||||
rdp_accountant._max_tree_sensitivity_square_sum(max_participation,
|
||||
min_separation, steps))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('test1', 1, 7, 8, 4),
|
||||
('test2', 3, 3, 9, 11),
|
||||
('test3', 3, 2, 7, 9),
|
||||
# This is an example showing worst-case sensitivity is larger than greedy
|
||||
# in "Practical and Private (Deep) Learning without Sampling or Shuffling"
|
||||
# https://arxiv.org/abs/2103.00039.
|
||||
('test4', 8, 2, 24, 88),
|
||||
)
|
||||
def test_max_tree_sensitivity_square_sum_toy(self, max_participation,
|
||||
min_separation, steps, expected):
|
||||
self.assertEqual(
|
||||
expected,
|
||||
rdp_accountant._max_tree_sensitivity_square_sum(max_participation,
|
||||
min_separation, steps))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in a new issue