forked from 626_privacy/tensorflow_privacy
Project import generated by Copybara.
PiperOrigin-RevId: 226056146
This commit is contained in:
parent
ceee90b1ac
commit
1595ed3cd1
12 changed files with 918 additions and 254 deletions
90
README.md
90
README.md
|
@ -1,8 +1,88 @@
|
||||||
# TensorFlow Privacy
|
# TensorFlow Privacy
|
||||||
|
|
||||||
This repository will contain implementations of TensorFlow optimizers that
|
This repository contains the source code for TensorFlow Privacy, a Python
|
||||||
support training machine learning models with (differential) privacy, as well
|
library that includes implementations of TensorFlow optimizers for training
|
||||||
as tutorials and analysis tools for computing the privacy guarantees provided.
|
machine learning models with differential privacy. The library comes with
|
||||||
|
tutorials and analysis tools for computing the privacy guarantees provided.
|
||||||
|
|
||||||
The content of this repository will supersede the following existing repository:
|
The TensorFlow Privacy library is under continual development, always welcoming
|
||||||
https://github.com/tensorflow/models/tree/master/research/differential_privacy
|
contributions. In particular, we always welcome help towards resolving the
|
||||||
|
issues currently open.
|
||||||
|
|
||||||
|
## Setting up TensorFlow Privacy
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
|
||||||
|
This library uses [TensorFlow](https://www.tensorflow.org/) to define machine
|
||||||
|
learning models. Therefore, installing TensorFlow is a pre-requisite. You can
|
||||||
|
find instructions [here](https://www.tensorflow.org/install/). For better
|
||||||
|
performance, it is also recommended to install TensorFlow with GPU support
|
||||||
|
(detailed instructions on how to do this are available in the TensorFlow
|
||||||
|
installation documentation).
|
||||||
|
|
||||||
|
Installing TensorFlow will take care of all other dependencies like `numpy` and
|
||||||
|
`scipy`.
|
||||||
|
|
||||||
|
### Installing TensorFlow Privacy
|
||||||
|
|
||||||
|
First, clone this GitHub repository into a directory of your choice:
|
||||||
|
|
||||||
|
```
|
||||||
|
git clone https://github.com/tensorflow/privacy
|
||||||
|
```
|
||||||
|
|
||||||
|
You can then install the local package in "editable" mode in order to add it to
|
||||||
|
your `PYTHONPATH`:
|
||||||
|
|
||||||
|
```
|
||||||
|
cd privacy
|
||||||
|
pip install -e ./privacy
|
||||||
|
```
|
||||||
|
|
||||||
|
If you'd like to make contributions, we recommend first forking the repository
|
||||||
|
and then cloning your fork rather than cloning this repository directly.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions are welcomed! Bug fixes and new features can be initiated through
|
||||||
|
Github pull requests. To speed the code review process, we ask that:
|
||||||
|
|
||||||
|
* When making code contributions to TensorFlow Privacy, you follow the `PEP8
|
||||||
|
with two spaces` coding style (the same as the one used by TensorFlow) in
|
||||||
|
your pull requests. In most cases this can be done by running `autopep8 -i
|
||||||
|
--indent-size 2 <file>` on the files you have edited.
|
||||||
|
|
||||||
|
* When making your first pull request, you
|
||||||
|
[sign the Google CLA](https://cla.developers.google.com/clas)
|
||||||
|
|
||||||
|
* We do not accept pull requests that add git submodules because of
|
||||||
|
[the problems that arise when maintaining git submodules](https://medium.com/@porteneuve/mastering-git-submodules-34c65e940407)
|
||||||
|
|
||||||
|
## Tutorials directory
|
||||||
|
|
||||||
|
To help you get started with the functionalities provided by this library, the
|
||||||
|
`tutorials/` folder comes with scripts demonstrating how to use the library
|
||||||
|
features.
|
||||||
|
|
||||||
|
NOTE: the tutorials are maintained carefully. However, they are not considered
|
||||||
|
part of the API and they can change at any time without warning. You should not
|
||||||
|
write 3rd party code that imports the tutorials and expect that the interface
|
||||||
|
will not break.
|
||||||
|
|
||||||
|
## Remarks
|
||||||
|
|
||||||
|
The content of this repository supersedes the following existing folder in the
|
||||||
|
tensorflow/models [repository](https://github.com/tensorflow/models/tree/master/research/differential_privacy)
|
||||||
|
|
||||||
|
## Contacts
|
||||||
|
|
||||||
|
If you have any questions that cannot be addressed by raising an issue, feel
|
||||||
|
free to contact:
|
||||||
|
|
||||||
|
* Nicolas Papernot (@npapernot)
|
||||||
|
* Steve Chien
|
||||||
|
* Galen Andrew (@galenmandrew)
|
||||||
|
|
||||||
|
## Copyright
|
||||||
|
|
||||||
|
Copyright 2018 - Google LLC
|
||||||
|
|
0
privacy/__init__.py
Normal file
0
privacy/__init__.py
Normal file
0
privacy/analysis/__init__.py
Normal file
0
privacy/analysis/__init__.py
Normal file
295
privacy/analysis/rdp_accountant.py
Normal file
295
privacy/analysis/rdp_accountant.py
Normal file
|
@ -0,0 +1,295 @@
|
||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""RDP analysis of the Sampled Gaussian mechanism.
|
||||||
|
|
||||||
|
Functionality for computing Renyi differential privacy (RDP) of an additive
|
||||||
|
Sampled Gaussian mechanism (SGM). Its public interface consists of two methods:
|
||||||
|
compute_rdp(q, stddev_to_sensitivity_ratio, T, orders) computes RDP with for
|
||||||
|
SGM iterated T times.
|
||||||
|
get_privacy_spent(orders, rdp, target_eps, target_delta) computes delta
|
||||||
|
(or eps) given RDP at multiple orders and
|
||||||
|
a target value for eps (or delta).
|
||||||
|
|
||||||
|
Example use:
|
||||||
|
|
||||||
|
Suppose that we have run an SGM applied to a function with l2-sensitivity 1.
|
||||||
|
Its parameters are given as a list of tuples (q1, sigma1, T1), ...,
|
||||||
|
(qk, sigma_k, Tk), and we wish to compute eps for a given delta.
|
||||||
|
The example code would be:
|
||||||
|
|
||||||
|
max_order = 32
|
||||||
|
orders = range(2, max_order + 1)
|
||||||
|
rdp = np.zeros_like(orders, dtype=float)
|
||||||
|
for q, sigma, T in parameters:
|
||||||
|
rdp += rdp_accountant.compute_rdp(q, sigma, T, orders)
|
||||||
|
eps, _, opt_order = rdp_accountant.get_privacy_spent(rdp, target_delta=delta)
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy import special
|
||||||
|
|
||||||
|
########################
|
||||||
|
# LOG-SPACE ARITHMETIC #
|
||||||
|
########################
|
||||||
|
|
||||||
|
|
||||||
|
def _log_add(logx, logy):
|
||||||
|
"""Add two numbers in the log space."""
|
||||||
|
a, b = min(logx, logy), max(logx, logy)
|
||||||
|
if a == -np.inf: # adding 0
|
||||||
|
return b
|
||||||
|
# Use exp(a) + exp(b) = (exp(a - b) + 1) * exp(b)
|
||||||
|
return math.log1p(math.exp(a - b)) + b # log1p(x) = log(x + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_sub(logx, logy):
|
||||||
|
"""Subtract two numbers in the log space. Answer must be non-negative."""
|
||||||
|
if logx < logy:
|
||||||
|
raise ValueError("The result of subtraction must be non-negative .")
|
||||||
|
if logy == -np.inf: # subtracting 0
|
||||||
|
return logx
|
||||||
|
if logx == logy:
|
||||||
|
return -np.inf # 0 is represented as -np.inf in the log space.
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y).
|
||||||
|
return math.log(math.expm1(logx - logy)) + logy # expm1(x) = exp(x) - 1
|
||||||
|
except OverflowError:
|
||||||
|
return logx
|
||||||
|
|
||||||
|
|
||||||
|
def _log_print(logx):
|
||||||
|
"""Pretty print."""
|
||||||
|
if logx < math.log(sys.float_info.max):
|
||||||
|
return "{}".format(math.exp(logx))
|
||||||
|
else:
|
||||||
|
return "exp({})".format(logx)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_log_a_int(q, sigma, alpha):
|
||||||
|
"""Compute log(A_alpha) for integer alpha. 0 < q < 1."""
|
||||||
|
assert isinstance(alpha, (int, long))
|
||||||
|
|
||||||
|
# Initialize with 0 in the log space.
|
||||||
|
log_a = -np.inf
|
||||||
|
|
||||||
|
for i in range(alpha + 1):
|
||||||
|
log_coef_i = (
|
||||||
|
math.log(special.binom(alpha, i)) + i * math.log(q) +
|
||||||
|
(alpha - i) * math.log(1 - q))
|
||||||
|
|
||||||
|
s = log_coef_i + (i * i - i) / (2 * (sigma**2))
|
||||||
|
log_a = _log_add(log_a, s)
|
||||||
|
|
||||||
|
return float(log_a)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_log_a_frac(q, sigma, alpha):
|
||||||
|
"""Compute log(A_alpha) for fractional alpha. 0 < q < 1."""
|
||||||
|
# The two parts of A_alpha, integrals over (-inf,z0] and (z0, +inf), are
|
||||||
|
# initialized to 0 in the log space:
|
||||||
|
log_a0, log_a1 = -np.inf, -np.inf
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
z0 = sigma**2 * math.log(1 / q - 1) + .5
|
||||||
|
|
||||||
|
while True: # do ... until loop
|
||||||
|
coef = special.binom(alpha, i)
|
||||||
|
log_coef = math.log(abs(coef))
|
||||||
|
j = alpha - i
|
||||||
|
|
||||||
|
log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q)
|
||||||
|
log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q)
|
||||||
|
|
||||||
|
log_e0 = math.log(.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma))
|
||||||
|
log_e1 = math.log(.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma))
|
||||||
|
|
||||||
|
log_s0 = log_t0 + (i * i - i) / (2 * (sigma**2)) + log_e0
|
||||||
|
log_s1 = log_t1 + (j * j - j) / (2 * (sigma**2)) + log_e1
|
||||||
|
|
||||||
|
if coef > 0:
|
||||||
|
log_a0 = _log_add(log_a0, log_s0)
|
||||||
|
log_a1 = _log_add(log_a1, log_s1)
|
||||||
|
else:
|
||||||
|
log_a0 = _log_sub(log_a0, log_s0)
|
||||||
|
log_a1 = _log_sub(log_a1, log_s1)
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
if max(log_s0, log_s1) < -30:
|
||||||
|
break
|
||||||
|
|
||||||
|
return _log_add(log_a0, log_a1)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_log_a(q, sigma, alpha):
|
||||||
|
"""Compute log(A_alpha) for any positive finite alpha."""
|
||||||
|
if float(alpha).is_integer():
|
||||||
|
return _compute_log_a_int(q, sigma, int(alpha))
|
||||||
|
else:
|
||||||
|
return _compute_log_a_frac(q, sigma, alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_erfc(x):
|
||||||
|
try:
|
||||||
|
return math.log(2) + special.log_ndtr(-x * 2**.5)
|
||||||
|
except NameError:
|
||||||
|
# If log_ndtr is not available, approximate as follows:
|
||||||
|
r = special.erfc(x)
|
||||||
|
if r == 0.0:
|
||||||
|
# Using the Laurent series at infinity for the tail of the erfc function:
|
||||||
|
# erfc(x) ~ exp(-x^2-.5/x^2+.625/x^4)/(x*pi^.5)
|
||||||
|
# To verify in Mathematica:
|
||||||
|
# Series[Log[Erfc[x]] + Log[x] + Log[Pi]/2 + x^2, {x, Infinity, 6}]
|
||||||
|
return (-math.log(math.pi) / 2 - math.log(x) - x**2 - .5 * x**-2 +
|
||||||
|
.625 * x**-4 - 37. / 24. * x**-6 + 353. / 64. * x**-8)
|
||||||
|
else:
|
||||||
|
return math.log(r)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_delta(orders, rdp, eps):
|
||||||
|
"""Compute delta given an RDP curve and target epsilon.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
orders: An array (or a scalar) of orders.
|
||||||
|
rdp: A list (or a scalar) of RDP guarantees.
|
||||||
|
eps: The target epsilon.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pair of (delta, optimal_order).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input is malformed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
orders_vec = np.atleast_1d(orders)
|
||||||
|
rdp_vec = np.atleast_1d(rdp)
|
||||||
|
|
||||||
|
if len(orders_vec) != len(rdp_vec):
|
||||||
|
raise ValueError("Input lists must have the same length.")
|
||||||
|
|
||||||
|
deltas = np.exp((rdp_vec - eps) * (orders_vec - 1))
|
||||||
|
idx_opt = np.argmin(deltas)
|
||||||
|
return min(deltas[idx_opt], 1.), orders_vec[idx_opt]
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_eps(orders, rdp, delta):
|
||||||
|
"""Compute epsilon given an RDP curve and target delta.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
orders: An array (or a scalar) of orders.
|
||||||
|
rdp: A list (or a scalar) of RDP guarantees.
|
||||||
|
delta: The target delta.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pair of (eps, optimal_order).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input is malformed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
orders_vec = np.atleast_1d(orders)
|
||||||
|
rdp_vec = np.atleast_1d(rdp)
|
||||||
|
|
||||||
|
if len(orders_vec) != len(rdp_vec):
|
||||||
|
raise ValueError("Input lists must have the same length.")
|
||||||
|
|
||||||
|
eps = rdp_vec - math.log(delta) / (orders_vec - 1)
|
||||||
|
|
||||||
|
idx_opt = np.nanargmin(eps) # Ignore NaNs
|
||||||
|
return eps[idx_opt], orders_vec[idx_opt]
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_rdp(q, sigma, alpha):
|
||||||
|
"""Compute RDP of the Sampled Gaussian mechanism at order alpha.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: The sampling rate.
|
||||||
|
sigma: The std of the additive Gaussian noise.
|
||||||
|
alpha: The order at which RDP is computed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RDP at alpha, can be np.inf.
|
||||||
|
"""
|
||||||
|
if q == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if q == 1.:
|
||||||
|
return alpha / (2 * sigma**2)
|
||||||
|
|
||||||
|
if np.isinf(alpha):
|
||||||
|
return np.inf
|
||||||
|
|
||||||
|
return _compute_log_a(q, sigma, alpha) / (alpha - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_rdp(q, stddev_to_sensitivity_ratio, steps, orders):
|
||||||
|
"""Compute RDP of the Sampled Gaussian Mechanism for given parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: The sampling rate.
|
||||||
|
stddev_to_sensitivity_ratio: The ratio of std of the Gaussian noise to the
|
||||||
|
l2-sensitivity of the function to which it is added.
|
||||||
|
steps: The number of steps.
|
||||||
|
orders: An array (or a scalar) of RDP orders.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The RDPs at all orders, can be np.inf.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if np.isscalar(orders):
|
||||||
|
rdp = _compute_rdp(q, stddev_to_sensitivity_ratio, orders)
|
||||||
|
else:
|
||||||
|
rdp = np.array([_compute_rdp(q, stddev_to_sensitivity_ratio, order)
|
||||||
|
for order in orders])
|
||||||
|
|
||||||
|
return rdp * steps
|
||||||
|
|
||||||
|
|
||||||
|
def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
|
||||||
|
"""Compute delta (or eps) for given eps (or delta) from the RDP curve.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
orders: An array (or a scalar) of RDP orders.
|
||||||
|
rdp: An array of RDP values. Must be of the same length as the orders list.
|
||||||
|
target_eps: If not None, the epsilon for which we compute the corresponding
|
||||||
|
delta.
|
||||||
|
target_delta: If not None, the delta for which we compute the corresponding
|
||||||
|
epsilon. Exactly one of target_eps and target_delta must be None.
|
||||||
|
Returns:
|
||||||
|
eps, delta, opt_order.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If target_eps and target_delta are messed up.
|
||||||
|
"""
|
||||||
|
if target_eps is None and target_delta is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Exactly one out of eps and delta must be None. (Both are).")
|
||||||
|
|
||||||
|
if target_eps is not None and target_delta is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Exactly one out of eps and delta must be None. (None is).")
|
||||||
|
|
||||||
|
if target_eps is not None:
|
||||||
|
delta, opt_order = _compute_delta(orders, rdp, target_eps)
|
||||||
|
return target_eps, delta, opt_order
|
||||||
|
else:
|
||||||
|
eps, opt_order = _compute_eps(orders, rdp, target_delta)
|
||||||
|
return eps, target_delta, opt_order
|
155
privacy/analysis/rdp_accountant_test.py
Normal file
155
privacy/analysis/rdp_accountant_test.py
Normal file
|
@ -0,0 +1,155 @@
|
||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for rdp_accountant.py."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import mpmath as mp
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import rdp_accountant
|
||||||
|
|
||||||
|
|
||||||
|
class TestGaussianMoments(parameterized.TestCase):
|
||||||
|
#################################
|
||||||
|
# HELPER FUNCTIONS: #
|
||||||
|
# Exact computations using #
|
||||||
|
# multi-precision arithmetic. #
|
||||||
|
#################################
|
||||||
|
|
||||||
|
def _log_float_mp(self, x):
|
||||||
|
# Convert multi-precision input to float log space.
|
||||||
|
if x >= sys.float_info.min:
|
||||||
|
return float(mp.log(x))
|
||||||
|
else:
|
||||||
|
return -np.inf
|
||||||
|
|
||||||
|
def _integral_mp(self, fn, bounds=(-mp.inf, mp.inf)):
|
||||||
|
integral, _ = mp.quad(fn, bounds, error=True, maxdegree=8)
|
||||||
|
return integral
|
||||||
|
|
||||||
|
def _distributions_mp(self, sigma, q):
|
||||||
|
|
||||||
|
def _mu0(x):
|
||||||
|
return mp.npdf(x, mu=0, sigma=sigma)
|
||||||
|
|
||||||
|
def _mu1(x):
|
||||||
|
return mp.npdf(x, mu=1, sigma=sigma)
|
||||||
|
|
||||||
|
def _mu(x):
|
||||||
|
return (1 - q) * _mu0(x) + q * _mu1(x)
|
||||||
|
|
||||||
|
return _mu0, _mu # Closure!
|
||||||
|
|
||||||
|
def _mu1_over_mu0(self, x, sigma):
|
||||||
|
# Closed-form expression for N(1, sigma^2) / N(0, sigma^2) at x.
|
||||||
|
return mp.exp((2 * x - 1) / (2 * sigma**2))
|
||||||
|
|
||||||
|
def _mu_over_mu0(self, x, q, sigma):
|
||||||
|
return (1 - q) + q * self._mu1_over_mu0(x, sigma)
|
||||||
|
|
||||||
|
def _compute_a_mp(self, sigma, q, alpha):
|
||||||
|
"""Compute A_alpha for arbitrary alpha by numerical integration."""
|
||||||
|
mu0, _ = self._distributions_mp(sigma, q)
|
||||||
|
a_alpha_fn = lambda z: mu0(z) * self._mu_over_mu0(z, q, sigma)**alpha
|
||||||
|
a_alpha = self._integral_mp(a_alpha_fn)
|
||||||
|
return a_alpha
|
||||||
|
|
||||||
|
# TEST ROUTINES
|
||||||
|
def test_compute_rdp_no_data(self):
|
||||||
|
# q = 0
|
||||||
|
self.assertEqual(rdp_accountant.compute_rdp(0, 10, 1, 20), 0)
|
||||||
|
|
||||||
|
def test_compute_rdp_no_sampling(self):
|
||||||
|
# q = 1, RDP = alpha/2 * sigma^2
|
||||||
|
self.assertEqual(rdp_accountant.compute_rdp(1, 10, 1, 20), 0.1)
|
||||||
|
|
||||||
|
def test_compute_rdp_scalar(self):
|
||||||
|
rdp_scalar = rdp_accountant.compute_rdp(0.1, 2, 10, 5)
|
||||||
|
self.assertAlmostEqual(rdp_scalar, 0.07737, places=5)
|
||||||
|
|
||||||
|
def test_compute_rdp_sequence(self):
|
||||||
|
rdp_vec = rdp_accountant.compute_rdp(0.01, 2.5, 50,
|
||||||
|
[1.5, 2.5, 5, 50, 100, np.inf])
|
||||||
|
self.assertSequenceAlmostEqual(
|
||||||
|
rdp_vec, [0.00065, 0.001085, 0.00218075, 0.023846, 167.416307, np.inf],
|
||||||
|
delta=1e-5)
|
||||||
|
|
||||||
|
params = ({'q': 1e-7, 'sigma': .1, 'order': 1.01},
|
||||||
|
{'q': 1e-6, 'sigma': .1, 'order': 256},
|
||||||
|
{'q': 1e-5, 'sigma': .1, 'order': 256.1},
|
||||||
|
{'q': 1e-6, 'sigma': 1, 'order': 27},
|
||||||
|
{'q': 1e-4, 'sigma': 1., 'order': 1.5},
|
||||||
|
{'q': 1e-3, 'sigma': 1., 'order': 2},
|
||||||
|
{'q': .01, 'sigma': 10, 'order': 20},
|
||||||
|
{'q': .1, 'sigma': 100, 'order': 20.5},
|
||||||
|
{'q': .99, 'sigma': .1, 'order': 256},
|
||||||
|
{'q': .999, 'sigma': 100, 'order': 256.1})
|
||||||
|
|
||||||
|
# pylint:disable=undefined-variable
|
||||||
|
@parameterized.parameters(p for p in params)
|
||||||
|
def test_compute_log_a_equals_mp(self, q, sigma, order):
|
||||||
|
# Compare the cheap computation of log(A) with an expensive, multi-precision
|
||||||
|
# computation.
|
||||||
|
log_a = rdp_accountant._compute_log_a(q, sigma, order)
|
||||||
|
log_a_mp = self._log_float_mp(self._compute_a_mp(sigma, q, order))
|
||||||
|
np.testing.assert_allclose(log_a, log_a_mp, rtol=1e-4)
|
||||||
|
|
||||||
|
def test_get_privacy_spent_check_target_delta(self):
|
||||||
|
orders = range(2, 33)
|
||||||
|
rdp = rdp_accountant.compute_rdp(0.01, 4, 10000, orders)
|
||||||
|
eps, _, opt_order = rdp_accountant.get_privacy_spent(
|
||||||
|
orders, rdp, target_delta=1e-5)
|
||||||
|
self.assertAlmostEqual(eps, 1.258575, places=5)
|
||||||
|
self.assertEqual(opt_order, 20)
|
||||||
|
|
||||||
|
def test_get_privacy_spent_check_target_eps(self):
|
||||||
|
orders = range(2, 33)
|
||||||
|
rdp = rdp_accountant.compute_rdp(0.01, 4, 10000, orders)
|
||||||
|
_, delta, opt_order = rdp_accountant.get_privacy_spent(
|
||||||
|
orders, rdp, target_eps=1.258575)
|
||||||
|
self.assertAlmostEqual(delta, 1e-5)
|
||||||
|
self.assertEqual(opt_order, 20)
|
||||||
|
|
||||||
|
def test_check_composition(self):
|
||||||
|
orders = (1.25, 1.5, 1.75, 2., 2.5, 3., 4., 5., 6., 7., 8., 10., 12., 14.,
|
||||||
|
16., 20., 24., 28., 32., 64., 256.)
|
||||||
|
|
||||||
|
rdp = rdp_accountant.compute_rdp(q=1e-4,
|
||||||
|
stddev_to_sensitivity_ratio=.4,
|
||||||
|
steps=40000,
|
||||||
|
orders=orders)
|
||||||
|
|
||||||
|
eps, _, opt_order = rdp_accountant.get_privacy_spent(orders, rdp,
|
||||||
|
target_delta=1e-6)
|
||||||
|
|
||||||
|
rdp += rdp_accountant.compute_rdp(q=0.1,
|
||||||
|
stddev_to_sensitivity_ratio=2,
|
||||||
|
steps=100,
|
||||||
|
orders=orders)
|
||||||
|
eps, _, opt_order = rdp_accountant.get_privacy_spent(orders, rdp,
|
||||||
|
target_delta=1e-5)
|
||||||
|
self.assertAlmostEqual(eps, 8.509656, places=5)
|
||||||
|
self.assertEqual(opt_order, 2.5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
0
privacy/optimizers/__init__.py
Normal file
0
privacy/optimizers/__init__.py
Normal file
|
@ -1,123 +0,0 @@
|
||||||
# Copyright 2018, The TensorFlow Authors.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""DPAdamOptimizer for TensorFlow."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
import privacy.optimizers.gaussian_average_query as ph
|
|
||||||
|
|
||||||
|
|
||||||
class DPAdamOptimizer(tf.train.AdamOptimizer):
|
|
||||||
"""Optimizer that implements the DP Adam algorithm.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
learning_rate,
|
|
||||||
beta1=0.9,
|
|
||||||
beta2=0.999,
|
|
||||||
epsilon=1e-8,
|
|
||||||
use_locking=False,
|
|
||||||
l2_norm_clip=1e9,
|
|
||||||
noise_multiplier=0.0,
|
|
||||||
nb_microbatches=1,
|
|
||||||
name='DPAdam'):
|
|
||||||
"""Construct a new DP Adam optimizer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
learning_rate: A Tensor or a floating point value. The learning rate to
|
|
||||||
use.
|
|
||||||
beta1: A float value or a constant float tensor.
|
|
||||||
The exponential decay rate for the 1st moment estimates.
|
|
||||||
beta2: A float value or a constant float tensor.
|
|
||||||
The exponential decay rate for the 2nd moment estimates.
|
|
||||||
epsilon: A small constant for numerical stability. This epsilon is
|
|
||||||
"epsilon hat" in the Kingma and Ba paper (in the formula just before
|
|
||||||
Section 2.1), not the epsilon in Algorithm 1 of the paper.
|
|
||||||
use_locking: If True use locks for update operations.
|
|
||||||
l2_norm_clip: Clipping parameter for DP-SGD.
|
|
||||||
noise_multiplier: Noise multiplier for DP-SGD.
|
|
||||||
nb_microbatches: Number of microbatches in which to split the input.
|
|
||||||
name: Optional name prefix for the operations created when applying
|
|
||||||
gradients. Defaults to "DPAdam". @compatibility(eager) When eager
|
|
||||||
execution is enabled, `learning_rate` can be a callable that takes no
|
|
||||||
arguments and returns the actual value to use. This can be useful for
|
|
||||||
changing these values across different invocations of optimizer
|
|
||||||
functions. @end_compatibility
|
|
||||||
"""
|
|
||||||
super(DPAdamOptimizer, self).__init__(
|
|
||||||
learning_rate,
|
|
||||||
beta1,
|
|
||||||
beta2,
|
|
||||||
epsilon,
|
|
||||||
use_locking,
|
|
||||||
name)
|
|
||||||
stddev = l2_norm_clip * noise_multiplier
|
|
||||||
self._nb_microbatches = nb_microbatches
|
|
||||||
self._privacy_helper = ph.GaussianAverageQuery(l2_norm_clip, stddev,
|
|
||||||
nb_microbatches)
|
|
||||||
self._ph_global_state = self._privacy_helper.initial_global_state()
|
|
||||||
|
|
||||||
def compute_gradients(self,
|
|
||||||
loss,
|
|
||||||
var_list,
|
|
||||||
gate_gradients=tf.train.Optimizer.GATE_OP,
|
|
||||||
aggregation_method=None,
|
|
||||||
colocate_gradients_with_ops=False,
|
|
||||||
grad_loss=None):
|
|
||||||
|
|
||||||
# Note: it would be closer to the correct i.i.d. sampling of records if
|
|
||||||
# we sampled each microbatch from the appropriate binomial distribution,
|
|
||||||
# although that still wouldn't be quite correct because it would be sampling
|
|
||||||
# from the dataset without replacement.
|
|
||||||
microbatches_losses = tf.reshape(loss, [self._nb_microbatches, -1])
|
|
||||||
sample_params = (
|
|
||||||
self._privacy_helper.derive_sample_params(self._ph_global_state))
|
|
||||||
|
|
||||||
def process_microbatch(i, sample_state):
|
|
||||||
"""Process one microbatch (record) with privacy helper."""
|
|
||||||
grads, _ = zip(*super(DPAdamOptimizer, self).compute_gradients(
|
|
||||||
tf.gather(microbatches_losses, [i]), var_list, gate_gradients,
|
|
||||||
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
|
||||||
grads_list = list(grads)
|
|
||||||
sample_state = self._privacy_helper.accumulate_record(
|
|
||||||
sample_params, sample_state, grads_list)
|
|
||||||
return [tf.add(i, 1), sample_state]
|
|
||||||
|
|
||||||
i = tf.constant(0)
|
|
||||||
|
|
||||||
if var_list is None:
|
|
||||||
var_list = (
|
|
||||||
tf.trainable_variables() +
|
|
||||||
tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
|
||||||
sample_state = self._privacy_helper.initial_sample_state(
|
|
||||||
self._ph_global_state, var_list)
|
|
||||||
|
|
||||||
# Use of while_loop here requires that sample_state be a nested structure of
|
|
||||||
# tensors. In general, we would prefer to allow it to be an arbitrary
|
|
||||||
# opaque type.
|
|
||||||
_, final_state = tf.while_loop(
|
|
||||||
lambda i, _: tf.less(i, self._nb_microbatches), process_microbatch,
|
|
||||||
[i, sample_state])
|
|
||||||
final_grads, self._ph_global_state = (
|
|
||||||
self._privacy_helper.get_noised_average(final_state,
|
|
||||||
self._ph_global_state))
|
|
||||||
|
|
||||||
return zip(final_grads, var_list)
|
|
||||||
|
|
|
@ -1,107 +0,0 @@
|
||||||
# Copyright 2018, The TensorFlow Authors.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""DPGradientDescentOptimizer for TensorFlow."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
import privacy.optimizers.gaussian_average_query as ph
|
|
||||||
|
|
||||||
|
|
||||||
class DPGradientDescentOptimizer(tf.train.GradientDescentOptimizer):
|
|
||||||
"""Optimizer that implements the DP gradient descent algorithm.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
learning_rate,
|
|
||||||
use_locking=False,
|
|
||||||
l2_norm_clip=1e9,
|
|
||||||
noise_multiplier=0.0,
|
|
||||||
nb_microbatches=1,
|
|
||||||
name='DPGradientDescent'):
|
|
||||||
"""Construct a new DP gradient descent optimizer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
learning_rate: A Tensor or a floating point value. The learning rate to
|
|
||||||
use.
|
|
||||||
use_locking: If True use locks for update operations.
|
|
||||||
l2_norm_clip: Clipping parameter for DP-SGD.
|
|
||||||
noise_multiplier: Noise multiplier for DP-SGD.
|
|
||||||
nb_microbatches: Number of microbatches in which to split the input.
|
|
||||||
name: Optional name prefix for the operations created when applying
|
|
||||||
gradients. Defaults to "DPGradientDescent". @compatibility(eager) When
|
|
||||||
eager execution is enabled, `learning_rate` can be a callable that takes
|
|
||||||
no arguments and returns the actual value to use. This can be useful for
|
|
||||||
changing these values across different invocations of optimizer
|
|
||||||
functions. @end_compatibility
|
|
||||||
"""
|
|
||||||
super(DPGradientDescentOptimizer, self).__init__(learning_rate, use_locking,
|
|
||||||
name)
|
|
||||||
stddev = l2_norm_clip * noise_multiplier
|
|
||||||
self._nb_microbatches = nb_microbatches
|
|
||||||
self._privacy_helper = ph.GaussianAverageQuery(l2_norm_clip, stddev,
|
|
||||||
nb_microbatches)
|
|
||||||
self._ph_global_state = self._privacy_helper.initial_global_state()
|
|
||||||
|
|
||||||
def compute_gradients(self,
|
|
||||||
loss,
|
|
||||||
var_list,
|
|
||||||
gate_gradients=tf.train.Optimizer.GATE_OP,
|
|
||||||
aggregation_method=None,
|
|
||||||
colocate_gradients_with_ops=False,
|
|
||||||
grad_loss=None):
|
|
||||||
|
|
||||||
# Note: it would be closer to the correct i.i.d. sampling of records if
|
|
||||||
# we sampled each microbatch from the appropriate binomial distribution,
|
|
||||||
# although that still wouldn't be quite correct because it would be sampling
|
|
||||||
# from the dataset without replacement.
|
|
||||||
microbatches_losses = tf.reshape(loss, [self._nb_microbatches, -1])
|
|
||||||
sample_params = (
|
|
||||||
self._privacy_helper.derive_sample_params(self._ph_global_state))
|
|
||||||
|
|
||||||
def process_microbatch(i, sample_state):
|
|
||||||
"""Process one microbatch (record) with privacy helper."""
|
|
||||||
grads, _ = zip(*super(DPGradientDescentOptimizer, self).compute_gradients(
|
|
||||||
tf.gather(microbatches_losses, [i]), var_list, gate_gradients,
|
|
||||||
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
|
||||||
grads_list = list(grads)
|
|
||||||
sample_state = self._privacy_helper.accumulate_record(
|
|
||||||
sample_params, sample_state, grads_list)
|
|
||||||
return [tf.add(i, 1), sample_state]
|
|
||||||
|
|
||||||
i = tf.constant(0)
|
|
||||||
|
|
||||||
if var_list is None:
|
|
||||||
var_list = (
|
|
||||||
tf.trainable_variables() +
|
|
||||||
tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
|
||||||
sample_state = self._privacy_helper.initial_sample_state(
|
|
||||||
self._ph_global_state, var_list)
|
|
||||||
|
|
||||||
# Use of while_loop here requires that sample_state be a nested structure of
|
|
||||||
# tensors. In general, we would prefer to allow it to be an arbitrary
|
|
||||||
# opaque type.
|
|
||||||
_, final_state = tf.while_loop(
|
|
||||||
lambda i, _: tf.less(i, self._nb_microbatches), process_microbatch,
|
|
||||||
[i, sample_state])
|
|
||||||
final_grads, self._ph_global_state = (
|
|
||||||
self._privacy_helper.get_noised_average(final_state,
|
|
||||||
self._ph_global_state))
|
|
||||||
|
|
||||||
return zip(final_grads, var_list)
|
|
100
privacy/optimizers/dp_optimizer.py
Normal file
100
privacy/optimizers/dp_optimizer.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
# Copyright 2018, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Differentially private optimizers for TensorFlow."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
import privacy.optimizers.gaussian_average_query as ph
|
||||||
|
|
||||||
|
|
||||||
|
def make_optimizer_class(cls):
|
||||||
|
"""Constructs a DP optimizer class from an existing one."""
|
||||||
|
if (tf.train.Optimizer.compute_gradients.__code__ is
|
||||||
|
not cls.compute_gradients.__code__):
|
||||||
|
tf.logging.warning(
|
||||||
|
'WARNING: Calling make_optimizer_class() on class %s that overrides '
|
||||||
|
'method compute_gradients(). Check to ensure that '
|
||||||
|
'make_optimizer_class() does not interfere with overridden version.',
|
||||||
|
cls.__name__)
|
||||||
|
|
||||||
|
class DPOptimizerClass(cls):
|
||||||
|
"""Differentially private subclass of given class cls."""
|
||||||
|
|
||||||
|
def __init__(self, l2_norm_clip, noise_multiplier, num_microbatches, *args,
|
||||||
|
**kwargs):
|
||||||
|
super(DPOptimizerClass, self).__init__(*args, **kwargs)
|
||||||
|
stddev = l2_norm_clip * noise_multiplier
|
||||||
|
self._num_microbatches = num_microbatches
|
||||||
|
self._privacy_helper = ph.GaussianAverageQuery(l2_norm_clip, stddev,
|
||||||
|
num_microbatches)
|
||||||
|
self._ph_global_state = self._privacy_helper.initial_global_state()
|
||||||
|
|
||||||
|
def compute_gradients(self,
|
||||||
|
loss,
|
||||||
|
var_list,
|
||||||
|
gate_gradients=tf.train.Optimizer.GATE_OP,
|
||||||
|
aggregation_method=None,
|
||||||
|
colocate_gradients_with_ops=False,
|
||||||
|
grad_loss=None):
|
||||||
|
|
||||||
|
# Note: it would be closer to the correct i.i.d. sampling of records if
|
||||||
|
# we sampled each microbatch from the appropriate binomial distribution,
|
||||||
|
# although that still wouldn't be quite correct because it would be
|
||||||
|
# sampling from the dataset without replacement.
|
||||||
|
microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
|
||||||
|
sample_params = (
|
||||||
|
self._privacy_helper.derive_sample_params(self._ph_global_state))
|
||||||
|
|
||||||
|
def process_microbatch(i, sample_state):
|
||||||
|
"""Process one microbatch (record) with privacy helper."""
|
||||||
|
grads, _ = zip(*super(cls, self).compute_gradients(
|
||||||
|
tf.gather(microbatches_losses, [i]), var_list, gate_gradients,
|
||||||
|
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
||||||
|
grads_list = list(grads)
|
||||||
|
sample_state = self._privacy_helper.accumulate_record(
|
||||||
|
sample_params, sample_state, grads_list)
|
||||||
|
return [tf.add(i, 1), sample_state]
|
||||||
|
|
||||||
|
i = tf.constant(0)
|
||||||
|
|
||||||
|
if var_list is None:
|
||||||
|
var_list = (
|
||||||
|
tf.trainable_variables() + tf.get_collection(
|
||||||
|
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
||||||
|
sample_state = self._privacy_helper.initial_sample_state(
|
||||||
|
self._ph_global_state, var_list)
|
||||||
|
|
||||||
|
# Use of while_loop here requires that sample_state be a nested structure
|
||||||
|
# of tensors. In general, we would prefer to allow it to be an arbitrary
|
||||||
|
# opaque type.
|
||||||
|
_, final_state = tf.while_loop(
|
||||||
|
lambda i, _: tf.less(i, self._num_microbatches), process_microbatch,
|
||||||
|
[i, sample_state])
|
||||||
|
final_grads, self._ph_global_state = (
|
||||||
|
self._privacy_helper.get_noised_average(final_state,
|
||||||
|
self._ph_global_state))
|
||||||
|
|
||||||
|
return zip(final_grads, var_list)
|
||||||
|
|
||||||
|
return DPOptimizerClass
|
||||||
|
|
||||||
|
|
||||||
|
DPAdagradOptimizer = make_optimizer_class(tf.train.AdagradOptimizer)
|
||||||
|
DPAdamOptimizer = make_optimizer_class(tf.train.AdamOptimizer)
|
||||||
|
DPGradientDescentOptimizer = make_optimizer_class(
|
||||||
|
tf.train.GradientDescentOptimizer)
|
|
@ -11,7 +11,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for differentially private optimizers."""
|
"""Tests for differentially private optimizers."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
@ -19,11 +18,11 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
import mock
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from privacy.optimizers import dp_adam
|
from privacy.optimizers import dp_optimizer
|
||||||
from privacy.optimizers import dp_gradient_descent
|
|
||||||
|
|
||||||
|
|
||||||
def loss(val0, val1):
|
def loss(val0, val1):
|
||||||
|
@ -33,22 +32,31 @@ def loss(val0, val1):
|
||||||
|
|
||||||
class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
# Parameters for testing: optimizer, nb_microbatches, expected answer.
|
# Parameters for testing: optimizer, num_microbatches, expected answer.
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('DPGradientDescent 1', dp_gradient_descent.DPGradientDescentOptimizer, 1,
|
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1,
|
||||||
[-10.0, -10.0]),
|
[-10.0, -10.0]),
|
||||||
('DPGradientDescent 2', dp_gradient_descent.DPGradientDescentOptimizer, 2,
|
('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2,
|
||||||
[-5.0, -5.0]),
|
[-5.0, -5.0]),
|
||||||
('DPGradientDescent 4', dp_gradient_descent.DPGradientDescentOptimizer, 4,
|
('DPGradientDescent 4', dp_optimizer.DPGradientDescentOptimizer, 4,
|
||||||
[-2.5, -2.5]), ('DPAdam 1', dp_adam.DPAdamOptimizer, 1, [-10.0, -10.0]),
|
[-2.5, -2.5]),
|
||||||
('DPAdam 2', dp_adam.DPAdamOptimizer, 2, [-5.0, -5.0]),
|
('DPAdagrad 1', dp_optimizer.DPAdagradOptimizer, 1, [-10.0, -10.0]),
|
||||||
('DPAdam 4', dp_adam.DPAdamOptimizer, 4, [-2.5, -2.5]))
|
('DPAdagrad 2', dp_optimizer.DPAdagradOptimizer, 2, [-5.0, -5.0]),
|
||||||
def testBaseline(self, cls, nb_microbatches, expected_answer):
|
('DPAdagrad 4', dp_optimizer.DPAdagradOptimizer, 4, [-2.5, -2.5]),
|
||||||
|
('DPAdam 1', dp_optimizer.DPAdamOptimizer, 1, [-10.0, -10.0]),
|
||||||
|
('DPAdam 2', dp_optimizer.DPAdamOptimizer, 2, [-5.0, -5.0]),
|
||||||
|
('DPAdam 4', dp_optimizer.DPAdamOptimizer, 4, [-2.5, -2.5]))
|
||||||
|
def testBaseline(self, cls, num_microbatches, expected_answer):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
var0 = tf.Variable([1.0, 2.0])
|
var0 = tf.Variable([1.0, 2.0])
|
||||||
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
||||||
|
|
||||||
opt = cls(learning_rate=2.0, nb_microbatches=nb_microbatches)
|
opt = cls(
|
||||||
|
l2_norm_clip=1.0e9,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
learning_rate=2.0)
|
||||||
|
|
||||||
self.evaluate(tf.global_variables_initializer())
|
self.evaluate(tf.global_variables_initializer())
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
|
@ -60,14 +68,20 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertAllCloseAccordingToType(expected_answer, grads_and_vars[0][0])
|
self.assertAllCloseAccordingToType(expected_answer, grads_and_vars[0][0])
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('DPGradientDescent', dp_gradient_descent.DPGradientDescentOptimizer),
|
('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer),
|
||||||
('DPAdam', dp_adam.DPAdamOptimizer))
|
('DPAdagrad', dp_optimizer.DPAdagradOptimizer),
|
||||||
|
('DPAdam', dp_optimizer.DPAdamOptimizer))
|
||||||
def testClippingNorm(self, cls):
|
def testClippingNorm(self, cls):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
var0 = tf.Variable([0.0, 0.0])
|
var0 = tf.Variable([0.0, 0.0])
|
||||||
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
|
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
|
||||||
|
|
||||||
opt = cls(learning_rate=2.0, l2_norm_clip=1.0, nb_microbatches=1)
|
opt = cls(
|
||||||
|
l2_norm_clip=1.0,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
num_microbatches=1,
|
||||||
|
learning_rate=2.0)
|
||||||
|
|
||||||
self.evaluate(tf.global_variables_initializer())
|
self.evaluate(tf.global_variables_initializer())
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
|
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
|
||||||
|
@ -78,18 +92,20 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
|
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('DPGradientDescent', dp_gradient_descent.DPGradientDescentOptimizer),
|
('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer),
|
||||||
('DPAdam', dp_adam.DPAdamOptimizer))
|
('DPAdagrad', dp_optimizer.DPAdagradOptimizer),
|
||||||
|
('DPAdam', dp_optimizer.DPAdamOptimizer))
|
||||||
def testNoiseMultiplier(self, cls):
|
def testNoiseMultiplier(self, cls):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
var0 = tf.Variable([0.0])
|
var0 = tf.Variable([0.0])
|
||||||
data0 = tf.Variable([[0.0]])
|
data0 = tf.Variable([[0.0]])
|
||||||
|
|
||||||
opt = cls(
|
opt = cls(
|
||||||
learning_rate=2.0,
|
|
||||||
l2_norm_clip=4.0,
|
l2_norm_clip=4.0,
|
||||||
noise_multiplier=2.0,
|
noise_multiplier=2.0,
|
||||||
nb_microbatches=1)
|
num_microbatches=1,
|
||||||
|
learning_rate=2.0)
|
||||||
|
|
||||||
self.evaluate(tf.global_variables_initializer())
|
self.evaluate(tf.global_variables_initializer())
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([0.0], self.evaluate(var0))
|
self.assertAllClose([0.0], self.evaluate(var0))
|
||||||
|
@ -103,6 +119,20 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
||||||
self.assertNear(np.std(grads), 2.0 * 4.0, 0.5)
|
self.assertNear(np.std(grads), 2.0 * 4.0, 0.5)
|
||||||
|
|
||||||
|
@mock.patch.object(tf, 'logging')
|
||||||
|
def testComputeGradientsOverrideWarning(self, mock_logging):
|
||||||
|
|
||||||
|
class SimpleOptimizer(tf.train.Optimizer):
|
||||||
|
|
||||||
|
def compute_gradients(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
dp_optimizer.make_optimizer_class(SimpleOptimizer)
|
||||||
|
mock_logging.warning.assert_called_once_with(
|
||||||
|
'WARNING: Calling make_optimizer_class() on class %s that overrides '
|
||||||
|
'method compute_gradients(). Check to ensure that '
|
||||||
|
'make_optimizer_class() does not interfere with overridden version.',
|
||||||
|
'SimpleOptimizer')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
51
tutorials/README.md
Normal file
51
tutorials/README.md
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
# Tutorials
|
||||||
|
|
||||||
|
As demonstrated on MNIST in `mnist_dpsgd_tutorial.py`, the easiest to use
|
||||||
|
a differentially private optimizer is to modify an existing training loop
|
||||||
|
to replace an existing vanilla optimizer with its differentially private
|
||||||
|
counterpart implemented in the library.
|
||||||
|
|
||||||
|
## Parameters
|
||||||
|
|
||||||
|
All of the optimizers share some privacy-specific parameters that need to
|
||||||
|
be tuned in addition to any existing hyperparameter. There are currently three:
|
||||||
|
* num_microbatches (int): The input data for each step (i.e., batch) of your
|
||||||
|
original training algorithm is split into this many microbatches. Generally,
|
||||||
|
increasing this will improve your utility but slow down your training in terms
|
||||||
|
of wall-clock time. The total number of examples consumed in one global step
|
||||||
|
remains the same. This number should evenly divide your input batch size.
|
||||||
|
* l2_norm_clip (float): The cumulative gradient across all network parameters
|
||||||
|
from each microbatch will be clipped so that its L2 norm is at most this
|
||||||
|
value. You should set this to something close to some percentile of what
|
||||||
|
you expect the gradient from each microbatch to be. In previous experiments,
|
||||||
|
we've found numbers from 0.5 to 1.0 to work reasonably well.
|
||||||
|
* noise_multiplier (float): This governs the amount of noise added during
|
||||||
|
training. Generally, more noise results in better privacy and lower utility.
|
||||||
|
This generally has to be at least 0.3 to obtain rigorous privacy guarantees,
|
||||||
|
but smaller values may still be acceptable for practical purposes.
|
||||||
|
|
||||||
|
## Measuring Privacy
|
||||||
|
|
||||||
|
Differential privacy is measured by two values, epsilon and delta. Roughly
|
||||||
|
speaking, they mean the following:
|
||||||
|
|
||||||
|
* epsilon gives a ceiling on how much the probability of a change in model
|
||||||
|
behavior can increase by including a single extra training example. This is
|
||||||
|
the far more sensitive value, and we usually want it to be at most 10.0 or
|
||||||
|
so. However, note that this is only an upper bound, and a large value of
|
||||||
|
epsilon may still mean good practical privacy.
|
||||||
|
* delta bounds the probability of an "unconditional" change in model behavior.
|
||||||
|
We can usually set this to a very small number (1e-7 or so) without
|
||||||
|
compromising utility. A rule of thumb is to set it to the inverse of the
|
||||||
|
order of magnitude of the training data size.
|
||||||
|
|
||||||
|
To find out the epsilon given a fixed delta value for your model, follow the
|
||||||
|
approach demonstrated in the `compute_epsilon` of the `mnist_dpsgd_tutorial.py`
|
||||||
|
where the arguments used to call the RDP accountant (i.e., the tool used to
|
||||||
|
compute the privacy guarantee) are:
|
||||||
|
|
||||||
|
* q : The sampling ratio, defined as (number of examples consumed in one
|
||||||
|
step) / (total training examples).
|
||||||
|
* stddev_to_sensitivity_ratio : The noise_multiplier from your parameters above.
|
||||||
|
* steps : The number of global steps taken.
|
||||||
|
|
183
tutorials/mnist_dpsgd_tutorial.py
Normal file
183
tutorials/mnist_dpsgd_tutorial.py
Normal file
|
@ -0,0 +1,183 @@
|
||||||
|
# Copyright 2018, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Training a CNN on MNIST with differentially private Adam optimizer."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from privacy.analysis.rdp_accountant import compute_rdp
|
||||||
|
from privacy.analysis.rdp_accountant import get_privacy_spent
|
||||||
|
from privacy.optimizers import dp_optimizer
|
||||||
|
|
||||||
|
tf.flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-Adam. If False,'
|
||||||
|
'train with vanilla Adam.')
|
||||||
|
tf.flags.DEFINE_float('learning_rate', 0.0015, 'Learning rate for training')
|
||||||
|
tf.flags.DEFINE_float('noise_multiplier', 1.0,
|
||||||
|
'Ratio of the standard deviation to the clipping norm')
|
||||||
|
tf.flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
||||||
|
tf.flags.DEFINE_integer('batch_size', 256, 'Batch size')
|
||||||
|
tf.flags.DEFINE_integer('epochs', 15, 'Number of epochs')
|
||||||
|
tf.flags.DEFINE_integer('microbatches', 256,
|
||||||
|
'Number of microbatches (must evenly divide batch_size')
|
||||||
|
tf.flags.DEFINE_string('model_dir', None, 'Model directory')
|
||||||
|
|
||||||
|
FLAGS = tf.flags.FLAGS
|
||||||
|
|
||||||
|
|
||||||
|
def cnn_model_fn(features, labels, mode):
|
||||||
|
"""Model function for a CNN."""
|
||||||
|
|
||||||
|
# Define CNN architecture using tf.keras.layers.
|
||||||
|
input_layer = tf.reshape(features['x'], [-1, 28, 28, 1])
|
||||||
|
y = tf.keras.layers.Conv2D(16, 8,
|
||||||
|
strides=2,
|
||||||
|
padding='same',
|
||||||
|
kernel_initializer='he_normal').apply(input_layer)
|
||||||
|
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
|
||||||
|
y = tf.keras.layers.Conv2D(32, 4,
|
||||||
|
strides=2,
|
||||||
|
padding='valid',
|
||||||
|
kernel_initializer='he_normal').apply(y)
|
||||||
|
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
|
||||||
|
y = tf.keras.layers.Flatten().apply(y)
|
||||||
|
y = tf.keras.layers.Dense(32, kernel_initializer='he_normal').apply(y)
|
||||||
|
logits = tf.keras.layers.Dense(10, kernel_initializer='he_normal').apply(y)
|
||||||
|
|
||||||
|
# Calculate loss as a vector (to support microbatches in DP-SGD).
|
||||||
|
vector_loss = tf.nn.softmax_cross_entropy_with_logits_v2(
|
||||||
|
labels=labels, logits=logits)
|
||||||
|
# Define mean of loss across minibatch (for reporting through tf.Estimator).
|
||||||
|
scalar_loss = tf.reduce_mean(vector_loss)
|
||||||
|
|
||||||
|
# Configure the training op (for TRAIN mode).
|
||||||
|
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||||
|
|
||||||
|
if FLAGS.dpsgd:
|
||||||
|
# Use DP version of AdamOptimizer. For illustration purposes, we do that
|
||||||
|
# here by calling make_optimizer_class() explicitly, though DP versions
|
||||||
|
# of standard optimizers are available in dp_optimizer.
|
||||||
|
dp_optimizer_class = dp_optimizer.make_optimizer_class(
|
||||||
|
tf.train.AdamOptimizer)
|
||||||
|
optimizer = dp_optimizer_class(
|
||||||
|
learning_rate=FLAGS.learning_rate,
|
||||||
|
noise_multiplier=FLAGS.noise_multiplier,
|
||||||
|
l2_norm_clip=FLAGS.l2_norm_clip,
|
||||||
|
num_microbatches=FLAGS.microbatches)
|
||||||
|
else:
|
||||||
|
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
|
||||||
|
global_step = tf.train.get_global_step()
|
||||||
|
train_op = optimizer.minimize(loss=vector_loss, global_step=global_step)
|
||||||
|
return tf.estimator.EstimatorSpec(mode=mode,
|
||||||
|
loss=scalar_loss,
|
||||||
|
train_op=train_op)
|
||||||
|
|
||||||
|
# Add evaluation metrics (for EVAL mode).
|
||||||
|
elif mode == tf.estimator.ModeKeys.EVAL:
|
||||||
|
eval_metric_ops = {
|
||||||
|
'accuracy':
|
||||||
|
tf.metrics.accuracy(
|
||||||
|
labels=tf.argmax(labels, axis=1),
|
||||||
|
predictions=tf.argmax(input=logits, axis=1))
|
||||||
|
}
|
||||||
|
return tf.estimator.EstimatorSpec(mode=mode,
|
||||||
|
loss=scalar_loss,
|
||||||
|
eval_metric_ops=eval_metric_ops)
|
||||||
|
|
||||||
|
|
||||||
|
def load_mnist():
|
||||||
|
"""Loads MNIST and preprocesses to combine training and validation data."""
|
||||||
|
train, test = tf.keras.datasets.mnist.load_data()
|
||||||
|
train_data, train_labels = train
|
||||||
|
test_data, test_labels = test
|
||||||
|
|
||||||
|
train_data = np.array(train_data, dtype=np.float32) / 255
|
||||||
|
test_data = np.array(test_data, dtype=np.float32) / 255
|
||||||
|
|
||||||
|
train_labels = tf.keras.utils.to_categorical(train_labels)
|
||||||
|
test_labels = tf.keras.utils.to_categorical(test_labels)
|
||||||
|
|
||||||
|
assert train_data.min() == 0.
|
||||||
|
assert train_data.max() == 1.
|
||||||
|
assert test_data.min() == 0.
|
||||||
|
assert test_data.max() == 1.
|
||||||
|
assert train_labels.shape[1] == 10
|
||||||
|
assert test_labels.shape[1] == 10
|
||||||
|
|
||||||
|
return train_data, train_labels, test_data, test_labels
|
||||||
|
|
||||||
|
|
||||||
|
def main(unused_argv):
|
||||||
|
tf.logging.set_verbosity(tf.logging.INFO)
|
||||||
|
if FLAGS.batch_size % FLAGS.microbatches != 0:
|
||||||
|
raise ValueError('Number of microbatches should divide evenly batch_size')
|
||||||
|
|
||||||
|
# Load training and test data.
|
||||||
|
train_data, train_labels, test_data, test_labels = load_mnist()
|
||||||
|
|
||||||
|
# Instantiate the tf.Estimator.
|
||||||
|
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn,
|
||||||
|
model_dir=FLAGS.model_dir)
|
||||||
|
|
||||||
|
# Create tf.Estimator input functions for the training and test data.
|
||||||
|
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||||
|
x={'x': train_data},
|
||||||
|
y=train_labels,
|
||||||
|
batch_size=FLAGS.batch_size,
|
||||||
|
num_epochs=FLAGS.epochs,
|
||||||
|
shuffle=True)
|
||||||
|
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||||
|
x={'x': test_data},
|
||||||
|
y=test_labels,
|
||||||
|
num_epochs=1,
|
||||||
|
shuffle=False)
|
||||||
|
|
||||||
|
# Define a function that computes privacy budget expended so far.
|
||||||
|
def compute_epsilon(steps):
|
||||||
|
"""Computes epsilon value for given hyperparameters."""
|
||||||
|
if FLAGS.noise_multiplier == 0.0:
|
||||||
|
return float('inf')
|
||||||
|
orders = [1 + x / 10. for x in range(1, 100)] + range(12, 64)
|
||||||
|
sampling_probability = FLAGS.batch_size / 60000
|
||||||
|
rdp = compute_rdp(q=sampling_probability,
|
||||||
|
stddev_to_sensitivity_ratio=FLAGS.noise_multiplier,
|
||||||
|
steps=steps,
|
||||||
|
orders=orders)
|
||||||
|
# Delta is set to 1e-5 because MNIST has 60000 training points.
|
||||||
|
return get_privacy_spent(orders, rdp, target_delta=1e-5)[0]
|
||||||
|
|
||||||
|
# Training loop.
|
||||||
|
steps_per_epoch = 60000 // FLAGS.batch_size
|
||||||
|
for epoch in range(1, FLAGS.epochs + 1):
|
||||||
|
# Train the model for one epoch.
|
||||||
|
mnist_classifier.train(input_fn=train_input_fn, steps=steps_per_epoch)
|
||||||
|
|
||||||
|
# Evaluate the model and print results
|
||||||
|
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
|
||||||
|
test_accuracy = eval_results['accuracy']
|
||||||
|
print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))
|
||||||
|
|
||||||
|
# Compute the privacy budget expended so far.
|
||||||
|
if FLAGS.dpsgd:
|
||||||
|
eps = compute_epsilon(epoch * steps_per_epoch)
|
||||||
|
print('For delta=1e-5, the current epsilon is: %.2f' % eps)
|
||||||
|
else:
|
||||||
|
print('Trained with vanilla non-private Adam optimizer')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.app.run()
|
Loading…
Reference in a new issue