Move tree_aggregation accountant to their own module.
PiperOrigin-RevId: 414770173
This commit is contained in:
parent
245fd069ca
commit
8850c23f67
5 changed files with 502 additions and 408 deletions
|
@ -45,8 +45,9 @@ else:
|
||||||
from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy
|
from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy
|
||||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_heterogeneous_rdp
|
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_heterogeneous_rdp
|
||||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
|
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
|
||||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp_tree_restart
|
|
||||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||||
|
from tensorflow_privacy.privacy.analysis.tree_aggregation_accountant import compute_rdp_tree_restart
|
||||||
|
from tensorflow_privacy.privacy.analysis.tree_aggregation_accountant import compute_rdp_single_tree
|
||||||
|
|
||||||
# DPQuery classes
|
# DPQuery classes
|
||||||
from tensorflow_privacy.privacy.dp_query.dp_query import DPQuery
|
from tensorflow_privacy.privacy.dp_query.dp_query import DPQuery
|
||||||
|
|
|
@ -40,11 +40,8 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import functools
|
|
||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
from typing import Collection, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import special
|
from scipy import special
|
||||||
import six
|
import six
|
||||||
|
@ -399,254 +396,6 @@ def compute_rdp(q, noise_multiplier, steps, orders):
|
||||||
return rdp * steps
|
return rdp * steps
|
||||||
|
|
||||||
|
|
||||||
# TODO(b/193679963): move accounting for tree aggregation to a separate module
|
|
||||||
def _compute_rdp_tree_restart(sigma, steps_list, alpha):
|
|
||||||
"""Computes RDP of the Tree Aggregation Protocol at order alpha."""
|
|
||||||
if np.isinf(alpha):
|
|
||||||
return np.inf
|
|
||||||
tree_depths = [
|
|
||||||
math.floor(math.log2(float(steps))) + 1
|
|
||||||
for steps in steps_list
|
|
||||||
if steps > 0
|
|
||||||
]
|
|
||||||
return _compute_gaussian_rdp(
|
|
||||||
alpha=alpha, sum_sensitivity_square=sum(tree_depths), sigma=sigma)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_rdp_tree_restart(
|
|
||||||
noise_multiplier: float, steps_list: Union[int, Collection[int]],
|
|
||||||
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
|
|
||||||
"""Computes RDP of the Tree Aggregation Protocol for Gaussian Mechanism.
|
|
||||||
|
|
||||||
This function implements the accounting when the tree is restarted at every
|
|
||||||
epoch. See appendix of
|
|
||||||
"Practical and Private (Deep) Learning without Sampling or Shuffling"
|
|
||||||
https://arxiv.org/abs/2103.00039.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
noise_multiplier: A non-negative float representing the ratio of the
|
|
||||||
standard deviation of the Gaussian noise to the l2-sensitivity of the
|
|
||||||
function to which it is added.
|
|
||||||
steps_list: A scalar or a list of non-negative intergers representing the
|
|
||||||
number of steps per epoch (between two restarts).
|
|
||||||
orders: An array (or a scalar) of RDP orders.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The RDPs at all orders. Can be `np.inf`.
|
|
||||||
"""
|
|
||||||
_check_nonnegative(noise_multiplier, "noise_multiplier")
|
|
||||||
if noise_multiplier == 0:
|
|
||||||
return np.inf
|
|
||||||
|
|
||||||
if not steps_list:
|
|
||||||
raise ValueError(
|
|
||||||
"steps_list must be a non-empty list, or a non-zero scalar, got "
|
|
||||||
f"{steps_list}.")
|
|
||||||
|
|
||||||
if np.isscalar(steps_list):
|
|
||||||
steps_list = [steps_list]
|
|
||||||
|
|
||||||
for steps in steps_list:
|
|
||||||
if steps < 0:
|
|
||||||
raise ValueError(f"Steps must be non-negative, got {steps_list}")
|
|
||||||
|
|
||||||
if np.isscalar(orders):
|
|
||||||
rdp = _compute_rdp_tree_restart(noise_multiplier, steps_list, orders)
|
|
||||||
else:
|
|
||||||
rdp = np.array([
|
|
||||||
_compute_rdp_tree_restart(noise_multiplier, steps_list, alpha)
|
|
||||||
for alpha in orders
|
|
||||||
])
|
|
||||||
|
|
||||||
return rdp
|
|
||||||
|
|
||||||
|
|
||||||
def _check_nonnegative(value: Union[int, float], name: str):
|
|
||||||
if value < 0:
|
|
||||||
raise ValueError(f"Provided {name} must be non-negative, got {value}")
|
|
||||||
|
|
||||||
|
|
||||||
def _check_possible_tree_participation(num_participation: int,
|
|
||||||
min_separation: int, start: int,
|
|
||||||
end: int, steps: int) -> bool:
|
|
||||||
"""Check if participation is possible with `min_separation` in `steps`.
|
|
||||||
|
|
||||||
This function checks if it is possible for a sample to appear
|
|
||||||
`num_participation` in `steps`, assuming there are at least `min_separation`
|
|
||||||
nodes between the appearance of the same sample in the streaming data (leaf
|
|
||||||
nodes in tree aggregation). The first appearance of the sample is after
|
|
||||||
`start` steps, and the sample won't appear in the `end` steps after the given
|
|
||||||
`steps`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_participation: The number of times a sample will appear.
|
|
||||||
min_separation: The minimum number of nodes between two appearance of a
|
|
||||||
sample. If a sample appears in consecutive x, y steps in a streaming
|
|
||||||
setting, then `min_separation=y-x-1`.
|
|
||||||
start: The first appearance of the sample is after `start` steps.
|
|
||||||
end: The sample won't appear in the `end` steps after the given `steps`.
|
|
||||||
steps: Total number of steps (leaf nodes in tree aggregation).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if a sample can appear `num_participation` with given conditions.
|
|
||||||
"""
|
|
||||||
return start + (min_separation + 1) * num_participation <= steps + end
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=None)
|
|
||||||
def _tree_sensitivity_square_sum(num_participation: int, min_separation: int,
|
|
||||||
start: int, end: int, size: int) -> float:
|
|
||||||
"""Compute the worst-case sum of sensitivtiy square for `num_participation`.
|
|
||||||
|
|
||||||
This is the key algorithm for DP accounting for DP-FTRL tree aggregation
|
|
||||||
without restart, which recurrently counts the worst-case occurence of a sample
|
|
||||||
in all the nodes in a tree. This implements a dynamic programming algorithm
|
|
||||||
that exhausts the possible `num_participation` appearance of a sample in
|
|
||||||
`size` leaf nodes. See Appendix D.2 (DP-FTRL-NoTreeRestart) of
|
|
||||||
"Practical and Private (Deep) Learning without Sampling or Shuffling"
|
|
||||||
https://arxiv.org/abs/2103.00039.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_participation: The number of times a sample will appear.
|
|
||||||
min_separation: The minimum number of nodes between two appearance of a
|
|
||||||
sample. If a sample appears in consecutive x, y size in a streaming
|
|
||||||
setting, then `min_separation=y-x-1`.
|
|
||||||
start: The first appearance of the sample is after `start` steps.
|
|
||||||
end: The sample won't appear in the `end` steps after given `size` steps.
|
|
||||||
size: Total number of steps (leaf nodes in tree aggregation).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The worst-case sum of sensitivity square for the given input.
|
|
||||||
"""
|
|
||||||
if not _check_possible_tree_participation(num_participation, min_separation,
|
|
||||||
start, end, size):
|
|
||||||
sum_value = -np.inf
|
|
||||||
elif num_participation == 0:
|
|
||||||
sum_value = 0.
|
|
||||||
elif num_participation == 1 and size == 1:
|
|
||||||
sum_value = 1.
|
|
||||||
else:
|
|
||||||
size_log2 = math.log2(size)
|
|
||||||
max_2power = math.floor(size_log2)
|
|
||||||
if max_2power == size_log2:
|
|
||||||
sum_value = num_participation**2
|
|
||||||
max_2power -= 1
|
|
||||||
else:
|
|
||||||
sum_value = 0.
|
|
||||||
candidate_sum = []
|
|
||||||
# i is the `num_participation` in the right subtree
|
|
||||||
for i in range(num_participation + 1):
|
|
||||||
# j is the `start` in the right subtree
|
|
||||||
for j in range(min_separation + 1):
|
|
||||||
left_sum = _tree_sensitivity_square_sum(
|
|
||||||
num_participation=num_participation - i,
|
|
||||||
min_separation=min_separation,
|
|
||||||
start=start,
|
|
||||||
end=j,
|
|
||||||
size=2**max_2power)
|
|
||||||
if np.isinf(left_sum):
|
|
||||||
candidate_sum.append(-np.inf)
|
|
||||||
continue # Early pruning for dynamic programming
|
|
||||||
right_sum = _tree_sensitivity_square_sum(
|
|
||||||
num_participation=i,
|
|
||||||
min_separation=min_separation,
|
|
||||||
start=j,
|
|
||||||
end=end,
|
|
||||||
size=size - 2**max_2power)
|
|
||||||
candidate_sum.append(left_sum + right_sum)
|
|
||||||
sum_value += max(candidate_sum)
|
|
||||||
return sum_value
|
|
||||||
|
|
||||||
|
|
||||||
def _max_tree_sensitivity_square_sum(max_participation: int,
|
|
||||||
min_separation: int, steps: int) -> float:
|
|
||||||
"""Compute the worst-case sum of sensitivity square in tree aggregation.
|
|
||||||
|
|
||||||
See Appendix D.2 of
|
|
||||||
"Practical and Private (Deep) Learning without Sampling or Shuffling"
|
|
||||||
https://arxiv.org/abs/2103.00039.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_participation: The maximum number of times a sample will appear.
|
|
||||||
min_separation: The minimum number of nodes between two appearance of a
|
|
||||||
sample. If a sample appears in consecutive x, y steps in a streaming
|
|
||||||
setting, then `min_separation=y-x-1`.
|
|
||||||
steps: Total number of steps (leaf nodes in tree aggregation).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The worst-case sum of sensitivity square for the given input.
|
|
||||||
"""
|
|
||||||
num_participation = max_participation
|
|
||||||
while not _check_possible_tree_participation(
|
|
||||||
num_participation, min_separation, 0, min_separation, steps):
|
|
||||||
num_participation -= 1
|
|
||||||
candidate_sum = []
|
|
||||||
for num_part in range(1, num_participation + 1):
|
|
||||||
candidate_sum.append(
|
|
||||||
_tree_sensitivity_square_sum(num_part, min_separation, 0,
|
|
||||||
min_separation, steps))
|
|
||||||
return max(candidate_sum)
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_gaussian_rdp(sigma: float, sum_sensitivity_square: float,
|
|
||||||
alpha: float) -> float:
|
|
||||||
"""Computes RDP of Gaussian mechanism."""
|
|
||||||
if np.isinf(alpha):
|
|
||||||
return np.inf
|
|
||||||
return alpha * sum_sensitivity_square / (2 * sigma**2)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_rdp_single_tree(
|
|
||||||
noise_multiplier: float, total_steps: int, max_participation: int,
|
|
||||||
min_separation: int,
|
|
||||||
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
|
|
||||||
"""Computes RDP of the Tree Aggregation Protocol for a single tree.
|
|
||||||
|
|
||||||
The accounting assume a single tree is constructed for `total_steps` leaf
|
|
||||||
nodes, where the same sample will appear at most `max_participation` times,
|
|
||||||
and there are at least `min_separation` nodes between two appearance. The key
|
|
||||||
idea is to (recurrently) count the worst-case occurence of a sample
|
|
||||||
in all the nodes in a tree, which implements a dynamic programming algorithm
|
|
||||||
that exhausts the possible `num_participation` appearance of a sample in
|
|
||||||
`steps` leaf nodes.
|
|
||||||
|
|
||||||
See Appendix D of
|
|
||||||
"Practical and Private (Deep) Learning without Sampling or Shuffling"
|
|
||||||
https://arxiv.org/abs/2103.00039.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
noise_multiplier: A non-negative float representing the ratio of the
|
|
||||||
standard deviation of the Gaussian noise to the l2-sensitivity of the
|
|
||||||
function to which it is added.
|
|
||||||
total_steps: Total number of steps (leaf nodes in tree aggregation).
|
|
||||||
max_participation: The maximum number of times a sample can appear.
|
|
||||||
min_separation: The minimum number of nodes between two appearance of a
|
|
||||||
sample. If a sample appears in consecutive x, y steps in a streaming
|
|
||||||
setting, then `min_separation=y-x-1`.
|
|
||||||
orders: An array (or a scalar) of RDP orders.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The RDPs at all orders. Can be `np.inf`.
|
|
||||||
"""
|
|
||||||
_check_nonnegative(noise_multiplier, "noise_multiplier")
|
|
||||||
if noise_multiplier == 0:
|
|
||||||
return np.inf
|
|
||||||
_check_nonnegative(total_steps, "total_steps")
|
|
||||||
_check_nonnegative(max_participation, "max_participation")
|
|
||||||
_check_nonnegative(min_separation, "min_separation")
|
|
||||||
sum_sensitivity_square = _max_tree_sensitivity_square_sum(
|
|
||||||
max_participation, min_separation, total_steps)
|
|
||||||
if np.isscalar(orders):
|
|
||||||
rdp = _compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square,
|
|
||||||
orders)
|
|
||||||
else:
|
|
||||||
rdp = np.array([
|
|
||||||
_compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square, alpha)
|
|
||||||
for alpha in orders
|
|
||||||
])
|
|
||||||
return rdp
|
|
||||||
|
|
||||||
|
|
||||||
def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders):
|
def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders):
|
||||||
"""Compute RDP of Gaussian Mechanism using sampling without replacement.
|
"""Compute RDP of Gaussian Mechanism using sampling without replacement.
|
||||||
|
|
||||||
|
|
|
@ -264,161 +264,5 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertLessEqual(delta, delta1 + 1e-300)
|
self.assertLessEqual(delta, delta1 + 1e-300)
|
||||||
|
|
||||||
|
|
||||||
class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase):
|
|
||||||
|
|
||||||
@parameterized.named_parameters(('eps20', 1.13, 19.74), ('eps2', 8.83, 2.04))
|
|
||||||
def test_compute_eps_tree(self, noise_multiplier, eps):
|
|
||||||
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
|
|
||||||
# This tests is based on the StackOverflow setting in "Practical and
|
|
||||||
# Private (Deep) Learning without Sampling or Shuffling". The calculated
|
|
||||||
# epsilon could be better as the method in this package keeps improving.
|
|
||||||
steps_list, target_delta = 1600, 1e-6
|
|
||||||
rdp = rdp_accountant.compute_rdp_tree_restart(noise_multiplier, steps_list,
|
|
||||||
orders)
|
|
||||||
new_eps = rdp_accountant.get_privacy_spent(
|
|
||||||
orders, rdp, target_delta=target_delta)[0]
|
|
||||||
self.assertLess(new_eps, eps)
|
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
|
||||||
('restart4', [400] * 4),
|
|
||||||
('restart2', [800] * 2),
|
|
||||||
('adaptive', [10, 400, 400, 400, 390]),
|
|
||||||
)
|
|
||||||
def test_compose_tree_rdp(self, steps_list):
|
|
||||||
noise_multiplier, orders = 0.1, 1
|
|
||||||
rdp_list = [
|
|
||||||
rdp_accountant.compute_rdp_tree_restart(noise_multiplier, steps, orders)
|
|
||||||
for steps in steps_list
|
|
||||||
]
|
|
||||||
rdp_composed = rdp_accountant.compute_rdp_tree_restart(
|
|
||||||
noise_multiplier, steps_list, orders)
|
|
||||||
self.assertAllClose(rdp_composed, sum(rdp_list), rtol=1e-12)
|
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
|
||||||
('restart4', [400] * 4),
|
|
||||||
('restart2', [800] * 2),
|
|
||||||
('adaptive', [10, 400, 400, 400, 390]),
|
|
||||||
)
|
|
||||||
def test_compute_eps_tree_decreasing(self, steps_list):
|
|
||||||
# Test privacy epsilon decreases with noise multiplier increasing when
|
|
||||||
# keeping other parameters the same.
|
|
||||||
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
|
|
||||||
target_delta = 1e-6
|
|
||||||
prev_eps = rdp_accountant.compute_rdp_tree_restart(0, steps_list, orders)
|
|
||||||
for noise_multiplier in [0.1 * x for x in range(1, 100, 5)]:
|
|
||||||
rdp = rdp_accountant.compute_rdp_tree_restart(noise_multiplier,
|
|
||||||
steps_list, orders)
|
|
||||||
eps = rdp_accountant.get_privacy_spent(
|
|
||||||
orders, rdp, target_delta=target_delta)[0]
|
|
||||||
self.assertLess(eps, prev_eps)
|
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
|
||||||
('negative_noise', -1, 3, 1),
|
|
||||||
('empty_steps', 1, [], 1),
|
|
||||||
('negative_steps', 1, -3, 1),
|
|
||||||
)
|
|
||||||
def test_compute_rdp_tree_restart_raise(self, noise_multiplier, steps_list,
|
|
||||||
orders):
|
|
||||||
with self.assertRaisesRegex(ValueError, 'must be'):
|
|
||||||
rdp_accountant.compute_rdp_tree_restart(noise_multiplier, steps_list,
|
|
||||||
orders)
|
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
|
||||||
('t100n0.1', 100, 0.1),
|
|
||||||
('t1000n0.01', 1000, 0.01),
|
|
||||||
)
|
|
||||||
def test_no_tree_no_sampling(self, total_steps, noise_multiplier):
|
|
||||||
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
|
|
||||||
tree_rdp = rdp_accountant.compute_rdp_tree_restart(noise_multiplier,
|
|
||||||
[1] * total_steps,
|
|
||||||
orders)
|
|
||||||
rdp = rdp_accountant.compute_rdp(1., noise_multiplier, total_steps, orders)
|
|
||||||
self.assertAllClose(tree_rdp, rdp, rtol=1e-12)
|
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
|
||||||
('negative_noise', -1, 3, 1, 1),
|
|
||||||
('negative_steps', 0.1, -3, 1, 1),
|
|
||||||
('negative_part', 0.1, 3, -1, 1),
|
|
||||||
('negative_sep', 0.1, 3, 1, -1),
|
|
||||||
)
|
|
||||||
def test_compute_rdp_single_tree_raise(self, noise_multiplier, total_steps,
|
|
||||||
max_participation, min_separation):
|
|
||||||
orders = 1
|
|
||||||
with self.assertRaisesRegex(ValueError, 'must be'):
|
|
||||||
rdp_accountant.compute_rdp_single_tree(noise_multiplier, total_steps,
|
|
||||||
max_participation, min_separation,
|
|
||||||
orders)
|
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
|
||||||
('3', 3),
|
|
||||||
('8', 8),
|
|
||||||
('11', 11),
|
|
||||||
('19', 19),
|
|
||||||
)
|
|
||||||
def test_max_tree_sensitivity_square_sum_every_step(self, steps):
|
|
||||||
max_participation, min_separation = steps, 0
|
|
||||||
# If a sample will appear in every leaf node, we can infer the total
|
|
||||||
# sensitivity by adding all the nodes.
|
|
||||||
steps_bin = bin(steps)[2:]
|
|
||||||
depth = [
|
|
||||||
len(steps_bin) - 1 - i for i, v in enumerate(steps_bin) if v == '1'
|
|
||||||
]
|
|
||||||
expected = sum([2**d * (2**(d + 1) - 1) for d in depth])
|
|
||||||
self.assertEqual(
|
|
||||||
expected,
|
|
||||||
rdp_accountant._max_tree_sensitivity_square_sum(max_participation,
|
|
||||||
min_separation, steps))
|
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
|
||||||
('11', 11),
|
|
||||||
('19', 19),
|
|
||||||
('200', 200),
|
|
||||||
)
|
|
||||||
def test_max_tree_sensitivity_square_sum_every_step_part(self, max_part):
|
|
||||||
steps, min_separation = 8, 0
|
|
||||||
assert max_part > steps
|
|
||||||
# If a sample will appear in every leaf node, we can infer the total
|
|
||||||
# sensitivity by adding all the nodes.
|
|
||||||
expected = 120
|
|
||||||
self.assertEqual(
|
|
||||||
expected,
|
|
||||||
rdp_accountant._max_tree_sensitivity_square_sum(max_part,
|
|
||||||
min_separation, steps))
|
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
|
||||||
('3', 3),
|
|
||||||
('8', 8),
|
|
||||||
('11', 11),
|
|
||||||
('19', 19),
|
|
||||||
)
|
|
||||||
def test_max_tree_sensitivity_square_sum_every_step_part2(self, steps):
|
|
||||||
max_participation, min_separation = 2, 0
|
|
||||||
# If a sample will appear twice, the worst case is to put the two nodes at
|
|
||||||
# consecutive nodes of the deepest subtree.
|
|
||||||
steps_bin = bin(steps)[2:]
|
|
||||||
depth = len(steps_bin) - 1
|
|
||||||
expected = 2 + 4 * depth
|
|
||||||
self.assertEqual(
|
|
||||||
expected,
|
|
||||||
rdp_accountant._max_tree_sensitivity_square_sum(max_participation,
|
|
||||||
min_separation, steps))
|
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
|
||||||
('test1', 1, 7, 8, 4),
|
|
||||||
('test2', 3, 3, 9, 11),
|
|
||||||
('test3', 3, 2, 7, 9),
|
|
||||||
# This is an example showing worst-case sensitivity is larger than greedy
|
|
||||||
# in "Practical and Private (Deep) Learning without Sampling or Shuffling"
|
|
||||||
# https://arxiv.org/abs/2103.00039.
|
|
||||||
('test4', 8, 2, 24, 88),
|
|
||||||
)
|
|
||||||
def test_max_tree_sensitivity_square_sum_toy(self, max_participation,
|
|
||||||
min_separation, steps, expected):
|
|
||||||
self.assertEqual(
|
|
||||||
expected,
|
|
||||||
rdp_accountant._max_tree_sensitivity_square_sum(max_participation,
|
|
||||||
min_separation, steps))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
|
@ -0,0 +1,315 @@
|
||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""DP analysis of tree aggregation.
|
||||||
|
|
||||||
|
See Appendix D of
|
||||||
|
"Practical and Private (Deep) Learning without Sampling or Shuffling"
|
||||||
|
https://arxiv.org/abs/2103.00039.
|
||||||
|
|
||||||
|
Functionality for computing differential privacy of tree aggregation of Gaussian
|
||||||
|
mechanism. Its public interface consists of the following methods:
|
||||||
|
compute_rdp_tree_restart(
|
||||||
|
noise_multiplier: float, steps_list: Union[int, Collection[int]],
|
||||||
|
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
|
||||||
|
computes RDP for DP-FTRL-TreeRestart.
|
||||||
|
compute_rdp_single_tree(
|
||||||
|
noise_multiplier: float, total_steps: int, max_participation: int,
|
||||||
|
min_separation: int,
|
||||||
|
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
|
||||||
|
computes RDP for DP-FTRL-NoTreeRestart.
|
||||||
|
|
||||||
|
For RDP to (epsilon, delta)-DP conversion, use the following public function
|
||||||
|
described in `rdp_accountant.py`:
|
||||||
|
get_privacy_spent(orders, rdp, target_eps, target_delta) computes delta
|
||||||
|
(or eps) given RDP at multiple orders and
|
||||||
|
a target value for eps (or delta).
|
||||||
|
|
||||||
|
Example use:
|
||||||
|
|
||||||
|
(1) DP-FTRL-TreeRestart RDP:
|
||||||
|
Suppose we use Gaussian mechanism of `noise_multiplier`; a sample may appear
|
||||||
|
at most once for every epoch and tree is restarted every epoch; the number of
|
||||||
|
leaf nodes for every epoch are tracked in `steps_list`. For `target_delta`, the
|
||||||
|
estimated epsilon is:
|
||||||
|
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
|
||||||
|
rdp = compute_rdp_tree_restart(noise_multiplier, steps_list, orders)
|
||||||
|
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
|
||||||
|
|
||||||
|
(2) DP-FTRL-NoTreeRestart RDP:
|
||||||
|
Suppose we use Gaussian mechanism of `noise_multiplier`; a sample may appear
|
||||||
|
at most `max_participation` times for a total of `total_steps` leaf nodes in a
|
||||||
|
single tree; there are at least `min_separation` leaf nodes between the two
|
||||||
|
appearance of a same sample. For `target_delta`, the estimated epsilon is:
|
||||||
|
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
|
||||||
|
rdp = compute_rdp_single_tree(noise_multiplier, total_steps,
|
||||||
|
max_participation, min_separation, orders)
|
||||||
|
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import math
|
||||||
|
from typing import Collection, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_rdp_tree_restart(sigma, steps_list, alpha):
|
||||||
|
"""Computes RDP of the Tree Aggregation Protocol at order alpha."""
|
||||||
|
if np.isinf(alpha):
|
||||||
|
return np.inf
|
||||||
|
tree_depths = [
|
||||||
|
math.floor(math.log2(float(steps))) + 1
|
||||||
|
for steps in steps_list
|
||||||
|
if steps > 0
|
||||||
|
]
|
||||||
|
return _compute_gaussian_rdp(
|
||||||
|
alpha=alpha, sum_sensitivity_square=sum(tree_depths), sigma=sigma)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_rdp_tree_restart(
|
||||||
|
noise_multiplier: float, steps_list: Union[int, Collection[int]],
|
||||||
|
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
|
||||||
|
"""Computes RDP of the Tree Aggregation Protocol for Gaussian Mechanism.
|
||||||
|
|
||||||
|
This function implements the accounting when the tree is restarted at every
|
||||||
|
epoch. See appendix D of
|
||||||
|
"Practical and Private (Deep) Learning without Sampling or Shuffling"
|
||||||
|
https://arxiv.org/abs/2103.00039.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
noise_multiplier: A non-negative float representing the ratio of the
|
||||||
|
standard deviation of the Gaussian noise to the l2-sensitivity of the
|
||||||
|
function to which it is added.
|
||||||
|
steps_list: A scalar or a list of non-negative intergers representing the
|
||||||
|
number of steps per epoch (between two restarts).
|
||||||
|
orders: An array (or a scalar) of RDP orders.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The RDPs at all orders. Can be `np.inf`.
|
||||||
|
"""
|
||||||
|
_check_nonnegative(noise_multiplier, "noise_multiplier")
|
||||||
|
if noise_multiplier == 0:
|
||||||
|
return np.inf
|
||||||
|
|
||||||
|
if not steps_list:
|
||||||
|
raise ValueError(
|
||||||
|
"steps_list must be a non-empty list, or a non-zero scalar, got "
|
||||||
|
f"{steps_list}.")
|
||||||
|
|
||||||
|
if np.isscalar(steps_list):
|
||||||
|
steps_list = [steps_list]
|
||||||
|
|
||||||
|
for steps in steps_list:
|
||||||
|
if steps < 0:
|
||||||
|
raise ValueError(f"Steps must be non-negative, got {steps_list}")
|
||||||
|
|
||||||
|
if np.isscalar(orders):
|
||||||
|
rdp = _compute_rdp_tree_restart(noise_multiplier, steps_list, orders)
|
||||||
|
else:
|
||||||
|
rdp = np.array([
|
||||||
|
_compute_rdp_tree_restart(noise_multiplier, steps_list, alpha)
|
||||||
|
for alpha in orders
|
||||||
|
])
|
||||||
|
|
||||||
|
return rdp
|
||||||
|
|
||||||
|
|
||||||
|
def _check_nonnegative(value: Union[int, float], name: str):
|
||||||
|
if value < 0:
|
||||||
|
raise ValueError(f"Provided {name} must be non-negative, got {value}")
|
||||||
|
|
||||||
|
|
||||||
|
def _check_possible_tree_participation(num_participation: int,
|
||||||
|
min_separation: int, start: int,
|
||||||
|
end: int, steps: int) -> bool:
|
||||||
|
"""Check if participation is possible with `min_separation` in `steps`.
|
||||||
|
|
||||||
|
This function checks if it is possible for a sample to appear
|
||||||
|
`num_participation` in `steps`, assuming there are at least `min_separation`
|
||||||
|
nodes between the appearance of the same sample in the streaming data (leaf
|
||||||
|
nodes in tree aggregation). The first appearance of the sample is after
|
||||||
|
`start` steps, and the sample won't appear in the `end` steps after the given
|
||||||
|
`steps`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_participation: The number of times a sample will appear.
|
||||||
|
min_separation: The minimum number of nodes between two appearance of a
|
||||||
|
sample. If a sample appears in consecutive x, y steps in a streaming
|
||||||
|
setting, then `min_separation=y-x-1`.
|
||||||
|
start: The first appearance of the sample is after `start` steps.
|
||||||
|
end: The sample won't appear in the `end` steps after the given `steps`.
|
||||||
|
steps: Total number of steps (leaf nodes in tree aggregation).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a sample can appear `num_participation` with given conditions.
|
||||||
|
"""
|
||||||
|
return start + (min_separation + 1) * num_participation <= steps + end
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=None)
|
||||||
|
def _tree_sensitivity_square_sum(num_participation: int, min_separation: int,
|
||||||
|
start: int, end: int, size: int) -> float:
|
||||||
|
"""Compute the worst-case sum of sensitivtiy square for `num_participation`.
|
||||||
|
|
||||||
|
This is the key algorithm for DP accounting for DP-FTRL tree aggregation
|
||||||
|
without restart, which recurrently counts the worst-case occurence of a sample
|
||||||
|
in all the nodes in a tree. This implements a dynamic programming algorithm
|
||||||
|
that exhausts the possible `num_participation` appearance of a sample in
|
||||||
|
`size` leaf nodes. See Appendix D.2 (DP-FTRL-NoTreeRestart) of
|
||||||
|
"Practical and Private (Deep) Learning without Sampling or Shuffling"
|
||||||
|
https://arxiv.org/abs/2103.00039.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_participation: The number of times a sample will appear.
|
||||||
|
min_separation: The minimum number of nodes between two appearance of a
|
||||||
|
sample. If a sample appears in consecutive x, y size in a streaming
|
||||||
|
setting, then `min_separation=y-x-1`.
|
||||||
|
start: The first appearance of the sample is after `start` steps.
|
||||||
|
end: The sample won't appear in the `end` steps after given `size` steps.
|
||||||
|
size: Total number of steps (leaf nodes in tree aggregation).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The worst-case sum of sensitivity square for the given input.
|
||||||
|
"""
|
||||||
|
if not _check_possible_tree_participation(num_participation, min_separation,
|
||||||
|
start, end, size):
|
||||||
|
sum_value = -np.inf
|
||||||
|
elif num_participation == 0:
|
||||||
|
sum_value = 0.
|
||||||
|
elif num_participation == 1 and size == 1:
|
||||||
|
sum_value = 1.
|
||||||
|
else:
|
||||||
|
size_log2 = math.log2(size)
|
||||||
|
max_2power = math.floor(size_log2)
|
||||||
|
if max_2power == size_log2:
|
||||||
|
sum_value = num_participation**2
|
||||||
|
max_2power -= 1
|
||||||
|
else:
|
||||||
|
sum_value = 0.
|
||||||
|
candidate_sum = []
|
||||||
|
# i is the `num_participation` in the right subtree
|
||||||
|
for i in range(num_participation + 1):
|
||||||
|
# j is the `start` in the right subtree
|
||||||
|
for j in range(min_separation + 1):
|
||||||
|
left_sum = _tree_sensitivity_square_sum(
|
||||||
|
num_participation=num_participation - i,
|
||||||
|
min_separation=min_separation,
|
||||||
|
start=start,
|
||||||
|
end=j,
|
||||||
|
size=2**max_2power)
|
||||||
|
if np.isinf(left_sum):
|
||||||
|
candidate_sum.append(-np.inf)
|
||||||
|
continue # Early pruning for dynamic programming
|
||||||
|
right_sum = _tree_sensitivity_square_sum(
|
||||||
|
num_participation=i,
|
||||||
|
min_separation=min_separation,
|
||||||
|
start=j,
|
||||||
|
end=end,
|
||||||
|
size=size - 2**max_2power)
|
||||||
|
candidate_sum.append(left_sum + right_sum)
|
||||||
|
sum_value += max(candidate_sum)
|
||||||
|
return sum_value
|
||||||
|
|
||||||
|
|
||||||
|
def _max_tree_sensitivity_square_sum(max_participation: int,
|
||||||
|
min_separation: int, steps: int) -> float:
|
||||||
|
"""Compute the worst-case sum of sensitivity square in tree aggregation.
|
||||||
|
|
||||||
|
See Appendix D.2 of
|
||||||
|
"Practical and Private (Deep) Learning without Sampling or Shuffling"
|
||||||
|
https://arxiv.org/abs/2103.00039.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_participation: The maximum number of times a sample will appear.
|
||||||
|
min_separation: The minimum number of nodes between two appearance of a
|
||||||
|
sample. If a sample appears in consecutive x, y steps in a streaming
|
||||||
|
setting, then `min_separation=y-x-1`.
|
||||||
|
steps: Total number of steps (leaf nodes in tree aggregation).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The worst-case sum of sensitivity square for the given input.
|
||||||
|
"""
|
||||||
|
num_participation = max_participation
|
||||||
|
while not _check_possible_tree_participation(
|
||||||
|
num_participation, min_separation, 0, min_separation, steps):
|
||||||
|
num_participation -= 1
|
||||||
|
candidate_sum = []
|
||||||
|
for num_part in range(1, num_participation + 1):
|
||||||
|
candidate_sum.append(
|
||||||
|
_tree_sensitivity_square_sum(num_part, min_separation, 0,
|
||||||
|
min_separation, steps))
|
||||||
|
return max(candidate_sum)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_gaussian_rdp(sigma: float, sum_sensitivity_square: float,
|
||||||
|
alpha: float) -> float:
|
||||||
|
"""Computes RDP of Gaussian mechanism."""
|
||||||
|
if np.isinf(alpha):
|
||||||
|
return np.inf
|
||||||
|
return alpha * sum_sensitivity_square / (2 * sigma**2)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_rdp_single_tree(
|
||||||
|
noise_multiplier: float, total_steps: int, max_participation: int,
|
||||||
|
min_separation: int,
|
||||||
|
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
|
||||||
|
"""Computes RDP of the Tree Aggregation Protocol for a single tree.
|
||||||
|
|
||||||
|
The accounting assume a single tree is constructed for `total_steps` leaf
|
||||||
|
nodes, where the same sample will appear at most `max_participation` times,
|
||||||
|
and there are at least `min_separation` nodes between two appearance. The key
|
||||||
|
idea is to (recurrently) count the worst-case occurence of a sample
|
||||||
|
in all the nodes in a tree, which implements a dynamic programming algorithm
|
||||||
|
that exhausts the possible `num_participation` appearance of a sample in
|
||||||
|
`steps` leaf nodes.
|
||||||
|
|
||||||
|
See Appendix D of
|
||||||
|
"Practical and Private (Deep) Learning without Sampling or Shuffling"
|
||||||
|
https://arxiv.org/abs/2103.00039.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
noise_multiplier: A non-negative float representing the ratio of the
|
||||||
|
standard deviation of the Gaussian noise to the l2-sensitivity of the
|
||||||
|
function to which it is added.
|
||||||
|
total_steps: Total number of steps (leaf nodes in tree aggregation).
|
||||||
|
max_participation: The maximum number of times a sample can appear.
|
||||||
|
min_separation: The minimum number of nodes between two appearance of a
|
||||||
|
sample. If a sample appears in consecutive x, y steps in a streaming
|
||||||
|
setting, then `min_separation=y-x-1`.
|
||||||
|
orders: An array (or a scalar) of RDP orders.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The RDPs at all orders. Can be `np.inf`.
|
||||||
|
"""
|
||||||
|
_check_nonnegative(noise_multiplier, "noise_multiplier")
|
||||||
|
if noise_multiplier == 0:
|
||||||
|
return np.inf
|
||||||
|
_check_nonnegative(total_steps, "total_steps")
|
||||||
|
_check_nonnegative(max_participation, "max_participation")
|
||||||
|
_check_nonnegative(min_separation, "min_separation")
|
||||||
|
sum_sensitivity_square = _max_tree_sensitivity_square_sum(
|
||||||
|
max_participation, min_separation, total_steps)
|
||||||
|
if np.isscalar(orders):
|
||||||
|
rdp = _compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square,
|
||||||
|
orders)
|
||||||
|
else:
|
||||||
|
rdp = np.array([
|
||||||
|
_compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square, alpha)
|
||||||
|
for alpha in orders
|
||||||
|
])
|
||||||
|
return rdp
|
|
@ -0,0 +1,185 @@
|
||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for rdp_accountant.py."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.analysis import rdp_accountant
|
||||||
|
from tensorflow_privacy.privacy.analysis import tree_aggregation_accountant
|
||||||
|
|
||||||
|
|
||||||
|
class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('eps20', 1.13, 19.74), ('eps2', 8.83, 2.04))
|
||||||
|
def test_compute_eps_tree(self, noise_multiplier, eps):
|
||||||
|
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
|
||||||
|
# This tests is based on the StackOverflow setting in "Practical and
|
||||||
|
# Private (Deep) Learning without Sampling or Shuffling". The calculated
|
||||||
|
# epsilon could be better as the method in this package keeps improving.
|
||||||
|
steps_list, target_delta = 1600, 1e-6
|
||||||
|
rdp = tree_aggregation_accountant.compute_rdp_tree_restart(
|
||||||
|
noise_multiplier, steps_list, orders)
|
||||||
|
new_eps = rdp_accountant.get_privacy_spent(
|
||||||
|
orders, rdp, target_delta=target_delta)[0]
|
||||||
|
self.assertLess(new_eps, eps)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('restart4', [400] * 4),
|
||||||
|
('restart2', [800] * 2),
|
||||||
|
('adaptive', [10, 400, 400, 400, 390]),
|
||||||
|
)
|
||||||
|
def test_compose_tree_rdp(self, steps_list):
|
||||||
|
noise_multiplier, orders = 0.1, 1
|
||||||
|
rdp_list = [
|
||||||
|
tree_aggregation_accountant.compute_rdp_tree_restart(
|
||||||
|
noise_multiplier, steps, orders) for steps in steps_list
|
||||||
|
]
|
||||||
|
rdp_composed = tree_aggregation_accountant.compute_rdp_tree_restart(
|
||||||
|
noise_multiplier, steps_list, orders)
|
||||||
|
self.assertAllClose(rdp_composed, sum(rdp_list), rtol=1e-12)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('restart4', [400] * 4),
|
||||||
|
('restart2', [800] * 2),
|
||||||
|
('adaptive', [10, 400, 400, 400, 390]),
|
||||||
|
)
|
||||||
|
def test_compute_eps_tree_decreasing(self, steps_list):
|
||||||
|
# Test privacy epsilon decreases with noise multiplier increasing when
|
||||||
|
# keeping other parameters the same.
|
||||||
|
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
|
||||||
|
target_delta = 1e-6
|
||||||
|
prev_eps = tree_aggregation_accountant.compute_rdp_tree_restart(
|
||||||
|
0, steps_list, orders)
|
||||||
|
for noise_multiplier in [0.1 * x for x in range(1, 100, 5)]:
|
||||||
|
rdp = tree_aggregation_accountant.compute_rdp_tree_restart(
|
||||||
|
noise_multiplier, steps_list, orders)
|
||||||
|
eps = rdp_accountant.get_privacy_spent(
|
||||||
|
orders, rdp, target_delta=target_delta)[0]
|
||||||
|
self.assertLess(eps, prev_eps)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('negative_noise', -1, 3, 1),
|
||||||
|
('empty_steps', 1, [], 1),
|
||||||
|
('negative_steps', 1, -3, 1),
|
||||||
|
)
|
||||||
|
def test_compute_rdp_tree_restart_raise(self, noise_multiplier, steps_list,
|
||||||
|
orders):
|
||||||
|
with self.assertRaisesRegex(ValueError, 'must be'):
|
||||||
|
tree_aggregation_accountant.compute_rdp_tree_restart(
|
||||||
|
noise_multiplier, steps_list, orders)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('t100n0.1', 100, 0.1),
|
||||||
|
('t1000n0.01', 1000, 0.01),
|
||||||
|
)
|
||||||
|
def test_no_tree_no_sampling(self, total_steps, noise_multiplier):
|
||||||
|
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
|
||||||
|
tree_rdp = tree_aggregation_accountant.compute_rdp_tree_restart(
|
||||||
|
noise_multiplier, [1] * total_steps, orders)
|
||||||
|
rdp = rdp_accountant.compute_rdp(1., noise_multiplier, total_steps, orders)
|
||||||
|
self.assertAllClose(tree_rdp, rdp, rtol=1e-12)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('negative_noise', -1, 3, 1, 1),
|
||||||
|
('negative_steps', 0.1, -3, 1, 1),
|
||||||
|
('negative_part', 0.1, 3, -1, 1),
|
||||||
|
('negative_sep', 0.1, 3, 1, -1),
|
||||||
|
)
|
||||||
|
def test_compute_rdp_single_tree_raise(self, noise_multiplier, total_steps,
|
||||||
|
max_participation, min_separation):
|
||||||
|
orders = 1
|
||||||
|
with self.assertRaisesRegex(ValueError, 'must be'):
|
||||||
|
tree_aggregation_accountant.compute_rdp_single_tree(
|
||||||
|
noise_multiplier, total_steps, max_participation, min_separation,
|
||||||
|
orders)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('3', 3),
|
||||||
|
('8', 8),
|
||||||
|
('11', 11),
|
||||||
|
('19', 19),
|
||||||
|
)
|
||||||
|
def test_max_tree_sensitivity_square_sum_every_step(self, steps):
|
||||||
|
max_participation, min_separation = steps, 0
|
||||||
|
# If a sample will appear in every leaf node, we can infer the total
|
||||||
|
# sensitivity by adding all the nodes.
|
||||||
|
steps_bin = bin(steps)[2:]
|
||||||
|
depth = [
|
||||||
|
len(steps_bin) - 1 - i for i, v in enumerate(steps_bin) if v == '1'
|
||||||
|
]
|
||||||
|
expected = sum([2**d * (2**(d + 1) - 1) for d in depth])
|
||||||
|
self.assertEqual(
|
||||||
|
expected,
|
||||||
|
tree_aggregation_accountant._max_tree_sensitivity_square_sum(
|
||||||
|
max_participation, min_separation, steps))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('11', 11),
|
||||||
|
('19', 19),
|
||||||
|
('200', 200),
|
||||||
|
)
|
||||||
|
def test_max_tree_sensitivity_square_sum_every_step_part(self, max_part):
|
||||||
|
steps, min_separation = 8, 0
|
||||||
|
assert max_part > steps
|
||||||
|
# If a sample will appear in every leaf node, we can infer the total
|
||||||
|
# sensitivity by adding all the nodes.
|
||||||
|
expected = 120
|
||||||
|
self.assertEqual(
|
||||||
|
expected,
|
||||||
|
tree_aggregation_accountant._max_tree_sensitivity_square_sum(
|
||||||
|
max_part, min_separation, steps))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('3', 3),
|
||||||
|
('8', 8),
|
||||||
|
('11', 11),
|
||||||
|
('19', 19),
|
||||||
|
)
|
||||||
|
def test_max_tree_sensitivity_square_sum_every_step_part2(self, steps):
|
||||||
|
max_participation, min_separation = 2, 0
|
||||||
|
# If a sample will appear twice, the worst case is to put the two nodes at
|
||||||
|
# consecutive nodes of the deepest subtree.
|
||||||
|
steps_bin = bin(steps)[2:]
|
||||||
|
depth = len(steps_bin) - 1
|
||||||
|
expected = 2 + 4 * depth
|
||||||
|
self.assertEqual(
|
||||||
|
expected,
|
||||||
|
tree_aggregation_accountant._max_tree_sensitivity_square_sum(
|
||||||
|
max_participation, min_separation, steps))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('test1', 1, 7, 8, 4),
|
||||||
|
('test2', 3, 3, 9, 11),
|
||||||
|
('test3', 3, 2, 7, 9),
|
||||||
|
# This is an example showing worst-case sensitivity is larger than greedy
|
||||||
|
# in "Practical and Private (Deep) Learning without Sampling or Shuffling"
|
||||||
|
# https://arxiv.org/abs/2103.00039.
|
||||||
|
('test4', 8, 2, 24, 88),
|
||||||
|
)
|
||||||
|
def test_max_tree_sensitivity_square_sum_toy(self, max_participation,
|
||||||
|
min_separation, steps, expected):
|
||||||
|
self.assertEqual(
|
||||||
|
expected,
|
||||||
|
tree_aggregation_accountant._max_tree_sensitivity_square_sum(
|
||||||
|
max_participation, min_separation, steps))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
Loading…
Reference in a new issue