COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/privacy/pull/230 from npapernot:hyperparam 8835b9c4072e3e598aa49d605e7643a2c2e65988
PiperOrigin-RevId: 446832781
This commit is contained in:
parent
930c4d13e8
commit
7eea74a6a1
9 changed files with 935 additions and 17 deletions
30
research/hyperparameters_2022/README.md
Normal file
30
research/hyperparameters_2022/README.md
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
# Hyperparameter Tuning with Renyi Differential Privacy
|
||||||
|
|
||||||
|
### Nicolas Papernot and Thomas Steinke
|
||||||
|
|
||||||
|
This repository contains the code used to reproduce some of the experiments in
|
||||||
|
our
|
||||||
|
[ICLR 2022 paper on hyperparameter tuning with differential privacy](https://openreview.net/forum?id=-70L8lpp9DF).
|
||||||
|
|
||||||
|
You can reproduce Figure 7 in the paper by running `figure7.py`. It loads by
|
||||||
|
default values used to plot the figure contained in the paper, and we also
|
||||||
|
included a dictionary `lr_acc.json` containing the accuracy of a large number of
|
||||||
|
ML models trained with different learning rates. If you'd like to try our
|
||||||
|
approach to fine-tune your own parameters, you will have to modify the code that
|
||||||
|
interacts with this dictionary (`lr_acc` in the code from `figure7.py`).
|
||||||
|
|
||||||
|
## Citing this work
|
||||||
|
|
||||||
|
If you use this repository for academic research, you are highly encouraged
|
||||||
|
(though not required) to cite our paper:
|
||||||
|
|
||||||
|
```
|
||||||
|
@inproceedings{
|
||||||
|
papernot2022hyperparameter,
|
||||||
|
title={Hyperparameter Tuning with Renyi Differential Privacy},
|
||||||
|
author={Nicolas Papernot and Thomas Steinke},
|
||||||
|
booktitle={International Conference on Learning Representations},
|
||||||
|
year={2022},
|
||||||
|
url={https://openreview.net/forum?id=-70L8lpp9DF}
|
||||||
|
}
|
||||||
|
```
|
199
research/hyperparameters_2022/figure7.py
Normal file
199
research/hyperparameters_2022/figure7.py
Normal file
|
@ -0,0 +1,199 @@
|
||||||
|
# Copyright 2022, The TensorFlow Privacy Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Code for reproducing Figure 7 of paper."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import rdp_accountant
|
||||||
|
|
||||||
|
# pylint: disable=bare-except
|
||||||
|
# pylint: disable=g-import-not-at-top
|
||||||
|
# pylint: disable=g-multiple-import
|
||||||
|
# pylint: disable=missing-function-docstring
|
||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
|
####################################################
|
||||||
|
# This file loads default values to reproduce
|
||||||
|
# figure 7 from the paper. If you'd like to
|
||||||
|
# provide your own value, modify the variables
|
||||||
|
# in the if statement controlled by this variable.
|
||||||
|
####################################################
|
||||||
|
|
||||||
|
load_values_to_reproduce_paper_fig = True
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_logarithmic_rdp(orders, rdp, gamma):
|
||||||
|
n = len(orders)
|
||||||
|
assert len(rdp) == n
|
||||||
|
assert min(orders) >= 1
|
||||||
|
rdp_out = [None] * n
|
||||||
|
for i in range(n):
|
||||||
|
if orders[i] == 1:
|
||||||
|
continue # unfortunately the formula doesn't work in this case
|
||||||
|
for j in range(n):
|
||||||
|
# Compute (orders[i],eps)-RDP bound on A_gamma given that Q satisfies
|
||||||
|
# (orders[i],rdp[i])-RDP and (orders[j],rdp[j])-RDP
|
||||||
|
eps = rdp[i] + (
|
||||||
|
1 - 1 / orders[j]) * rdp[j] + math.log(1 / gamma - 1) / orders[j] + (
|
||||||
|
math.log(1 / gamma - 1) - math.log(math.log(1 / gamma))) / (
|
||||||
|
orders[i] - 1)
|
||||||
|
if rdp_out[i] is None or eps < rdp_out[i]:
|
||||||
|
rdp_out[i] = eps
|
||||||
|
return rdp_out
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_geometric_rdp(orders, rdp, gamma):
|
||||||
|
n = len(orders)
|
||||||
|
assert len(rdp) == n
|
||||||
|
assert min(orders) >= 1
|
||||||
|
rdp_out = [None] * n
|
||||||
|
for i in range(n):
|
||||||
|
if orders[i] == 1:
|
||||||
|
continue # formula doesn't work in this case
|
||||||
|
for j in range(n):
|
||||||
|
eps = rdp[i] + 2 * (1 - 1 / orders[j]) * rdp[j] + (
|
||||||
|
2 / orders[j] + 1 / (orders[i] - 1)) * math.log(1 / gamma)
|
||||||
|
if rdp_out[i] is None or eps < rdp_out[i]:
|
||||||
|
rdp_out[i] = eps
|
||||||
|
return rdp_out
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_negativebinomial_rdp(orders, rdp, gamma, eta):
|
||||||
|
n = len(orders)
|
||||||
|
assert len(rdp) == n
|
||||||
|
assert min(orders) >= 1
|
||||||
|
assert 0 < gamma < 1
|
||||||
|
assert eta > 0
|
||||||
|
rdp_out = [None] * n
|
||||||
|
# foo = log(eta/(1-gamma^eta))
|
||||||
|
foo = math.log(eta) - math.log1p(-math.pow(gamma, eta))
|
||||||
|
for i in range(n):
|
||||||
|
if orders[i] == 1:
|
||||||
|
continue # forumla doesn't work for lambda=1
|
||||||
|
for j in range(n):
|
||||||
|
eps = rdp[i] + (1 + eta) * (1 - 1 / orders[j]) * rdp[j] - (
|
||||||
|
(1 + eta) / orders[j] + 1 /
|
||||||
|
(orders[i] - 1)) * math.log(gamma) + foo / (orders[i] - 1) + (
|
||||||
|
1 + eta) * math.log1p(-gamma) / orders[j]
|
||||||
|
if rdp_out[i] is None or eps < rdp_out[i]:
|
||||||
|
rdp_out[i] = eps
|
||||||
|
return rdp_out
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_poisson_rdp(orders, rdp, tau):
|
||||||
|
n = len(orders)
|
||||||
|
assert len(rdp) == n
|
||||||
|
assert min(orders) >= 1
|
||||||
|
rdp_out = [None] * n
|
||||||
|
for i in range(n):
|
||||||
|
if orders[i] == 1:
|
||||||
|
continue # forumula doesn't work with lambda=1
|
||||||
|
_, delta, _ = rdp_accountant.get_privacy_spent(
|
||||||
|
orders, rdp, target_eps=math.log1p(1 / (orders[i] - 1)))
|
||||||
|
rdp_out[i] = rdp[i] + tau * delta + math.log(tau) / (orders[i] - 1)
|
||||||
|
return rdp_out
|
||||||
|
|
||||||
|
|
||||||
|
if load_values_to_reproduce_paper_fig:
|
||||||
|
from figure7_default_values import orders, rdp, lr_acc, num_trials, lr_rates, gammas, non_private_acc
|
||||||
|
else:
|
||||||
|
orders = [] # Complete with the list of orders
|
||||||
|
rdp = [] # Complete with the list of RDP
|
||||||
|
lr_acc = {} # Complete with a dictionary such that keys
|
||||||
|
# are learning rates and values are the
|
||||||
|
# corresponding model's accuracy
|
||||||
|
num_trials = 1000 # num_trials to average results over
|
||||||
|
lr_rates = np.asarray([]) # 1D array of learning rate candidates
|
||||||
|
gammas = np.asarray(
|
||||||
|
[]) # 1D array of gamma parameters to the random distributions.
|
||||||
|
non_private_acc = 1. # accuracy of a non-private run (for plotting only)
|
||||||
|
|
||||||
|
for dist_id in range(4):
|
||||||
|
res_x = np.zeros_like(gammas)
|
||||||
|
res_y = np.zeros_like(res_x)
|
||||||
|
res_y_max = non_private_acc * np.ones_like(res_x)
|
||||||
|
for gamma_id, gamma in enumerate(gammas):
|
||||||
|
expected = (1 / gamma - 1) / np.log(1 / gamma)
|
||||||
|
best_acc_trials = []
|
||||||
|
for trial in range(num_trials):
|
||||||
|
if dist_id == 0:
|
||||||
|
K = np.random.logseries(1 - gamma)
|
||||||
|
label = 'logarithmic distribution $\\eta=0$'
|
||||||
|
color = 'b'
|
||||||
|
eps = repeat_logarithmic_rdp(orders, rdp, gamma)
|
||||||
|
elif dist_id == 1:
|
||||||
|
if load_values_to_reproduce_paper_fig and gamma < 1e-4:
|
||||||
|
continue
|
||||||
|
K = np.random.geometric(gamma)
|
||||||
|
label = 'geometric distribution $\\eta=1$'
|
||||||
|
color = 'g'
|
||||||
|
eps = repeat_geometric_rdp(orders, rdp, gamma)
|
||||||
|
elif dist_id == 2:
|
||||||
|
if load_values_to_reproduce_paper_fig and gamma < 1e-07:
|
||||||
|
continue
|
||||||
|
eta = 0.5
|
||||||
|
K = 0
|
||||||
|
while K == 0:
|
||||||
|
K = np.random.negative_binomial(eta, gamma)
|
||||||
|
label = 'negative binomial $\\eta=0.5$'
|
||||||
|
color = 'k'
|
||||||
|
eps = repeat_negativebinomial_rdp(orders, rdp, gamma, eta)
|
||||||
|
elif dist_id == 3:
|
||||||
|
if load_values_to_reproduce_paper_fig and gamma < 0.0015:
|
||||||
|
continue
|
||||||
|
gamma_factor = 100
|
||||||
|
K = np.random.poisson(gamma * gamma_factor)
|
||||||
|
label = 'poisson distribution'
|
||||||
|
color = 'm'
|
||||||
|
eps = repeat_poisson_rdp(orders, rdp, gamma * gamma_factor)
|
||||||
|
best_acc = 0.
|
||||||
|
best_lr = -1.
|
||||||
|
for k in range(K):
|
||||||
|
# pick a hyperparam candidate uniformly at random
|
||||||
|
j = np.random.randint(0, len(lr_rates))
|
||||||
|
lr_candidate = lr_rates[j]
|
||||||
|
try:
|
||||||
|
acc = lr_acc[str(lr_candidate)]
|
||||||
|
except:
|
||||||
|
print('lr - acc pair missing for ' + str(lr_candidate))
|
||||||
|
acc = 0.
|
||||||
|
if best_acc < acc:
|
||||||
|
best_acc = acc
|
||||||
|
best_lr = lr_candidate
|
||||||
|
best_acc_trials.append(best_acc)
|
||||||
|
try:
|
||||||
|
res_x[gamma_id] = np.min(eps)
|
||||||
|
res_y[gamma_id] = np.mean(best_acc_trials)
|
||||||
|
except:
|
||||||
|
print('skipping ' + str(gamma_id))
|
||||||
|
if dist_id == 0:
|
||||||
|
plt.hlines(
|
||||||
|
res_y_max[0],
|
||||||
|
xmin=-1.,
|
||||||
|
xmax=20.,
|
||||||
|
color='r',
|
||||||
|
label='baseline (non-private search)')
|
||||||
|
if dist_id >= 1:
|
||||||
|
res_x = res_x[2:]
|
||||||
|
res_y = res_y[2:]
|
||||||
|
plt.plot(res_x, res_y, label=label, color=color)
|
||||||
|
|
||||||
|
if load_values_to_reproduce_paper_fig:
|
||||||
|
plt.xlim([0.5, 8.])
|
||||||
|
plt.ylim([0.85, 0.97])
|
||||||
|
plt.xlabel('Privacy budget')
|
||||||
|
plt.ylabel('Model Accuracy for Best Hyperparameter')
|
||||||
|
plt.legend(loc='lower right')
|
||||||
|
plt.savefig('rdp_hyper_search.pdf', bbox_inches='tight')
|
47
research/hyperparameters_2022/figure7_default_values.py
Normal file
47
research/hyperparameters_2022/figure7_default_values.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
# Copyright 2022, The TensorFlow Privacy Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Default values for generating Figure 7."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] +
|
||||||
|
list(range(5, 64)) + [128, 256, 512])
|
||||||
|
rdp = [
|
||||||
|
2.04459751e-01, 2.45818210e-01, 2.87335988e-01, 3.29014798e-01,
|
||||||
|
3.70856385e-01, 4.12862542e-01, 4.97375951e-01, 5.82570265e-01,
|
||||||
|
6.68461534e-01, 7.55066706e-01, 8.42403732e-01, 1.01935100e+00,
|
||||||
|
1.19947313e+00, 1.38297035e+00, 1.57009549e+00, 1.76124790e+00,
|
||||||
|
1.95794503e+00, 2.19017390e+00, 4.48407479e+00, 3.08305394e+02,
|
||||||
|
4.98610133e+03, 1.11363692e+04, 1.72590079e+04, 2.33487231e+04,
|
||||||
|
2.94091123e+04, 3.54439803e+04, 4.14567914e+04, 4.74505356e+04,
|
||||||
|
5.34277419e+04, 5.93905358e+04, 6.53407051e+04, 7.12797586e+04,
|
||||||
|
7.72089762e+04, 8.31294496e+04, 8.90421151e+04, 9.49477802e+04,
|
||||||
|
1.00847145e+05, 1.06740819e+05, 1.12629335e+05, 1.18513163e+05,
|
||||||
|
1.24392717e+05, 1.30268362e+05, 1.36140424e+05, 1.42009194e+05,
|
||||||
|
1.47874932e+05, 1.53737871e+05, 1.59598221e+05, 1.65456171e+05,
|
||||||
|
1.71311893e+05, 1.77165542e+05, 1.83017260e+05, 1.88867175e+05,
|
||||||
|
1.94715404e+05, 2.00562057e+05, 2.06407230e+05, 2.12251015e+05,
|
||||||
|
2.18093495e+05, 2.23934746e+05, 2.29774840e+05, 2.35613842e+05,
|
||||||
|
2.41451813e+05, 2.47288808e+05, 2.53124881e+05, 2.58960080e+05,
|
||||||
|
2.64794449e+05, 2.70628032e+05, 2.76460867e+05, 2.82292992e+05,
|
||||||
|
2.88124440e+05, 6.66483142e+05, 1.41061455e+06, 2.89842152e+06
|
||||||
|
]
|
||||||
|
with open("lr_acc.json", "r") as dict_f:
|
||||||
|
lr_acc = json.load(dict_f)
|
||||||
|
num_trials = 1000
|
||||||
|
lr_rates = np.logspace(np.log10(1e-4), np.log10(1.), num=1000)[-400:]
|
||||||
|
gammas = np.asarray(
|
||||||
|
[1e-07, 8e-06, 1e-04, 0.00024, 0.0015, 0.0035, 0.025, 0.05, 0.1, 0.2, 0.5])
|
||||||
|
non_private_acc = 0.9594
|
14
research/hyperparameters_2022/lr_acc.json
Normal file
14
research/hyperparameters_2022/lr_acc.json
Normal file
File diff suppressed because one or more lines are too long
622
research/hyperparameters_2022/rdp_accountant.py
Normal file
622
research/hyperparameters_2022/rdp_accountant.py
Normal file
|
@ -0,0 +1,622 @@
|
||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""RDP analysis of the Sampled Gaussian Mechanism.
|
||||||
|
|
||||||
|
Functionality for computing Renyi differential privacy (RDP) of an additive
|
||||||
|
Sampled Gaussian Mechanism (SGM). Its public interface consists of two methods:
|
||||||
|
compute_rdp(q, noise_multiplier, T, orders) computes RDP for SGM iterated
|
||||||
|
T times.
|
||||||
|
get_privacy_spent(orders, rdp, target_eps, target_delta) computes delta
|
||||||
|
(or eps) given RDP at multiple orders and
|
||||||
|
a target value for eps (or delta).
|
||||||
|
|
||||||
|
Example use:
|
||||||
|
|
||||||
|
Suppose that we have run an SGM applied to a function with l2-sensitivity 1.
|
||||||
|
Its parameters are given as a list of tuples (q1, sigma1, T1), ...,
|
||||||
|
(qk, sigma_k, Tk), and we wish to compute eps for a given delta.
|
||||||
|
The example code would be:
|
||||||
|
|
||||||
|
max_order = 32
|
||||||
|
orders = range(2, max_order + 1)
|
||||||
|
rdp = np.zeros_like(orders, dtype=float)
|
||||||
|
for q, sigma, T in parameters:
|
||||||
|
rdp += rdp_accountant.compute_rdp(q, sigma, T, orders)
|
||||||
|
eps, _, opt_order = rdp_accountant.get_privacy_spent(rdp, target_delta=delta)
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy import special
|
||||||
|
import six
|
||||||
|
|
||||||
|
########################
|
||||||
|
# LOG-SPACE ARITHMETIC #
|
||||||
|
########################
|
||||||
|
|
||||||
|
|
||||||
|
def _log_add(logx, logy):
|
||||||
|
"""Add two numbers in the log space."""
|
||||||
|
a, b = min(logx, logy), max(logx, logy)
|
||||||
|
if a == -np.inf: # adding 0
|
||||||
|
return b
|
||||||
|
# Use exp(a) + exp(b) = (exp(a - b) + 1) * exp(b)
|
||||||
|
return math.log1p(math.exp(a - b)) + b # log1p(x) = log(x + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_sub(logx, logy):
|
||||||
|
"""Subtract two numbers in the log space. Answer must be non-negative."""
|
||||||
|
if logx < logy:
|
||||||
|
raise ValueError("The result of subtraction must be non-negative.")
|
||||||
|
if logy == -np.inf: # subtracting 0
|
||||||
|
return logx
|
||||||
|
if logx == logy:
|
||||||
|
return -np.inf # 0 is represented as -np.inf in the log space.
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y).
|
||||||
|
return math.log(math.expm1(logx - logy)) + logy # expm1(x) = exp(x) - 1
|
||||||
|
except OverflowError:
|
||||||
|
return logx
|
||||||
|
|
||||||
|
|
||||||
|
def _log_sub_sign(logx, logy):
|
||||||
|
"""Returns log(exp(logx)-exp(logy)) and its sign."""
|
||||||
|
if logx > logy:
|
||||||
|
s = True
|
||||||
|
mag = logx + np.log(1 - np.exp(logy - logx))
|
||||||
|
elif logx < logy:
|
||||||
|
s = False
|
||||||
|
mag = logy + np.log(1 - np.exp(logx - logy))
|
||||||
|
else:
|
||||||
|
s = True
|
||||||
|
mag = -np.inf
|
||||||
|
|
||||||
|
return s, mag
|
||||||
|
|
||||||
|
|
||||||
|
def _log_print(logx):
|
||||||
|
"""Pretty print."""
|
||||||
|
if logx < math.log(sys.float_info.max):
|
||||||
|
return "{}".format(math.exp(logx))
|
||||||
|
else:
|
||||||
|
return "exp({})".format(logx)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_comb(n, k):
|
||||||
|
return (special.gammaln(n + 1) - special.gammaln(k + 1) -
|
||||||
|
special.gammaln(n - k + 1))
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_log_a_int(q, sigma, alpha):
|
||||||
|
"""Compute log(A_alpha) for integer alpha. 0 < q < 1."""
|
||||||
|
assert isinstance(alpha, six.integer_types)
|
||||||
|
|
||||||
|
# Initialize with 0 in the log space.
|
||||||
|
log_a = -np.inf
|
||||||
|
|
||||||
|
for i in range(alpha + 1):
|
||||||
|
log_coef_i = (
|
||||||
|
_log_comb(alpha, i) + i * math.log(q) + (alpha - i) * math.log(1 - q))
|
||||||
|
|
||||||
|
s = log_coef_i + (i * i - i) / (2 * (sigma**2))
|
||||||
|
log_a = _log_add(log_a, s)
|
||||||
|
|
||||||
|
return float(log_a)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_log_a_frac(q, sigma, alpha):
|
||||||
|
"""Compute log(A_alpha) for fractional alpha. 0 < q < 1."""
|
||||||
|
# The two parts of A_alpha, integrals over (-inf,z0] and [z0, +inf), are
|
||||||
|
# initialized to 0 in the log space:
|
||||||
|
log_a0, log_a1 = -np.inf, -np.inf
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
z0 = sigma**2 * math.log(1 / q - 1) + .5
|
||||||
|
|
||||||
|
while True: # do ... until loop
|
||||||
|
coef = special.binom(alpha, i)
|
||||||
|
log_coef = math.log(abs(coef))
|
||||||
|
j = alpha - i
|
||||||
|
|
||||||
|
log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q)
|
||||||
|
log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q)
|
||||||
|
|
||||||
|
log_e0 = math.log(.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma))
|
||||||
|
log_e1 = math.log(.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma))
|
||||||
|
|
||||||
|
log_s0 = log_t0 + (i * i - i) / (2 * (sigma**2)) + log_e0
|
||||||
|
log_s1 = log_t1 + (j * j - j) / (2 * (sigma**2)) + log_e1
|
||||||
|
|
||||||
|
if coef > 0:
|
||||||
|
log_a0 = _log_add(log_a0, log_s0)
|
||||||
|
log_a1 = _log_add(log_a1, log_s1)
|
||||||
|
else:
|
||||||
|
log_a0 = _log_sub(log_a0, log_s0)
|
||||||
|
log_a1 = _log_sub(log_a1, log_s1)
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
if max(log_s0, log_s1) < -30:
|
||||||
|
break
|
||||||
|
|
||||||
|
return _log_add(log_a0, log_a1)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_log_a(q, sigma, alpha):
|
||||||
|
"""Compute log(A_alpha) for any positive finite alpha."""
|
||||||
|
if float(alpha).is_integer():
|
||||||
|
return _compute_log_a_int(q, sigma, int(alpha))
|
||||||
|
else:
|
||||||
|
return _compute_log_a_frac(q, sigma, alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_erfc(x):
|
||||||
|
"""Compute log(erfc(x)) with high accuracy for large x."""
|
||||||
|
try:
|
||||||
|
return math.log(2) + special.log_ndtr(-x * 2**.5)
|
||||||
|
except NameError:
|
||||||
|
# If log_ndtr is not available, approximate as follows:
|
||||||
|
r = special.erfc(x)
|
||||||
|
if r == 0.0:
|
||||||
|
# Using the Laurent series at infinity for the tail of the erfc function:
|
||||||
|
# erfc(x) ~ exp(-x^2-.5/x^2+.625/x^4)/(x*pi^.5)
|
||||||
|
# To verify in Mathematica:
|
||||||
|
# Series[Log[Erfc[x]] + Log[x] + Log[Pi]/2 + x^2, {x, Infinity, 6}]
|
||||||
|
return (-math.log(math.pi) / 2 - math.log(x) - x**2 - .5 * x**-2 +
|
||||||
|
.625 * x**-4 - 37. / 24. * x**-6 + 353. / 64. * x**-8)
|
||||||
|
else:
|
||||||
|
return math.log(r)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_delta(orders, rdp, eps):
|
||||||
|
"""Compute delta given a list of RDP values and target epsilon.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
orders: An array (or a scalar) of orders.
|
||||||
|
rdp: A list (or a scalar) of RDP guarantees.
|
||||||
|
eps: The target epsilon.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pair of (delta, optimal_order).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input is malformed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
orders_vec = np.atleast_1d(orders)
|
||||||
|
rdp_vec = np.atleast_1d(rdp)
|
||||||
|
|
||||||
|
if eps < 0:
|
||||||
|
raise ValueError("Value of privacy loss bound epsilon must be >=0.")
|
||||||
|
if len(orders_vec) != len(rdp_vec):
|
||||||
|
raise ValueError("Input lists must have the same length.")
|
||||||
|
|
||||||
|
# Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3):
|
||||||
|
# delta = min( np.exp((rdp_vec - eps) * (orders_vec - 1)) )
|
||||||
|
|
||||||
|
# Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4):
|
||||||
|
logdeltas = [] # work in log space to avoid overflows
|
||||||
|
for (a, r) in zip(orders_vec, rdp_vec):
|
||||||
|
if a < 1:
|
||||||
|
raise ValueError("Renyi divergence order must be >=1.")
|
||||||
|
if r < 0:
|
||||||
|
raise ValueError("Renyi divergence must be >=0.")
|
||||||
|
# For small alpha, we are better of with bound via KL divergence:
|
||||||
|
# delta <= sqrt(1-exp(-KL)).
|
||||||
|
# Take a min of the two bounds.
|
||||||
|
logdelta = 0.5 * math.log1p(-math.exp(-r))
|
||||||
|
if a > 1.01:
|
||||||
|
# This bound is not numerically stable as alpha->1.
|
||||||
|
# Thus we have a min value for alpha.
|
||||||
|
# The bound is also not useful for small alpha, so doesn't matter.
|
||||||
|
rdp_bound = (a - 1) * (r - eps + math.log1p(-1 / a)) - math.log(a)
|
||||||
|
logdelta = min(logdelta, rdp_bound)
|
||||||
|
|
||||||
|
logdeltas.append(logdelta)
|
||||||
|
|
||||||
|
idx_opt = np.argmin(logdeltas)
|
||||||
|
return min(math.exp(logdeltas[idx_opt]), 1.), orders_vec[idx_opt]
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_eps(orders, rdp, delta):
|
||||||
|
"""Compute epsilon given a list of RDP values and target delta.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
orders: An array (or a scalar) of orders.
|
||||||
|
rdp: A list (or a scalar) of RDP guarantees.
|
||||||
|
delta: The target delta.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pair of (eps, optimal_order).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input is malformed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
orders_vec = np.atleast_1d(orders)
|
||||||
|
rdp_vec = np.atleast_1d(rdp)
|
||||||
|
|
||||||
|
if delta <= 0:
|
||||||
|
raise ValueError("Privacy failure probability bound delta must be >0.")
|
||||||
|
if len(orders_vec) != len(rdp_vec):
|
||||||
|
raise ValueError("Input lists must have the same length.")
|
||||||
|
|
||||||
|
# Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3):
|
||||||
|
# eps = min( rdp_vec - math.log(delta) / (orders_vec - 1) )
|
||||||
|
|
||||||
|
# Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4).
|
||||||
|
# Also appears in https://arxiv.org/abs/2001.05990 Equation 20 (in v1).
|
||||||
|
eps_vec = []
|
||||||
|
for (a, r) in zip(orders_vec, rdp_vec):
|
||||||
|
if a < 1:
|
||||||
|
raise ValueError("Renyi divergence order must be >=1.")
|
||||||
|
if r < 0:
|
||||||
|
raise ValueError("Renyi divergence must be >=0.")
|
||||||
|
|
||||||
|
if delta**2 + math.expm1(-r) >= 0:
|
||||||
|
# In this case, we can simply bound via KL divergence:
|
||||||
|
# delta <= sqrt(1-exp(-KL)).
|
||||||
|
eps = 0 # No need to try further computation if we have eps = 0.
|
||||||
|
elif a > 1.01:
|
||||||
|
# This bound is not numerically stable as alpha->1.
|
||||||
|
# Thus we have a min value of alpha.
|
||||||
|
# The bound is also not useful for small alpha, so doesn't matter.
|
||||||
|
eps = r + math.log1p(-1 / a) - math.log(delta * a) / (a - 1)
|
||||||
|
else:
|
||||||
|
# In this case we can't do anything. E.g., asking for delta = 0.
|
||||||
|
eps = np.inf
|
||||||
|
eps_vec.append(eps)
|
||||||
|
|
||||||
|
idx_opt = np.argmin(eps_vec)
|
||||||
|
return max(0, eps_vec[idx_opt]), orders_vec[idx_opt]
|
||||||
|
|
||||||
|
|
||||||
|
def _stable_inplace_diff_in_log(vec, signs, n=-1):
|
||||||
|
"""Replaces the first n-1 dims of vec with the log of abs difference operator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vec: numpy array of floats with size larger than 'n'
|
||||||
|
signs: Optional numpy array of bools with the same size as vec in case one
|
||||||
|
needs to compute partial differences vec and signs jointly describe a
|
||||||
|
vector of real numbers' sign and abs in log scale.
|
||||||
|
n: Optonal upper bound on number of differences to compute. If negative, all
|
||||||
|
differences are computed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The first n-1 dimension of vec and signs will store the log-abs and sign of
|
||||||
|
the difference.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input is malformed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert vec.shape == signs.shape
|
||||||
|
if n < 0:
|
||||||
|
n = np.max(vec.shape) - 1
|
||||||
|
else:
|
||||||
|
assert np.max(vec.shape) >= n + 1
|
||||||
|
for j in range(0, n, 1):
|
||||||
|
if signs[j] == signs[j + 1]: # When the signs are the same
|
||||||
|
# if the signs are both positive, then we can just use the standard one
|
||||||
|
signs[j], vec[j] = _log_sub_sign(vec[j + 1], vec[j])
|
||||||
|
# otherwise, we do that but toggle the sign
|
||||||
|
if not signs[j + 1]:
|
||||||
|
signs[j] = ~signs[j]
|
||||||
|
else: # When the signs are different.
|
||||||
|
vec[j] = _log_add(vec[j], vec[j + 1])
|
||||||
|
signs[j] = signs[j + 1]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_forward_diffs(fun, n):
|
||||||
|
"""Computes up to nth order forward difference evaluated at 0.
|
||||||
|
|
||||||
|
See Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fun: Function to compute forward differences of.
|
||||||
|
n: Number of differences to compute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pair (deltas, signs_deltas) of the log deltas and their signs.
|
||||||
|
"""
|
||||||
|
func_vec = np.zeros(n + 3)
|
||||||
|
signs_func_vec = np.ones(n + 3, dtype=bool)
|
||||||
|
|
||||||
|
# ith coordinate of deltas stores log(abs(ith order discrete derivative))
|
||||||
|
deltas = np.zeros(n + 2)
|
||||||
|
signs_deltas = np.zeros(n + 2, dtype=bool)
|
||||||
|
for i in range(1, n + 3, 1):
|
||||||
|
func_vec[i] = fun(1.0 * (i - 1))
|
||||||
|
for i in range(0, n + 2, 1):
|
||||||
|
# Diff in log scale
|
||||||
|
_stable_inplace_diff_in_log(func_vec, signs_func_vec, n=n + 2 - i)
|
||||||
|
deltas[i] = func_vec[0]
|
||||||
|
signs_deltas[i] = signs_func_vec[0]
|
||||||
|
return deltas, signs_deltas
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_rdp(q, sigma, alpha):
|
||||||
|
"""Compute RDP of the Sampled Gaussian mechanism at order alpha.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: The sampling rate.
|
||||||
|
sigma: The std of the additive Gaussian noise.
|
||||||
|
alpha: The order at which RDP is computed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RDP at alpha, can be np.inf.
|
||||||
|
"""
|
||||||
|
if q == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if q == 1.:
|
||||||
|
return alpha / (2 * sigma**2)
|
||||||
|
|
||||||
|
if np.isinf(alpha):
|
||||||
|
return np.inf
|
||||||
|
|
||||||
|
return _compute_log_a(q, sigma, alpha) / (alpha - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_rdp(q, noise_multiplier, steps, orders):
|
||||||
|
"""Computes RDP of the Sampled Gaussian Mechanism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: The sampling rate.
|
||||||
|
noise_multiplier: The ratio of the standard deviation of the Gaussian noise
|
||||||
|
to the l2-sensitivity of the function to which it is added.
|
||||||
|
steps: The number of steps.
|
||||||
|
orders: An array (or a scalar) of RDP orders.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The RDPs at all orders. Can be `np.inf`.
|
||||||
|
"""
|
||||||
|
if np.isscalar(orders):
|
||||||
|
rdp = _compute_rdp(q, noise_multiplier, orders)
|
||||||
|
else:
|
||||||
|
rdp = np.array(
|
||||||
|
[_compute_rdp(q, noise_multiplier, order) for order in orders])
|
||||||
|
|
||||||
|
return rdp * steps
|
||||||
|
|
||||||
|
|
||||||
|
def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders):
|
||||||
|
"""Compute RDP of Gaussian Mechanism using sampling without replacement.
|
||||||
|
|
||||||
|
This function applies to the following schemes:
|
||||||
|
1. Sampling w/o replacement: Sample a uniformly random subset of size m = q*n.
|
||||||
|
2. ``Replace one data point'' version of differential privacy, i.e., n is
|
||||||
|
considered public information.
|
||||||
|
|
||||||
|
Reference: Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf (A strengthened
|
||||||
|
version applies subsampled-Gaussian mechanism)
|
||||||
|
- Wang, Balle, Kasiviswanathan. "Subsampled Renyi Differential Privacy and
|
||||||
|
Analytical Moments Accountant." AISTATS'2019.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: The sampling proportion = m / n. Assume m is an integer <= n.
|
||||||
|
noise_multiplier: The ratio of the standard deviation of the Gaussian noise
|
||||||
|
to the l2-sensitivity of the function to which it is added.
|
||||||
|
steps: The number of steps.
|
||||||
|
orders: An array (or a scalar) of RDP orders.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The RDPs at all orders, can be np.inf.
|
||||||
|
"""
|
||||||
|
if np.isscalar(orders):
|
||||||
|
rdp = _compute_rdp_sample_without_replacement_scalar(
|
||||||
|
q, noise_multiplier, orders)
|
||||||
|
else:
|
||||||
|
rdp = np.array([
|
||||||
|
_compute_rdp_sample_without_replacement_scalar(q, noise_multiplier,
|
||||||
|
order)
|
||||||
|
for order in orders
|
||||||
|
])
|
||||||
|
|
||||||
|
return rdp * steps
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_rdp_sample_without_replacement_scalar(q, sigma, alpha):
|
||||||
|
"""Compute RDP of the Sampled Gaussian mechanism at order alpha.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: The sampling proportion = m / n. Assume m is an integer <= n.
|
||||||
|
sigma: The std of the additive Gaussian noise.
|
||||||
|
alpha: The order at which RDP is computed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RDP at alpha, can be np.inf.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert (q <= 1) and (q >= 0) and (alpha >= 1)
|
||||||
|
|
||||||
|
if q == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if q == 1.:
|
||||||
|
return alpha / (2 * sigma**2)
|
||||||
|
|
||||||
|
if np.isinf(alpha):
|
||||||
|
return np.inf
|
||||||
|
|
||||||
|
if float(alpha).is_integer():
|
||||||
|
return _compute_rdp_sample_without_replacement_int(q, sigma, alpha) / (
|
||||||
|
alpha - 1)
|
||||||
|
else:
|
||||||
|
# When alpha not an integer, we apply Corollary 10 of [WBK19] to interpolate
|
||||||
|
# the CGF and obtain an upper bound
|
||||||
|
alpha_f = math.floor(alpha)
|
||||||
|
alpha_c = math.ceil(alpha)
|
||||||
|
|
||||||
|
x = _compute_rdp_sample_without_replacement_int(q, sigma, alpha_f)
|
||||||
|
y = _compute_rdp_sample_without_replacement_int(q, sigma, alpha_c)
|
||||||
|
t = alpha - alpha_f
|
||||||
|
return ((1 - t) * x + t * y) / (alpha - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_rdp_sample_without_replacement_int(q, sigma, alpha):
|
||||||
|
"""Compute log(A_alpha) for integer alpha, subsampling without replacement.
|
||||||
|
|
||||||
|
When alpha is smaller than max_alpha, compute the bound Theorem 27 exactly,
|
||||||
|
otherwise compute the bound with Stirling approximation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: The sampling proportion = m / n. Assume m is an integer <= n.
|
||||||
|
sigma: The std of the additive Gaussian noise.
|
||||||
|
alpha: The order at which RDP is computed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RDP at alpha, can be np.inf.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_alpha = 256
|
||||||
|
assert isinstance(alpha, six.integer_types)
|
||||||
|
|
||||||
|
if np.isinf(alpha):
|
||||||
|
return np.inf
|
||||||
|
elif alpha == 1:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def cgf(x):
|
||||||
|
# Return rdp(x+1)*x, the rdp of Gaussian mechanism is alpha/(2*sigma**2)
|
||||||
|
return x * 1.0 * (x + 1) / (2.0 * sigma**2)
|
||||||
|
|
||||||
|
def func(x):
|
||||||
|
# Return the rdp of Gaussian mechanism
|
||||||
|
return 1.0 * x / (2.0 * sigma**2)
|
||||||
|
|
||||||
|
# Initialize with 1 in the log space.
|
||||||
|
log_a = 0
|
||||||
|
# Calculates the log term when alpha = 2
|
||||||
|
log_f2m1 = func(2.0) + np.log(1 - np.exp(-func(2.0)))
|
||||||
|
if alpha <= max_alpha:
|
||||||
|
# We need forward differences of exp(cgf)
|
||||||
|
# The following line is the numerically stable way of implementing it.
|
||||||
|
# The output is in polar form with logarithmic magnitude
|
||||||
|
deltas, _ = _get_forward_diffs(cgf, alpha)
|
||||||
|
# Compute the bound exactly requires book keeping of O(alpha**2)
|
||||||
|
|
||||||
|
for i in range(2, alpha + 1):
|
||||||
|
if i == 2:
|
||||||
|
s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(
|
||||||
|
np.log(4) + log_f2m1,
|
||||||
|
func(2.0) + np.log(2))
|
||||||
|
elif i > 2:
|
||||||
|
delta_lo = deltas[int(2 * np.floor(i / 2.0)) - 1]
|
||||||
|
delta_hi = deltas[int(2 * np.ceil(i / 2.0)) - 1]
|
||||||
|
s = np.log(4) + 0.5 * (delta_lo + delta_hi)
|
||||||
|
s = np.minimum(s, np.log(2) + cgf(i - 1))
|
||||||
|
s += i * np.log(q) + _log_comb(alpha, i)
|
||||||
|
log_a = _log_add(log_a, s)
|
||||||
|
return float(log_a)
|
||||||
|
else:
|
||||||
|
# Compute the bound with stirling approximation. Everything is O(x) now.
|
||||||
|
for i in range(2, alpha + 1):
|
||||||
|
if i == 2:
|
||||||
|
s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(
|
||||||
|
np.log(4) + log_f2m1,
|
||||||
|
func(2.0) + np.log(2))
|
||||||
|
else:
|
||||||
|
s = np.log(2) + cgf(i - 1) + i * np.log(q) + _log_comb(alpha, i)
|
||||||
|
log_a = _log_add(log_a, s)
|
||||||
|
|
||||||
|
return log_a
|
||||||
|
|
||||||
|
|
||||||
|
def compute_heterogenous_rdp(sampling_probabilities, noise_multipliers,
|
||||||
|
steps_list, orders):
|
||||||
|
"""Computes RDP of Heteregoneous Applications of Sampled Gaussian Mechanisms.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sampling_probabilities: A list containing the sampling rates.
|
||||||
|
noise_multipliers: A list containing the noise multipliers: the ratio of the
|
||||||
|
standard deviation of the Gaussian noise to the l2-sensitivity of the
|
||||||
|
function to which it is added.
|
||||||
|
steps_list: A list containing the number of steps at each
|
||||||
|
`sampling_probability` and `noise_multiplier`.
|
||||||
|
orders: An array (or a scalar) of RDP orders.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The RDPs at all orders. Can be `np.inf`.
|
||||||
|
"""
|
||||||
|
assert len(sampling_probabilities) == len(noise_multipliers)
|
||||||
|
|
||||||
|
rdp = 0
|
||||||
|
for q, noise_multiplier, steps in zip(sampling_probabilities,
|
||||||
|
noise_multipliers, steps_list):
|
||||||
|
rdp += compute_rdp(q, noise_multiplier, steps, orders)
|
||||||
|
|
||||||
|
return rdp
|
||||||
|
|
||||||
|
|
||||||
|
def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
|
||||||
|
"""Computes delta (or eps) for given eps (or delta) from RDP values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
orders: An array (or a scalar) of RDP orders.
|
||||||
|
rdp: An array of RDP values. Must be of the same length as the orders list.
|
||||||
|
target_eps: If not `None`, the epsilon for which we compute the
|
||||||
|
corresponding delta.
|
||||||
|
target_delta: If not `None`, the delta for which we compute the
|
||||||
|
corresponding epsilon. Exactly one of `target_eps` and `target_delta` must
|
||||||
|
be `None`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of epsilon, delta, and the optimal order.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If target_eps and target_delta are messed up.
|
||||||
|
"""
|
||||||
|
if target_eps is None and target_delta is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Exactly one out of eps and delta must be None. (Both are).")
|
||||||
|
|
||||||
|
if target_eps is not None and target_delta is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Exactly one out of eps and delta must be None. (None is).")
|
||||||
|
|
||||||
|
if target_eps is not None:
|
||||||
|
delta, opt_order = _compute_delta(orders, rdp, target_eps)
|
||||||
|
return target_eps, delta, opt_order
|
||||||
|
else:
|
||||||
|
eps, opt_order = _compute_eps(orders, rdp, target_delta)
|
||||||
|
return eps, target_delta, opt_order
|
||||||
|
|
||||||
|
|
||||||
|
def compute_rdp_from_ledger(ledger, orders):
|
||||||
|
"""Computes RDP of Sampled Gaussian Mechanism from ledger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ledger: A formatted privacy ledger.
|
||||||
|
orders: An array (or a scalar) of RDP orders.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RDP at all orders. Can be `np.inf`.
|
||||||
|
"""
|
||||||
|
total_rdp = np.zeros_like(orders, dtype=float)
|
||||||
|
for sample in ledger:
|
||||||
|
# Compute equivalent z from l2_clip_bounds and noise stddevs in sample.
|
||||||
|
# See https://arxiv.org/pdf/1812.06210.pdf for derivation of this formula.
|
||||||
|
effective_z = sum([
|
||||||
|
(q.noise_stddev / q.l2_norm_bound)**-2 for q in sample.queries
|
||||||
|
])**-0.5
|
||||||
|
total_rdp += compute_rdp(sample.selection_probability, effective_z, 1,
|
||||||
|
orders)
|
||||||
|
return total_rdp
|
|
@ -45,7 +45,10 @@ py_binary(
|
||||||
py_library(
|
py_library(
|
||||||
name = "compute_noise_from_budget_lib",
|
name = "compute_noise_from_budget_lib",
|
||||||
srcs = ["compute_noise_from_budget_lib.py"],
|
srcs = ["compute_noise_from_budget_lib.py"],
|
||||||
deps = [":rdp_accountant"],
|
deps = [
|
||||||
|
"@com_google_differential_py//python/dp_accounting:dp_event",
|
||||||
|
"@com_google_differential_py//python/dp_accounting/rdp:rdp_privacy_accountant",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
|
|
@ -19,22 +19,18 @@ import math
|
||||||
from absl import app
|
from absl import app
|
||||||
from scipy import optimize
|
from scipy import optimize
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp # pylint: disable=g-import-not-at-top
|
from com_google_differential_py.python.dp_accounting import dp_event
|
||||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
from com_google_differential_py.python.dp_accounting.rdp import rdp_privacy_accountant
|
||||||
|
|
||||||
|
|
||||||
def apply_dp_sgd_analysis(q, sigma, steps, orders, delta):
|
def apply_dp_sgd_analysis(q, sigma, steps, orders, delta):
|
||||||
"""Compute and print results of DP-SGD analysis."""
|
"""Compute and print results of DP-SGD analysis."""
|
||||||
|
|
||||||
# compute_rdp requires that sigma be the ratio of the standard deviation of
|
accountant = rdp_privacy_accountant.RdpAccountant(orders)
|
||||||
# the Gaussian noise to the l2-sensitivity of the function to which it is
|
event = dp_event.SelfComposedDpEvent(
|
||||||
# added. Hence, sigma here corresponds to the `noise_multiplier` parameter
|
dp_event.PoissonSampledDpEvent(q, dp_event.GaussianDpEvent(sigma)), steps)
|
||||||
# in the DP-SGD implementation found in privacy.optimizers.dp_optimizer
|
accountant.compose(event)
|
||||||
rdp = compute_rdp(q, sigma, steps, orders)
|
return accountant.get_epsilon_and_optimal_order(delta)
|
||||||
|
|
||||||
eps, _, opt_order = get_privacy_spent(orders, rdp, target_delta=delta)
|
|
||||||
|
|
||||||
return eps, opt_order
|
|
||||||
|
|
||||||
|
|
||||||
def compute_noise(n, batch_size, target_epsilon, epochs, delta, noise_lbd):
|
def compute_noise(n, batch_size, target_epsilon, epochs, delta, noise_lbd):
|
||||||
|
|
|
@ -94,6 +94,8 @@ py_binary(
|
||||||
"//tensorflow_privacy/privacy/optimizers:dp_optimizer",
|
"//tensorflow_privacy/privacy/optimizers:dp_optimizer",
|
||||||
"//third_party/py/tensorflow:tensorflow_compat_v1_estimator",
|
"//third_party/py/tensorflow:tensorflow_compat_v1_estimator",
|
||||||
"//third_party/py/tensorflow:tensorflow_estimator",
|
"//third_party/py/tensorflow:tensorflow_estimator",
|
||||||
|
"@com_google_differential_py//python/dp_accounting:dp_event",
|
||||||
|
"@com_google_differential_py//python/dp_accounting/rdp:rdp_privacy_accountant",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -30,9 +30,10 @@ import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow import estimator as tf_estimator
|
from tensorflow import estimator as tf_estimator
|
||||||
from tensorflow.compat.v1 import estimator as tf_compat_v1_estimator
|
from tensorflow.compat.v1 import estimator as tf_compat_v1_estimator
|
||||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
|
|
||||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||||
from tensorflow_privacy.privacy.optimizers import dp_optimizer
|
from tensorflow_privacy.privacy.optimizers import dp_optimizer
|
||||||
|
from com_google_differential_py.python.dp_accounting import dp_event
|
||||||
|
from com_google_differential_py.python.dp_accounting.rdp import rdp_privacy_accountant
|
||||||
|
|
||||||
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
|
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
|
||||||
|
|
||||||
|
@ -169,10 +170,14 @@ def print_privacy_guarantees(epochs, batch_size, samples, noise_multiplier):
|
||||||
eps, _, _ = get_privacy_spent(orders, rdp, target_delta=delta)
|
eps, _, _ = get_privacy_spent(orders, rdp, target_delta=delta)
|
||||||
print('\t{:g}% enjoy at least ({:.2f}, {})-DP'.format(p * 100, eps, delta))
|
print('\t{:g}% enjoy at least ({:.2f}, {})-DP'.format(p * 100, eps, delta))
|
||||||
|
|
||||||
# Compute privacy guarantees for the Sampled Gaussian Mechanism.
|
accountant = rdp_privacy_accountant.RdpAccountant(orders)
|
||||||
rdp_sgm = compute_rdp(batch_size / samples, noise_multiplier,
|
event = dp_event.SelfComposedDpEvent(
|
||||||
epochs * steps_per_epoch, orders)
|
dp_event.PoissonSampledDpEvent(
|
||||||
eps_sgm, _, _ = get_privacy_spent(orders, rdp_sgm, target_delta=delta)
|
batch_size / samples, dp_event.GaussianDpEvent(noise_multiplier)),
|
||||||
|
epochs * steps_per_epoch)
|
||||||
|
accountant.compose(event)
|
||||||
|
eps_sgm = accountant.get_epsilon(target_delta=delta)
|
||||||
|
|
||||||
print('By comparison, DP-SGD analysis for training done with the same '
|
print('By comparison, DP-SGD analysis for training done with the same '
|
||||||
'parameters and random shuffling in each epoch guarantees '
|
'parameters and random shuffling in each epoch guarantees '
|
||||||
'({:.2f}, {})-DP for all samples.'.format(eps_sgm, delta))
|
'({:.2f}, {})-DP for all samples.'.format(eps_sgm, delta))
|
||||||
|
|
Loading…
Reference in a new issue