RDP accounting for tree aggregation without restart. This implements the dynamic programming algorithm detailed in the updated version of "Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039. PiperOrigin-RevId: 414583453
This commit is contained in:
parent
49db04e356
commit
245fd069ca
2 changed files with 276 additions and 5 deletions
|
@ -40,6 +40,7 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
from typing import Collection, Union
|
from typing import Collection, Union
|
||||||
|
@ -398,6 +399,7 @@ def compute_rdp(q, noise_multiplier, steps, orders):
|
||||||
return rdp * steps
|
return rdp * steps
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(b/193679963): move accounting for tree aggregation to a separate module
|
||||||
def _compute_rdp_tree_restart(sigma, steps_list, alpha):
|
def _compute_rdp_tree_restart(sigma, steps_list, alpha):
|
||||||
"""Computes RDP of the Tree Aggregation Protocol at order alpha."""
|
"""Computes RDP of the Tree Aggregation Protocol at order alpha."""
|
||||||
if np.isinf(alpha):
|
if np.isinf(alpha):
|
||||||
|
@ -407,7 +409,8 @@ def _compute_rdp_tree_restart(sigma, steps_list, alpha):
|
||||||
for steps in steps_list
|
for steps in steps_list
|
||||||
if steps > 0
|
if steps > 0
|
||||||
]
|
]
|
||||||
return alpha * sum(tree_depths) / (2 * sigma**2)
|
return _compute_gaussian_rdp(
|
||||||
|
alpha=alpha, sum_sensitivity_square=sum(tree_depths), sigma=sigma)
|
||||||
|
|
||||||
|
|
||||||
def compute_rdp_tree_restart(
|
def compute_rdp_tree_restart(
|
||||||
|
@ -431,10 +434,8 @@ def compute_rdp_tree_restart(
|
||||||
Returns:
|
Returns:
|
||||||
The RDPs at all orders. Can be `np.inf`.
|
The RDPs at all orders. Can be `np.inf`.
|
||||||
"""
|
"""
|
||||||
if noise_multiplier < 0:
|
_check_nonnegative(noise_multiplier, "noise_multiplier")
|
||||||
raise ValueError(
|
if noise_multiplier == 0:
|
||||||
f"Noise multiplier must be non-negative, got {noise_multiplier}")
|
|
||||||
elif noise_multiplier == 0:
|
|
||||||
return np.inf
|
return np.inf
|
||||||
|
|
||||||
if not steps_list:
|
if not steps_list:
|
||||||
|
@ -460,6 +461,192 @@ def compute_rdp_tree_restart(
|
||||||
return rdp
|
return rdp
|
||||||
|
|
||||||
|
|
||||||
|
def _check_nonnegative(value: Union[int, float], name: str):
|
||||||
|
if value < 0:
|
||||||
|
raise ValueError(f"Provided {name} must be non-negative, got {value}")
|
||||||
|
|
||||||
|
|
||||||
|
def _check_possible_tree_participation(num_participation: int,
|
||||||
|
min_separation: int, start: int,
|
||||||
|
end: int, steps: int) -> bool:
|
||||||
|
"""Check if participation is possible with `min_separation` in `steps`.
|
||||||
|
|
||||||
|
This function checks if it is possible for a sample to appear
|
||||||
|
`num_participation` in `steps`, assuming there are at least `min_separation`
|
||||||
|
nodes between the appearance of the same sample in the streaming data (leaf
|
||||||
|
nodes in tree aggregation). The first appearance of the sample is after
|
||||||
|
`start` steps, and the sample won't appear in the `end` steps after the given
|
||||||
|
`steps`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_participation: The number of times a sample will appear.
|
||||||
|
min_separation: The minimum number of nodes between two appearance of a
|
||||||
|
sample. If a sample appears in consecutive x, y steps in a streaming
|
||||||
|
setting, then `min_separation=y-x-1`.
|
||||||
|
start: The first appearance of the sample is after `start` steps.
|
||||||
|
end: The sample won't appear in the `end` steps after the given `steps`.
|
||||||
|
steps: Total number of steps (leaf nodes in tree aggregation).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a sample can appear `num_participation` with given conditions.
|
||||||
|
"""
|
||||||
|
return start + (min_separation + 1) * num_participation <= steps + end
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=None)
|
||||||
|
def _tree_sensitivity_square_sum(num_participation: int, min_separation: int,
|
||||||
|
start: int, end: int, size: int) -> float:
|
||||||
|
"""Compute the worst-case sum of sensitivtiy square for `num_participation`.
|
||||||
|
|
||||||
|
This is the key algorithm for DP accounting for DP-FTRL tree aggregation
|
||||||
|
without restart, which recurrently counts the worst-case occurence of a sample
|
||||||
|
in all the nodes in a tree. This implements a dynamic programming algorithm
|
||||||
|
that exhausts the possible `num_participation` appearance of a sample in
|
||||||
|
`size` leaf nodes. See Appendix D.2 (DP-FTRL-NoTreeRestart) of
|
||||||
|
"Practical and Private (Deep) Learning without Sampling or Shuffling"
|
||||||
|
https://arxiv.org/abs/2103.00039.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_participation: The number of times a sample will appear.
|
||||||
|
min_separation: The minimum number of nodes between two appearance of a
|
||||||
|
sample. If a sample appears in consecutive x, y size in a streaming
|
||||||
|
setting, then `min_separation=y-x-1`.
|
||||||
|
start: The first appearance of the sample is after `start` steps.
|
||||||
|
end: The sample won't appear in the `end` steps after given `size` steps.
|
||||||
|
size: Total number of steps (leaf nodes in tree aggregation).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The worst-case sum of sensitivity square for the given input.
|
||||||
|
"""
|
||||||
|
if not _check_possible_tree_participation(num_participation, min_separation,
|
||||||
|
start, end, size):
|
||||||
|
sum_value = -np.inf
|
||||||
|
elif num_participation == 0:
|
||||||
|
sum_value = 0.
|
||||||
|
elif num_participation == 1 and size == 1:
|
||||||
|
sum_value = 1.
|
||||||
|
else:
|
||||||
|
size_log2 = math.log2(size)
|
||||||
|
max_2power = math.floor(size_log2)
|
||||||
|
if max_2power == size_log2:
|
||||||
|
sum_value = num_participation**2
|
||||||
|
max_2power -= 1
|
||||||
|
else:
|
||||||
|
sum_value = 0.
|
||||||
|
candidate_sum = []
|
||||||
|
# i is the `num_participation` in the right subtree
|
||||||
|
for i in range(num_participation + 1):
|
||||||
|
# j is the `start` in the right subtree
|
||||||
|
for j in range(min_separation + 1):
|
||||||
|
left_sum = _tree_sensitivity_square_sum(
|
||||||
|
num_participation=num_participation - i,
|
||||||
|
min_separation=min_separation,
|
||||||
|
start=start,
|
||||||
|
end=j,
|
||||||
|
size=2**max_2power)
|
||||||
|
if np.isinf(left_sum):
|
||||||
|
candidate_sum.append(-np.inf)
|
||||||
|
continue # Early pruning for dynamic programming
|
||||||
|
right_sum = _tree_sensitivity_square_sum(
|
||||||
|
num_participation=i,
|
||||||
|
min_separation=min_separation,
|
||||||
|
start=j,
|
||||||
|
end=end,
|
||||||
|
size=size - 2**max_2power)
|
||||||
|
candidate_sum.append(left_sum + right_sum)
|
||||||
|
sum_value += max(candidate_sum)
|
||||||
|
return sum_value
|
||||||
|
|
||||||
|
|
||||||
|
def _max_tree_sensitivity_square_sum(max_participation: int,
|
||||||
|
min_separation: int, steps: int) -> float:
|
||||||
|
"""Compute the worst-case sum of sensitivity square in tree aggregation.
|
||||||
|
|
||||||
|
See Appendix D.2 of
|
||||||
|
"Practical and Private (Deep) Learning without Sampling or Shuffling"
|
||||||
|
https://arxiv.org/abs/2103.00039.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_participation: The maximum number of times a sample will appear.
|
||||||
|
min_separation: The minimum number of nodes between two appearance of a
|
||||||
|
sample. If a sample appears in consecutive x, y steps in a streaming
|
||||||
|
setting, then `min_separation=y-x-1`.
|
||||||
|
steps: Total number of steps (leaf nodes in tree aggregation).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The worst-case sum of sensitivity square for the given input.
|
||||||
|
"""
|
||||||
|
num_participation = max_participation
|
||||||
|
while not _check_possible_tree_participation(
|
||||||
|
num_participation, min_separation, 0, min_separation, steps):
|
||||||
|
num_participation -= 1
|
||||||
|
candidate_sum = []
|
||||||
|
for num_part in range(1, num_participation + 1):
|
||||||
|
candidate_sum.append(
|
||||||
|
_tree_sensitivity_square_sum(num_part, min_separation, 0,
|
||||||
|
min_separation, steps))
|
||||||
|
return max(candidate_sum)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_gaussian_rdp(sigma: float, sum_sensitivity_square: float,
|
||||||
|
alpha: float) -> float:
|
||||||
|
"""Computes RDP of Gaussian mechanism."""
|
||||||
|
if np.isinf(alpha):
|
||||||
|
return np.inf
|
||||||
|
return alpha * sum_sensitivity_square / (2 * sigma**2)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_rdp_single_tree(
|
||||||
|
noise_multiplier: float, total_steps: int, max_participation: int,
|
||||||
|
min_separation: int,
|
||||||
|
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
|
||||||
|
"""Computes RDP of the Tree Aggregation Protocol for a single tree.
|
||||||
|
|
||||||
|
The accounting assume a single tree is constructed for `total_steps` leaf
|
||||||
|
nodes, where the same sample will appear at most `max_participation` times,
|
||||||
|
and there are at least `min_separation` nodes between two appearance. The key
|
||||||
|
idea is to (recurrently) count the worst-case occurence of a sample
|
||||||
|
in all the nodes in a tree, which implements a dynamic programming algorithm
|
||||||
|
that exhausts the possible `num_participation` appearance of a sample in
|
||||||
|
`steps` leaf nodes.
|
||||||
|
|
||||||
|
See Appendix D of
|
||||||
|
"Practical and Private (Deep) Learning without Sampling or Shuffling"
|
||||||
|
https://arxiv.org/abs/2103.00039.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
noise_multiplier: A non-negative float representing the ratio of the
|
||||||
|
standard deviation of the Gaussian noise to the l2-sensitivity of the
|
||||||
|
function to which it is added.
|
||||||
|
total_steps: Total number of steps (leaf nodes in tree aggregation).
|
||||||
|
max_participation: The maximum number of times a sample can appear.
|
||||||
|
min_separation: The minimum number of nodes between two appearance of a
|
||||||
|
sample. If a sample appears in consecutive x, y steps in a streaming
|
||||||
|
setting, then `min_separation=y-x-1`.
|
||||||
|
orders: An array (or a scalar) of RDP orders.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The RDPs at all orders. Can be `np.inf`.
|
||||||
|
"""
|
||||||
|
_check_nonnegative(noise_multiplier, "noise_multiplier")
|
||||||
|
if noise_multiplier == 0:
|
||||||
|
return np.inf
|
||||||
|
_check_nonnegative(total_steps, "total_steps")
|
||||||
|
_check_nonnegative(max_participation, "max_participation")
|
||||||
|
_check_nonnegative(min_separation, "min_separation")
|
||||||
|
sum_sensitivity_square = _max_tree_sensitivity_square_sum(
|
||||||
|
max_participation, min_separation, total_steps)
|
||||||
|
if np.isscalar(orders):
|
||||||
|
rdp = _compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square,
|
||||||
|
orders)
|
||||||
|
else:
|
||||||
|
rdp = np.array([
|
||||||
|
_compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square, alpha)
|
||||||
|
for alpha in orders
|
||||||
|
])
|
||||||
|
return rdp
|
||||||
|
|
||||||
|
|
||||||
def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders):
|
def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders):
|
||||||
"""Compute RDP of Gaussian Mechanism using sampling without replacement.
|
"""Compute RDP of Gaussian Mechanism using sampling without replacement.
|
||||||
|
|
||||||
|
|
|
@ -335,6 +335,90 @@ class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
rdp = rdp_accountant.compute_rdp(1., noise_multiplier, total_steps, orders)
|
rdp = rdp_accountant.compute_rdp(1., noise_multiplier, total_steps, orders)
|
||||||
self.assertAllClose(tree_rdp, rdp, rtol=1e-12)
|
self.assertAllClose(tree_rdp, rdp, rtol=1e-12)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('negative_noise', -1, 3, 1, 1),
|
||||||
|
('negative_steps', 0.1, -3, 1, 1),
|
||||||
|
('negative_part', 0.1, 3, -1, 1),
|
||||||
|
('negative_sep', 0.1, 3, 1, -1),
|
||||||
|
)
|
||||||
|
def test_compute_rdp_single_tree_raise(self, noise_multiplier, total_steps,
|
||||||
|
max_participation, min_separation):
|
||||||
|
orders = 1
|
||||||
|
with self.assertRaisesRegex(ValueError, 'must be'):
|
||||||
|
rdp_accountant.compute_rdp_single_tree(noise_multiplier, total_steps,
|
||||||
|
max_participation, min_separation,
|
||||||
|
orders)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('3', 3),
|
||||||
|
('8', 8),
|
||||||
|
('11', 11),
|
||||||
|
('19', 19),
|
||||||
|
)
|
||||||
|
def test_max_tree_sensitivity_square_sum_every_step(self, steps):
|
||||||
|
max_participation, min_separation = steps, 0
|
||||||
|
# If a sample will appear in every leaf node, we can infer the total
|
||||||
|
# sensitivity by adding all the nodes.
|
||||||
|
steps_bin = bin(steps)[2:]
|
||||||
|
depth = [
|
||||||
|
len(steps_bin) - 1 - i for i, v in enumerate(steps_bin) if v == '1'
|
||||||
|
]
|
||||||
|
expected = sum([2**d * (2**(d + 1) - 1) for d in depth])
|
||||||
|
self.assertEqual(
|
||||||
|
expected,
|
||||||
|
rdp_accountant._max_tree_sensitivity_square_sum(max_participation,
|
||||||
|
min_separation, steps))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('11', 11),
|
||||||
|
('19', 19),
|
||||||
|
('200', 200),
|
||||||
|
)
|
||||||
|
def test_max_tree_sensitivity_square_sum_every_step_part(self, max_part):
|
||||||
|
steps, min_separation = 8, 0
|
||||||
|
assert max_part > steps
|
||||||
|
# If a sample will appear in every leaf node, we can infer the total
|
||||||
|
# sensitivity by adding all the nodes.
|
||||||
|
expected = 120
|
||||||
|
self.assertEqual(
|
||||||
|
expected,
|
||||||
|
rdp_accountant._max_tree_sensitivity_square_sum(max_part,
|
||||||
|
min_separation, steps))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('3', 3),
|
||||||
|
('8', 8),
|
||||||
|
('11', 11),
|
||||||
|
('19', 19),
|
||||||
|
)
|
||||||
|
def test_max_tree_sensitivity_square_sum_every_step_part2(self, steps):
|
||||||
|
max_participation, min_separation = 2, 0
|
||||||
|
# If a sample will appear twice, the worst case is to put the two nodes at
|
||||||
|
# consecutive nodes of the deepest subtree.
|
||||||
|
steps_bin = bin(steps)[2:]
|
||||||
|
depth = len(steps_bin) - 1
|
||||||
|
expected = 2 + 4 * depth
|
||||||
|
self.assertEqual(
|
||||||
|
expected,
|
||||||
|
rdp_accountant._max_tree_sensitivity_square_sum(max_participation,
|
||||||
|
min_separation, steps))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('test1', 1, 7, 8, 4),
|
||||||
|
('test2', 3, 3, 9, 11),
|
||||||
|
('test3', 3, 2, 7, 9),
|
||||||
|
# This is an example showing worst-case sensitivity is larger than greedy
|
||||||
|
# in "Practical and Private (Deep) Learning without Sampling or Shuffling"
|
||||||
|
# https://arxiv.org/abs/2103.00039.
|
||||||
|
('test4', 8, 2, 24, 88),
|
||||||
|
)
|
||||||
|
def test_max_tree_sensitivity_square_sum_toy(self, max_participation,
|
||||||
|
min_separation, steps, expected):
|
||||||
|
self.assertEqual(
|
||||||
|
expected,
|
||||||
|
rdp_accountant._max_tree_sensitivity_square_sum(max_participation,
|
||||||
|
min_separation, steps))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue