Renaming stddev_to_sensitivity_ratio to noise_multiplier in rdp_accountant.
PiperOrigin-RevId: 227552068
This commit is contained in:
parent
205e005f60
commit
01ab549902
4 changed files with 36 additions and 37 deletions
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -12,12 +12,12 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""RDP analysis of the Sampled Gaussian mechanism.
|
"""RDP analysis of the Sampled Gaussian Mechanism.
|
||||||
|
|
||||||
Functionality for computing Renyi differential privacy (RDP) of an additive
|
Functionality for computing Renyi differential privacy (RDP) of an additive
|
||||||
Sampled Gaussian mechanism (SGM). Its public interface consists of two methods:
|
Sampled Gaussian Mechanism (SGM). Its public interface consists of two methods:
|
||||||
compute_rdp(q, stddev_to_sensitivity_ratio, T, orders) computes RDP with for
|
compute_rdp(q, noise_multiplier, T, orders) computes RDP for SGM iterated
|
||||||
SGM iterated T times.
|
T times.
|
||||||
get_privacy_spent(orders, rdp, target_eps, target_delta) computes delta
|
get_privacy_spent(orders, rdp, target_eps, target_delta) computes delta
|
||||||
(or eps) given RDP at multiple orders and
|
(or eps) given RDP at multiple orders and
|
||||||
a target value for eps (or delta).
|
a target value for eps (or delta).
|
||||||
|
@ -63,7 +63,7 @@ def _log_add(logx, logy):
|
||||||
def _log_sub(logx, logy):
|
def _log_sub(logx, logy):
|
||||||
"""Subtract two numbers in the log space. Answer must be non-negative."""
|
"""Subtract two numbers in the log space. Answer must be non-negative."""
|
||||||
if logx < logy:
|
if logx < logy:
|
||||||
raise ValueError("The result of subtraction must be non-negative .")
|
raise ValueError("The result of subtraction must be non-negative.")
|
||||||
if logy == -np.inf: # subtracting 0
|
if logy == -np.inf: # subtracting 0
|
||||||
return logx
|
return logx
|
||||||
if logx == logy:
|
if logx == logy:
|
||||||
|
@ -104,7 +104,7 @@ def _compute_log_a_int(q, sigma, alpha):
|
||||||
|
|
||||||
def _compute_log_a_frac(q, sigma, alpha):
|
def _compute_log_a_frac(q, sigma, alpha):
|
||||||
"""Compute log(A_alpha) for fractional alpha. 0 < q < 1."""
|
"""Compute log(A_alpha) for fractional alpha. 0 < q < 1."""
|
||||||
# The two parts of A_alpha, integrals over (-inf,z0] and (z0, +inf), are
|
# The two parts of A_alpha, integrals over (-inf,z0] and [z0, +inf), are
|
||||||
# initialized to 0 in the log space:
|
# initialized to 0 in the log space:
|
||||||
log_a0, log_a1 = -np.inf, -np.inf
|
log_a0, log_a1 = -np.inf, -np.inf
|
||||||
i = 0
|
i = 0
|
||||||
|
@ -148,6 +148,7 @@ def _compute_log_a(q, sigma, alpha):
|
||||||
|
|
||||||
|
|
||||||
def _log_erfc(x):
|
def _log_erfc(x):
|
||||||
|
"""Compute log(erfc(x)) with high accuracy for large x."""
|
||||||
try:
|
try:
|
||||||
return math.log(2) + special.log_ndtr(-x * 2**.5)
|
return math.log(2) + special.log_ndtr(-x * 2**.5)
|
||||||
except NameError:
|
except NameError:
|
||||||
|
@ -165,7 +166,7 @@ def _log_erfc(x):
|
||||||
|
|
||||||
|
|
||||||
def _compute_delta(orders, rdp, eps):
|
def _compute_delta(orders, rdp, eps):
|
||||||
"""Compute delta given an RDP curve and target epsilon.
|
"""Compute delta given a list of RDP values and target epsilon.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
orders: An array (or a scalar) of orders.
|
orders: An array (or a scalar) of orders.
|
||||||
|
@ -191,7 +192,7 @@ def _compute_delta(orders, rdp, eps):
|
||||||
|
|
||||||
|
|
||||||
def _compute_eps(orders, rdp, delta):
|
def _compute_eps(orders, rdp, delta):
|
||||||
"""Compute epsilon given an RDP curve and target delta.
|
"""Compute epsilon given a list of RDP values and target delta.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
orders: An array (or a scalar) of orders.
|
orders: An array (or a scalar) of orders.
|
||||||
|
@ -240,31 +241,30 @@ def _compute_rdp(q, sigma, alpha):
|
||||||
return _compute_log_a(q, sigma, alpha) / (alpha - 1)
|
return _compute_log_a(q, sigma, alpha) / (alpha - 1)
|
||||||
|
|
||||||
|
|
||||||
def compute_rdp(q, stddev_to_sensitivity_ratio, steps, orders):
|
def compute_rdp(q, noise_multiplier, steps, orders):
|
||||||
"""Compute RDP of the Sampled Gaussian Mechanism for given parameters.
|
"""Compute RDP of the Sampled Gaussian Mechanism.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q: The sampling rate.
|
q: The sampling rate.
|
||||||
stddev_to_sensitivity_ratio: The ratio of std of the Gaussian noise to the
|
noise_multiplier: The ratio of the standard deviation of the Gaussian noise
|
||||||
l2-sensitivity of the function to which it is added.
|
to the l2-sensitivity of the function to which it is added.
|
||||||
steps: The number of steps.
|
steps: The number of steps.
|
||||||
orders: An array (or a scalar) of RDP orders.
|
orders: An array (or a scalar) of RDP orders.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The RDPs at all orders, can be np.inf.
|
The RDPs at all orders, can be np.inf.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if np.isscalar(orders):
|
if np.isscalar(orders):
|
||||||
rdp = _compute_rdp(q, stddev_to_sensitivity_ratio, orders)
|
rdp = _compute_rdp(q, noise_multiplier, orders)
|
||||||
else:
|
else:
|
||||||
rdp = np.array([_compute_rdp(q, stddev_to_sensitivity_ratio, order)
|
rdp = np.array([_compute_rdp(q, noise_multiplier, order)
|
||||||
for order in orders])
|
for order in orders])
|
||||||
|
|
||||||
return rdp * steps
|
return rdp * steps
|
||||||
|
|
||||||
|
|
||||||
def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
|
def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
|
||||||
"""Compute delta (or eps) for given eps (or delta) from the RDP curve.
|
"""Compute delta (or eps) for given eps (or delta) from RDP values.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
orders: An array (or a scalar) of RDP orders.
|
orders: An array (or a scalar) of RDP orders.
|
||||||
|
@ -273,6 +273,7 @@ def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
|
||||||
delta.
|
delta.
|
||||||
target_delta: If not None, the delta for which we compute the corresponding
|
target_delta: If not None, the delta for which we compute the corresponding
|
||||||
epsilon. Exactly one of target_eps and target_delta must be None.
|
epsilon. Exactly one of target_eps and target_delta must be None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
eps, delta, opt_order.
|
eps, delta, opt_order.
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -27,7 +27,6 @@ import numpy as np
|
||||||
|
|
||||||
from privacy.analysis import rdp_accountant
|
from privacy.analysis import rdp_accountant
|
||||||
|
|
||||||
|
|
||||||
class TestGaussianMoments(parameterized.TestCase):
|
class TestGaussianMoments(parameterized.TestCase):
|
||||||
#################################
|
#################################
|
||||||
# HELPER FUNCTIONS: #
|
# HELPER FUNCTIONS: #
|
||||||
|
@ -134,7 +133,7 @@ class TestGaussianMoments(parameterized.TestCase):
|
||||||
16., 20., 24., 28., 32., 64., 256.)
|
16., 20., 24., 28., 32., 64., 256.)
|
||||||
|
|
||||||
rdp = rdp_accountant.compute_rdp(q=1e-4,
|
rdp = rdp_accountant.compute_rdp(q=1e-4,
|
||||||
stddev_to_sensitivity_ratio=.4,
|
noise_multiplier=.4,
|
||||||
steps=40000,
|
steps=40000,
|
||||||
orders=orders)
|
orders=orders)
|
||||||
|
|
||||||
|
@ -142,7 +141,7 @@ class TestGaussianMoments(parameterized.TestCase):
|
||||||
target_delta=1e-6)
|
target_delta=1e-6)
|
||||||
|
|
||||||
rdp += rdp_accountant.compute_rdp(q=0.1,
|
rdp += rdp_accountant.compute_rdp(q=0.1,
|
||||||
stddev_to_sensitivity_ratio=2,
|
noise_multiplier=2,
|
||||||
steps=100,
|
steps=100,
|
||||||
orders=orders)
|
orders=orders)
|
||||||
eps, _, opt_order = rdp_accountant.get_privacy_spent(orders, rdp,
|
eps, _, opt_order = rdp_accountant.get_privacy_spent(orders, rdp,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Tutorials
|
# Tutorials
|
||||||
|
|
||||||
As demonstrated on MNIST in `mnist_dpsgd_tutorial.py`, the easiest to use
|
As demonstrated on MNIST in `mnist_dpsgd_tutorial.py`, the easiest way to use
|
||||||
a differentially private optimizer is to modify an existing training loop
|
a differentially private optimizer is to modify an existing training loop
|
||||||
to replace an existing vanilla optimizer with its differentially private
|
to replace an existing vanilla optimizer with its differentially private
|
||||||
counterpart implemented in the library.
|
counterpart implemented in the library.
|
||||||
|
@ -26,26 +26,25 @@ be tuned in addition to any existing hyperparameter. There are currently three:
|
||||||
|
|
||||||
## Measuring Privacy
|
## Measuring Privacy
|
||||||
|
|
||||||
Differential privacy is measured by two values, epsilon and delta. Roughly
|
Differential privacy can be expressed using two values, epsilon and delta.
|
||||||
speaking, they mean the following:
|
Roughly speaking, they mean the following:
|
||||||
|
|
||||||
* epsilon gives a ceiling on how much the probability of a change in model
|
* epsilon gives a ceiling on how much the probability of a particular output
|
||||||
behavior can increase by including a single extra training example. This is
|
can increase by including (or removing) a single training example. We usually
|
||||||
the far more sensitive value, and we usually want it to be at most 10.0 or
|
want it to be a small constant (less than 10, or, for more stringent privacy
|
||||||
so. However, note that this is only an upper bound, and a large value of
|
guarantees, less than 1). However, this is only an upper bound, and a large
|
||||||
epsilon may still mean good practical privacy.
|
value of epsilon may still mean good practical privacy.
|
||||||
* delta bounds the probability of an "unconditional" change in model behavior.
|
* delta bounds the probability of an arbitrary change in model behavior.
|
||||||
We can usually set this to a very small number (1e-7 or so) without
|
We can usually set this to a very small number (1e-7 or so) without
|
||||||
compromising utility. A rule of thumb is to set it to the inverse of the
|
compromising utility. A rule of thumb is to set it to be less than the inverse
|
||||||
order of magnitude of the training data size.
|
of the training data size.
|
||||||
|
|
||||||
To find out the epsilon given a fixed delta value for your model, follow the
|
To find out the epsilon given a fixed delta value for your model, follow the
|
||||||
approach demonstrated in the `compute_epsilon` of the `mnist_dpsgd_tutorial.py`
|
approach demonstrated in the `compute_epsilon` of the `mnist_dpsgd_tutorial.py`
|
||||||
where the arguments used to call the RDP accountant (i.e., the tool used to
|
where the arguments used to call the RDP accountant (i.e., the tool used to
|
||||||
compute the privacy guarantee) are:
|
compute the privacy guarantee) are:
|
||||||
|
|
||||||
* q : The sampling ratio, defined as (number of examples consumed in one
|
* q : The sampling ratio, defined as (number of examples consumed in one
|
||||||
step) / (total training examples).
|
step) / (total training examples).
|
||||||
* stddev_to_sensitivity_ratio : The noise_multiplier from your parameters above.
|
* noise_multiplier : The noise_multiplier from your parameters above.
|
||||||
* steps : The number of global steps taken.
|
* steps : The number of global steps taken.
|
||||||
|
|
||||||
|
|
|
@ -156,7 +156,7 @@ def main(unused_argv):
|
||||||
orders = [1 + x / 10. for x in range(1, 100)] + range(12, 64)
|
orders = [1 + x / 10. for x in range(1, 100)] + range(12, 64)
|
||||||
sampling_probability = FLAGS.batch_size / 60000
|
sampling_probability = FLAGS.batch_size / 60000
|
||||||
rdp = compute_rdp(q=sampling_probability,
|
rdp = compute_rdp(q=sampling_probability,
|
||||||
stddev_to_sensitivity_ratio=FLAGS.noise_multiplier,
|
noise_multiplier=FLAGS.noise_multiplier,
|
||||||
steps=steps,
|
steps=steps,
|
||||||
orders=orders)
|
orders=orders)
|
||||||
# Delta is set to 1e-5 because MNIST has 60000 training points.
|
# Delta is set to 1e-5 because MNIST has 60000 training points.
|
||||||
|
|
Loading…
Reference in a new issue