RDP accounting for tree aggregation without restart. This implements the dynamic programming algorithm detailed in the updated version of "Practical and Private (Deep) Learning without Sampling or Shuffling"

https://arxiv.org/abs/2103.00039.

PiperOrigin-RevId: 414583453
This commit is contained in:
Zheng Xu 2021-12-06 17:38:14 -08:00 committed by A. Unique TensorFlower
parent 49db04e356
commit 245fd069ca
2 changed files with 276 additions and 5 deletions

View file

@ -40,6 +40,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import math
import sys
from typing import Collection, Union
@ -398,6 +399,7 @@ def compute_rdp(q, noise_multiplier, steps, orders):
return rdp * steps
# TODO(b/193679963): move accounting for tree aggregation to a separate module
def _compute_rdp_tree_restart(sigma, steps_list, alpha):
"""Computes RDP of the Tree Aggregation Protocol at order alpha."""
if np.isinf(alpha):
@ -407,7 +409,8 @@ def _compute_rdp_tree_restart(sigma, steps_list, alpha):
for steps in steps_list
if steps > 0
]
return alpha * sum(tree_depths) / (2 * sigma**2)
return _compute_gaussian_rdp(
alpha=alpha, sum_sensitivity_square=sum(tree_depths), sigma=sigma)
def compute_rdp_tree_restart(
@ -431,10 +434,8 @@ def compute_rdp_tree_restart(
Returns:
The RDPs at all orders. Can be `np.inf`.
"""
if noise_multiplier < 0:
raise ValueError(
f"Noise multiplier must be non-negative, got {noise_multiplier}")
elif noise_multiplier == 0:
_check_nonnegative(noise_multiplier, "noise_multiplier")
if noise_multiplier == 0:
return np.inf
if not steps_list:
@ -460,6 +461,192 @@ def compute_rdp_tree_restart(
return rdp
def _check_nonnegative(value: Union[int, float], name: str):
if value < 0:
raise ValueError(f"Provided {name} must be non-negative, got {value}")
def _check_possible_tree_participation(num_participation: int,
min_separation: int, start: int,
end: int, steps: int) -> bool:
"""Check if participation is possible with `min_separation` in `steps`.
This function checks if it is possible for a sample to appear
`num_participation` in `steps`, assuming there are at least `min_separation`
nodes between the appearance of the same sample in the streaming data (leaf
nodes in tree aggregation). The first appearance of the sample is after
`start` steps, and the sample won't appear in the `end` steps after the given
`steps`.
Args:
num_participation: The number of times a sample will appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y steps in a streaming
setting, then `min_separation=y-x-1`.
start: The first appearance of the sample is after `start` steps.
end: The sample won't appear in the `end` steps after the given `steps`.
steps: Total number of steps (leaf nodes in tree aggregation).
Returns:
True if a sample can appear `num_participation` with given conditions.
"""
return start + (min_separation + 1) * num_participation <= steps + end
@functools.lru_cache(maxsize=None)
def _tree_sensitivity_square_sum(num_participation: int, min_separation: int,
start: int, end: int, size: int) -> float:
"""Compute the worst-case sum of sensitivtiy square for `num_participation`.
This is the key algorithm for DP accounting for DP-FTRL tree aggregation
without restart, which recurrently counts the worst-case occurence of a sample
in all the nodes in a tree. This implements a dynamic programming algorithm
that exhausts the possible `num_participation` appearance of a sample in
`size` leaf nodes. See Appendix D.2 (DP-FTRL-NoTreeRestart) of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Args:
num_participation: The number of times a sample will appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y size in a streaming
setting, then `min_separation=y-x-1`.
start: The first appearance of the sample is after `start` steps.
end: The sample won't appear in the `end` steps after given `size` steps.
size: Total number of steps (leaf nodes in tree aggregation).
Returns:
The worst-case sum of sensitivity square for the given input.
"""
if not _check_possible_tree_participation(num_participation, min_separation,
start, end, size):
sum_value = -np.inf
elif num_participation == 0:
sum_value = 0.
elif num_participation == 1 and size == 1:
sum_value = 1.
else:
size_log2 = math.log2(size)
max_2power = math.floor(size_log2)
if max_2power == size_log2:
sum_value = num_participation**2
max_2power -= 1
else:
sum_value = 0.
candidate_sum = []
# i is the `num_participation` in the right subtree
for i in range(num_participation + 1):
# j is the `start` in the right subtree
for j in range(min_separation + 1):
left_sum = _tree_sensitivity_square_sum(
num_participation=num_participation - i,
min_separation=min_separation,
start=start,
end=j,
size=2**max_2power)
if np.isinf(left_sum):
candidate_sum.append(-np.inf)
continue # Early pruning for dynamic programming
right_sum = _tree_sensitivity_square_sum(
num_participation=i,
min_separation=min_separation,
start=j,
end=end,
size=size - 2**max_2power)
candidate_sum.append(left_sum + right_sum)
sum_value += max(candidate_sum)
return sum_value
def _max_tree_sensitivity_square_sum(max_participation: int,
min_separation: int, steps: int) -> float:
"""Compute the worst-case sum of sensitivity square in tree aggregation.
See Appendix D.2 of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Args:
max_participation: The maximum number of times a sample will appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y steps in a streaming
setting, then `min_separation=y-x-1`.
steps: Total number of steps (leaf nodes in tree aggregation).
Returns:
The worst-case sum of sensitivity square for the given input.
"""
num_participation = max_participation
while not _check_possible_tree_participation(
num_participation, min_separation, 0, min_separation, steps):
num_participation -= 1
candidate_sum = []
for num_part in range(1, num_participation + 1):
candidate_sum.append(
_tree_sensitivity_square_sum(num_part, min_separation, 0,
min_separation, steps))
return max(candidate_sum)
def _compute_gaussian_rdp(sigma: float, sum_sensitivity_square: float,
alpha: float) -> float:
"""Computes RDP of Gaussian mechanism."""
if np.isinf(alpha):
return np.inf
return alpha * sum_sensitivity_square / (2 * sigma**2)
def compute_rdp_single_tree(
noise_multiplier: float, total_steps: int, max_participation: int,
min_separation: int,
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
"""Computes RDP of the Tree Aggregation Protocol for a single tree.
The accounting assume a single tree is constructed for `total_steps` leaf
nodes, where the same sample will appear at most `max_participation` times,
and there are at least `min_separation` nodes between two appearance. The key
idea is to (recurrently) count the worst-case occurence of a sample
in all the nodes in a tree, which implements a dynamic programming algorithm
that exhausts the possible `num_participation` appearance of a sample in
`steps` leaf nodes.
See Appendix D of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Args:
noise_multiplier: A non-negative float representing the ratio of the
standard deviation of the Gaussian noise to the l2-sensitivity of the
function to which it is added.
total_steps: Total number of steps (leaf nodes in tree aggregation).
max_participation: The maximum number of times a sample can appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y steps in a streaming
setting, then `min_separation=y-x-1`.
orders: An array (or a scalar) of RDP orders.
Returns:
The RDPs at all orders. Can be `np.inf`.
"""
_check_nonnegative(noise_multiplier, "noise_multiplier")
if noise_multiplier == 0:
return np.inf
_check_nonnegative(total_steps, "total_steps")
_check_nonnegative(max_participation, "max_participation")
_check_nonnegative(min_separation, "min_separation")
sum_sensitivity_square = _max_tree_sensitivity_square_sum(
max_participation, min_separation, total_steps)
if np.isscalar(orders):
rdp = _compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square,
orders)
else:
rdp = np.array([
_compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square, alpha)
for alpha in orders
])
return rdp
def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders):
"""Compute RDP of Gaussian Mechanism using sampling without replacement.

View file

@ -335,6 +335,90 @@ class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase):
rdp = rdp_accountant.compute_rdp(1., noise_multiplier, total_steps, orders)
self.assertAllClose(tree_rdp, rdp, rtol=1e-12)
@parameterized.named_parameters(
('negative_noise', -1, 3, 1, 1),
('negative_steps', 0.1, -3, 1, 1),
('negative_part', 0.1, 3, -1, 1),
('negative_sep', 0.1, 3, 1, -1),
)
def test_compute_rdp_single_tree_raise(self, noise_multiplier, total_steps,
max_participation, min_separation):
orders = 1
with self.assertRaisesRegex(ValueError, 'must be'):
rdp_accountant.compute_rdp_single_tree(noise_multiplier, total_steps,
max_participation, min_separation,
orders)
@parameterized.named_parameters(
('3', 3),
('8', 8),
('11', 11),
('19', 19),
)
def test_max_tree_sensitivity_square_sum_every_step(self, steps):
max_participation, min_separation = steps, 0
# If a sample will appear in every leaf node, we can infer the total
# sensitivity by adding all the nodes.
steps_bin = bin(steps)[2:]
depth = [
len(steps_bin) - 1 - i for i, v in enumerate(steps_bin) if v == '1'
]
expected = sum([2**d * (2**(d + 1) - 1) for d in depth])
self.assertEqual(
expected,
rdp_accountant._max_tree_sensitivity_square_sum(max_participation,
min_separation, steps))
@parameterized.named_parameters(
('11', 11),
('19', 19),
('200', 200),
)
def test_max_tree_sensitivity_square_sum_every_step_part(self, max_part):
steps, min_separation = 8, 0
assert max_part > steps
# If a sample will appear in every leaf node, we can infer the total
# sensitivity by adding all the nodes.
expected = 120
self.assertEqual(
expected,
rdp_accountant._max_tree_sensitivity_square_sum(max_part,
min_separation, steps))
@parameterized.named_parameters(
('3', 3),
('8', 8),
('11', 11),
('19', 19),
)
def test_max_tree_sensitivity_square_sum_every_step_part2(self, steps):
max_participation, min_separation = 2, 0
# If a sample will appear twice, the worst case is to put the two nodes at
# consecutive nodes of the deepest subtree.
steps_bin = bin(steps)[2:]
depth = len(steps_bin) - 1
expected = 2 + 4 * depth
self.assertEqual(
expected,
rdp_accountant._max_tree_sensitivity_square_sum(max_participation,
min_separation, steps))
@parameterized.named_parameters(
('test1', 1, 7, 8, 4),
('test2', 3, 3, 9, 11),
('test3', 3, 2, 7, 9),
# This is an example showing worst-case sensitivity is larger than greedy
# in "Practical and Private (Deep) Learning without Sampling or Shuffling"
# https://arxiv.org/abs/2103.00039.
('test4', 8, 2, 24, 88),
)
def test_max_tree_sensitivity_square_sum_toy(self, max_participation,
min_separation, steps, expected):
self.assertEqual(
expected,
rdp_accountant._max_tree_sensitivity_square_sum(max_participation,
min_separation, steps))
if __name__ == '__main__':
tf.test.main()