Project import generated by Copybara.

PiperOrigin-RevId: 226056146
This commit is contained in:
Steve Chien 2018-12-18 14:06:54 -08:00 committed by Nicolas Papernot
parent ceee90b1ac
commit 1595ed3cd1
12 changed files with 918 additions and 254 deletions

View file

@ -1,8 +1,88 @@
# TensorFlow Privacy
This repository will contain implementations of TensorFlow optimizers that
support training machine learning models with (differential) privacy, as well
as tutorials and analysis tools for computing the privacy guarantees provided.
This repository contains the source code for TensorFlow Privacy, a Python
library that includes implementations of TensorFlow optimizers for training
machine learning models with differential privacy. The library comes with
tutorials and analysis tools for computing the privacy guarantees provided.
The content of this repository will supersede the following existing repository:
https://github.com/tensorflow/models/tree/master/research/differential_privacy
The TensorFlow Privacy library is under continual development, always welcoming
contributions. In particular, we always welcome help towards resolving the
issues currently open.
## Setting up TensorFlow Privacy
### Dependencies
This library uses [TensorFlow](https://www.tensorflow.org/) to define machine
learning models. Therefore, installing TensorFlow is a pre-requisite. You can
find instructions [here](https://www.tensorflow.org/install/). For better
performance, it is also recommended to install TensorFlow with GPU support
(detailed instructions on how to do this are available in the TensorFlow
installation documentation).
Installing TensorFlow will take care of all other dependencies like `numpy` and
`scipy`.
### Installing TensorFlow Privacy
First, clone this GitHub repository into a directory of your choice:
```
git clone https://github.com/tensorflow/privacy
```
You can then install the local package in "editable" mode in order to add it to
your `PYTHONPATH`:
```
cd privacy
pip install -e ./privacy
```
If you'd like to make contributions, we recommend first forking the repository
and then cloning your fork rather than cloning this repository directly.
## Contributing
Contributions are welcomed! Bug fixes and new features can be initiated through
Github pull requests. To speed the code review process, we ask that:
* When making code contributions to TensorFlow Privacy, you follow the `PEP8
with two spaces` coding style (the same as the one used by TensorFlow) in
your pull requests. In most cases this can be done by running `autopep8 -i
--indent-size 2 <file>` on the files you have edited.
* When making your first pull request, you
[sign the Google CLA](https://cla.developers.google.com/clas)
* We do not accept pull requests that add git submodules because of
[the problems that arise when maintaining git submodules](https://medium.com/@porteneuve/mastering-git-submodules-34c65e940407)
## Tutorials directory
To help you get started with the functionalities provided by this library, the
`tutorials/` folder comes with scripts demonstrating how to use the library
features.
NOTE: the tutorials are maintained carefully. However, they are not considered
part of the API and they can change at any time without warning. You should not
write 3rd party code that imports the tutorials and expect that the interface
will not break.
## Remarks
The content of this repository supersedes the following existing folder in the
tensorflow/models [repository](https://github.com/tensorflow/models/tree/master/research/differential_privacy)
## Contacts
If you have any questions that cannot be addressed by raising an issue, feel
free to contact:
* Nicolas Papernot (@npapernot)
* Steve Chien
* Galen Andrew (@galenmandrew)
## Copyright
Copyright 2018 - Google LLC

0
privacy/__init__.py Normal file
View file

View file

View file

@ -0,0 +1,295 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""RDP analysis of the Sampled Gaussian mechanism.
Functionality for computing Renyi differential privacy (RDP) of an additive
Sampled Gaussian mechanism (SGM). Its public interface consists of two methods:
compute_rdp(q, stddev_to_sensitivity_ratio, T, orders) computes RDP with for
SGM iterated T times.
get_privacy_spent(orders, rdp, target_eps, target_delta) computes delta
(or eps) given RDP at multiple orders and
a target value for eps (or delta).
Example use:
Suppose that we have run an SGM applied to a function with l2-sensitivity 1.
Its parameters are given as a list of tuples (q1, sigma1, T1), ...,
(qk, sigma_k, Tk), and we wish to compute eps for a given delta.
The example code would be:
max_order = 32
orders = range(2, max_order + 1)
rdp = np.zeros_like(orders, dtype=float)
for q, sigma, T in parameters:
rdp += rdp_accountant.compute_rdp(q, sigma, T, orders)
eps, _, opt_order = rdp_accountant.get_privacy_spent(rdp, target_delta=delta)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import sys
import numpy as np
from scipy import special
########################
# LOG-SPACE ARITHMETIC #
########################
def _log_add(logx, logy):
"""Add two numbers in the log space."""
a, b = min(logx, logy), max(logx, logy)
if a == -np.inf: # adding 0
return b
# Use exp(a) + exp(b) = (exp(a - b) + 1) * exp(b)
return math.log1p(math.exp(a - b)) + b # log1p(x) = log(x + 1)
def _log_sub(logx, logy):
"""Subtract two numbers in the log space. Answer must be non-negative."""
if logx < logy:
raise ValueError("The result of subtraction must be non-negative .")
if logy == -np.inf: # subtracting 0
return logx
if logx == logy:
return -np.inf # 0 is represented as -np.inf in the log space.
try:
# Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y).
return math.log(math.expm1(logx - logy)) + logy # expm1(x) = exp(x) - 1
except OverflowError:
return logx
def _log_print(logx):
"""Pretty print."""
if logx < math.log(sys.float_info.max):
return "{}".format(math.exp(logx))
else:
return "exp({})".format(logx)
def _compute_log_a_int(q, sigma, alpha):
"""Compute log(A_alpha) for integer alpha. 0 < q < 1."""
assert isinstance(alpha, (int, long))
# Initialize with 0 in the log space.
log_a = -np.inf
for i in range(alpha + 1):
log_coef_i = (
math.log(special.binom(alpha, i)) + i * math.log(q) +
(alpha - i) * math.log(1 - q))
s = log_coef_i + (i * i - i) / (2 * (sigma**2))
log_a = _log_add(log_a, s)
return float(log_a)
def _compute_log_a_frac(q, sigma, alpha):
"""Compute log(A_alpha) for fractional alpha. 0 < q < 1."""
# The two parts of A_alpha, integrals over (-inf,z0] and (z0, +inf), are
# initialized to 0 in the log space:
log_a0, log_a1 = -np.inf, -np.inf
i = 0
z0 = sigma**2 * math.log(1 / q - 1) + .5
while True: # do ... until loop
coef = special.binom(alpha, i)
log_coef = math.log(abs(coef))
j = alpha - i
log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q)
log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q)
log_e0 = math.log(.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma))
log_e1 = math.log(.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma))
log_s0 = log_t0 + (i * i - i) / (2 * (sigma**2)) + log_e0
log_s1 = log_t1 + (j * j - j) / (2 * (sigma**2)) + log_e1
if coef > 0:
log_a0 = _log_add(log_a0, log_s0)
log_a1 = _log_add(log_a1, log_s1)
else:
log_a0 = _log_sub(log_a0, log_s0)
log_a1 = _log_sub(log_a1, log_s1)
i += 1
if max(log_s0, log_s1) < -30:
break
return _log_add(log_a0, log_a1)
def _compute_log_a(q, sigma, alpha):
"""Compute log(A_alpha) for any positive finite alpha."""
if float(alpha).is_integer():
return _compute_log_a_int(q, sigma, int(alpha))
else:
return _compute_log_a_frac(q, sigma, alpha)
def _log_erfc(x):
try:
return math.log(2) + special.log_ndtr(-x * 2**.5)
except NameError:
# If log_ndtr is not available, approximate as follows:
r = special.erfc(x)
if r == 0.0:
# Using the Laurent series at infinity for the tail of the erfc function:
# erfc(x) ~ exp(-x^2-.5/x^2+.625/x^4)/(x*pi^.5)
# To verify in Mathematica:
# Series[Log[Erfc[x]] + Log[x] + Log[Pi]/2 + x^2, {x, Infinity, 6}]
return (-math.log(math.pi) / 2 - math.log(x) - x**2 - .5 * x**-2 +
.625 * x**-4 - 37. / 24. * x**-6 + 353. / 64. * x**-8)
else:
return math.log(r)
def _compute_delta(orders, rdp, eps):
"""Compute delta given an RDP curve and target epsilon.
Args:
orders: An array (or a scalar) of orders.
rdp: A list (or a scalar) of RDP guarantees.
eps: The target epsilon.
Returns:
Pair of (delta, optimal_order).
Raises:
ValueError: If input is malformed.
"""
orders_vec = np.atleast_1d(orders)
rdp_vec = np.atleast_1d(rdp)
if len(orders_vec) != len(rdp_vec):
raise ValueError("Input lists must have the same length.")
deltas = np.exp((rdp_vec - eps) * (orders_vec - 1))
idx_opt = np.argmin(deltas)
return min(deltas[idx_opt], 1.), orders_vec[idx_opt]
def _compute_eps(orders, rdp, delta):
"""Compute epsilon given an RDP curve and target delta.
Args:
orders: An array (or a scalar) of orders.
rdp: A list (or a scalar) of RDP guarantees.
delta: The target delta.
Returns:
Pair of (eps, optimal_order).
Raises:
ValueError: If input is malformed.
"""
orders_vec = np.atleast_1d(orders)
rdp_vec = np.atleast_1d(rdp)
if len(orders_vec) != len(rdp_vec):
raise ValueError("Input lists must have the same length.")
eps = rdp_vec - math.log(delta) / (orders_vec - 1)
idx_opt = np.nanargmin(eps) # Ignore NaNs
return eps[idx_opt], orders_vec[idx_opt]
def _compute_rdp(q, sigma, alpha):
"""Compute RDP of the Sampled Gaussian mechanism at order alpha.
Args:
q: The sampling rate.
sigma: The std of the additive Gaussian noise.
alpha: The order at which RDP is computed.
Returns:
RDP at alpha, can be np.inf.
"""
if q == 0:
return 0
if q == 1.:
return alpha / (2 * sigma**2)
if np.isinf(alpha):
return np.inf
return _compute_log_a(q, sigma, alpha) / (alpha - 1)
def compute_rdp(q, stddev_to_sensitivity_ratio, steps, orders):
"""Compute RDP of the Sampled Gaussian Mechanism for given parameters.
Args:
q: The sampling rate.
stddev_to_sensitivity_ratio: The ratio of std of the Gaussian noise to the
l2-sensitivity of the function to which it is added.
steps: The number of steps.
orders: An array (or a scalar) of RDP orders.
Returns:
The RDPs at all orders, can be np.inf.
"""
if np.isscalar(orders):
rdp = _compute_rdp(q, stddev_to_sensitivity_ratio, orders)
else:
rdp = np.array([_compute_rdp(q, stddev_to_sensitivity_ratio, order)
for order in orders])
return rdp * steps
def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
"""Compute delta (or eps) for given eps (or delta) from the RDP curve.
Args:
orders: An array (or a scalar) of RDP orders.
rdp: An array of RDP values. Must be of the same length as the orders list.
target_eps: If not None, the epsilon for which we compute the corresponding
delta.
target_delta: If not None, the delta for which we compute the corresponding
epsilon. Exactly one of target_eps and target_delta must be None.
Returns:
eps, delta, opt_order.
Raises:
ValueError: If target_eps and target_delta are messed up.
"""
if target_eps is None and target_delta is None:
raise ValueError(
"Exactly one out of eps and delta must be None. (Both are).")
if target_eps is not None and target_delta is not None:
raise ValueError(
"Exactly one out of eps and delta must be None. (None is).")
if target_eps is not None:
delta, opt_order = _compute_delta(orders, rdp, target_eps)
return target_eps, delta, opt_order
else:
eps, opt_order = _compute_eps(orders, rdp, target_delta)
return eps, target_delta, opt_order

View file

@ -0,0 +1,155 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for rdp_accountant.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from absl.testing import absltest
from absl.testing import parameterized
import mpmath as mp
import numpy as np
import rdp_accountant
class TestGaussianMoments(parameterized.TestCase):
#################################
# HELPER FUNCTIONS: #
# Exact computations using #
# multi-precision arithmetic. #
#################################
def _log_float_mp(self, x):
# Convert multi-precision input to float log space.
if x >= sys.float_info.min:
return float(mp.log(x))
else:
return -np.inf
def _integral_mp(self, fn, bounds=(-mp.inf, mp.inf)):
integral, _ = mp.quad(fn, bounds, error=True, maxdegree=8)
return integral
def _distributions_mp(self, sigma, q):
def _mu0(x):
return mp.npdf(x, mu=0, sigma=sigma)
def _mu1(x):
return mp.npdf(x, mu=1, sigma=sigma)
def _mu(x):
return (1 - q) * _mu0(x) + q * _mu1(x)
return _mu0, _mu # Closure!
def _mu1_over_mu0(self, x, sigma):
# Closed-form expression for N(1, sigma^2) / N(0, sigma^2) at x.
return mp.exp((2 * x - 1) / (2 * sigma**2))
def _mu_over_mu0(self, x, q, sigma):
return (1 - q) + q * self._mu1_over_mu0(x, sigma)
def _compute_a_mp(self, sigma, q, alpha):
"""Compute A_alpha for arbitrary alpha by numerical integration."""
mu0, _ = self._distributions_mp(sigma, q)
a_alpha_fn = lambda z: mu0(z) * self._mu_over_mu0(z, q, sigma)**alpha
a_alpha = self._integral_mp(a_alpha_fn)
return a_alpha
# TEST ROUTINES
def test_compute_rdp_no_data(self):
# q = 0
self.assertEqual(rdp_accountant.compute_rdp(0, 10, 1, 20), 0)
def test_compute_rdp_no_sampling(self):
# q = 1, RDP = alpha/2 * sigma^2
self.assertEqual(rdp_accountant.compute_rdp(1, 10, 1, 20), 0.1)
def test_compute_rdp_scalar(self):
rdp_scalar = rdp_accountant.compute_rdp(0.1, 2, 10, 5)
self.assertAlmostEqual(rdp_scalar, 0.07737, places=5)
def test_compute_rdp_sequence(self):
rdp_vec = rdp_accountant.compute_rdp(0.01, 2.5, 50,
[1.5, 2.5, 5, 50, 100, np.inf])
self.assertSequenceAlmostEqual(
rdp_vec, [0.00065, 0.001085, 0.00218075, 0.023846, 167.416307, np.inf],
delta=1e-5)
params = ({'q': 1e-7, 'sigma': .1, 'order': 1.01},
{'q': 1e-6, 'sigma': .1, 'order': 256},
{'q': 1e-5, 'sigma': .1, 'order': 256.1},
{'q': 1e-6, 'sigma': 1, 'order': 27},
{'q': 1e-4, 'sigma': 1., 'order': 1.5},
{'q': 1e-3, 'sigma': 1., 'order': 2},
{'q': .01, 'sigma': 10, 'order': 20},
{'q': .1, 'sigma': 100, 'order': 20.5},
{'q': .99, 'sigma': .1, 'order': 256},
{'q': .999, 'sigma': 100, 'order': 256.1})
# pylint:disable=undefined-variable
@parameterized.parameters(p for p in params)
def test_compute_log_a_equals_mp(self, q, sigma, order):
# Compare the cheap computation of log(A) with an expensive, multi-precision
# computation.
log_a = rdp_accountant._compute_log_a(q, sigma, order)
log_a_mp = self._log_float_mp(self._compute_a_mp(sigma, q, order))
np.testing.assert_allclose(log_a, log_a_mp, rtol=1e-4)
def test_get_privacy_spent_check_target_delta(self):
orders = range(2, 33)
rdp = rdp_accountant.compute_rdp(0.01, 4, 10000, orders)
eps, _, opt_order = rdp_accountant.get_privacy_spent(
orders, rdp, target_delta=1e-5)
self.assertAlmostEqual(eps, 1.258575, places=5)
self.assertEqual(opt_order, 20)
def test_get_privacy_spent_check_target_eps(self):
orders = range(2, 33)
rdp = rdp_accountant.compute_rdp(0.01, 4, 10000, orders)
_, delta, opt_order = rdp_accountant.get_privacy_spent(
orders, rdp, target_eps=1.258575)
self.assertAlmostEqual(delta, 1e-5)
self.assertEqual(opt_order, 20)
def test_check_composition(self):
orders = (1.25, 1.5, 1.75, 2., 2.5, 3., 4., 5., 6., 7., 8., 10., 12., 14.,
16., 20., 24., 28., 32., 64., 256.)
rdp = rdp_accountant.compute_rdp(q=1e-4,
stddev_to_sensitivity_ratio=.4,
steps=40000,
orders=orders)
eps, _, opt_order = rdp_accountant.get_privacy_spent(orders, rdp,
target_delta=1e-6)
rdp += rdp_accountant.compute_rdp(q=0.1,
stddev_to_sensitivity_ratio=2,
steps=100,
orders=orders)
eps, _, opt_order = rdp_accountant.get_privacy_spent(orders, rdp,
target_delta=1e-5)
self.assertAlmostEqual(eps, 8.509656, places=5)
self.assertEqual(opt_order, 2.5)
if __name__ == '__main__':
absltest.main()

View file

View file

@ -1,123 +0,0 @@
# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DPAdamOptimizer for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import privacy.optimizers.gaussian_average_query as ph
class DPAdamOptimizer(tf.train.AdamOptimizer):
"""Optimizer that implements the DP Adam algorithm.
"""
def __init__(self,
learning_rate,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
use_locking=False,
l2_norm_clip=1e9,
noise_multiplier=0.0,
nb_microbatches=1,
name='DPAdam'):
"""Construct a new DP Adam optimizer.
Args:
learning_rate: A Tensor or a floating point value. The learning rate to
use.
beta1: A float value or a constant float tensor.
The exponential decay rate for the 1st moment estimates.
beta2: A float value or a constant float tensor.
The exponential decay rate for the 2nd moment estimates.
epsilon: A small constant for numerical stability. This epsilon is
"epsilon hat" in the Kingma and Ba paper (in the formula just before
Section 2.1), not the epsilon in Algorithm 1 of the paper.
use_locking: If True use locks for update operations.
l2_norm_clip: Clipping parameter for DP-SGD.
noise_multiplier: Noise multiplier for DP-SGD.
nb_microbatches: Number of microbatches in which to split the input.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "DPAdam". @compatibility(eager) When eager
execution is enabled, `learning_rate` can be a callable that takes no
arguments and returns the actual value to use. This can be useful for
changing these values across different invocations of optimizer
functions. @end_compatibility
"""
super(DPAdamOptimizer, self).__init__(
learning_rate,
beta1,
beta2,
epsilon,
use_locking,
name)
stddev = l2_norm_clip * noise_multiplier
self._nb_microbatches = nb_microbatches
self._privacy_helper = ph.GaussianAverageQuery(l2_norm_clip, stddev,
nb_microbatches)
self._ph_global_state = self._privacy_helper.initial_global_state()
def compute_gradients(self,
loss,
var_list,
gate_gradients=tf.train.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None):
# Note: it would be closer to the correct i.i.d. sampling of records if
# we sampled each microbatch from the appropriate binomial distribution,
# although that still wouldn't be quite correct because it would be sampling
# from the dataset without replacement.
microbatches_losses = tf.reshape(loss, [self._nb_microbatches, -1])
sample_params = (
self._privacy_helper.derive_sample_params(self._ph_global_state))
def process_microbatch(i, sample_state):
"""Process one microbatch (record) with privacy helper."""
grads, _ = zip(*super(DPAdamOptimizer, self).compute_gradients(
tf.gather(microbatches_losses, [i]), var_list, gate_gradients,
aggregation_method, colocate_gradients_with_ops, grad_loss))
grads_list = list(grads)
sample_state = self._privacy_helper.accumulate_record(
sample_params, sample_state, grads_list)
return [tf.add(i, 1), sample_state]
i = tf.constant(0)
if var_list is None:
var_list = (
tf.trainable_variables() +
tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
sample_state = self._privacy_helper.initial_sample_state(
self._ph_global_state, var_list)
# Use of while_loop here requires that sample_state be a nested structure of
# tensors. In general, we would prefer to allow it to be an arbitrary
# opaque type.
_, final_state = tf.while_loop(
lambda i, _: tf.less(i, self._nb_microbatches), process_microbatch,
[i, sample_state])
final_grads, self._ph_global_state = (
self._privacy_helper.get_noised_average(final_state,
self._ph_global_state))
return zip(final_grads, var_list)

View file

@ -1,107 +0,0 @@
# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DPGradientDescentOptimizer for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import privacy.optimizers.gaussian_average_query as ph
class DPGradientDescentOptimizer(tf.train.GradientDescentOptimizer):
"""Optimizer that implements the DP gradient descent algorithm.
"""
def __init__(self,
learning_rate,
use_locking=False,
l2_norm_clip=1e9,
noise_multiplier=0.0,
nb_microbatches=1,
name='DPGradientDescent'):
"""Construct a new DP gradient descent optimizer.
Args:
learning_rate: A Tensor or a floating point value. The learning rate to
use.
use_locking: If True use locks for update operations.
l2_norm_clip: Clipping parameter for DP-SGD.
noise_multiplier: Noise multiplier for DP-SGD.
nb_microbatches: Number of microbatches in which to split the input.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "DPGradientDescent". @compatibility(eager) When
eager execution is enabled, `learning_rate` can be a callable that takes
no arguments and returns the actual value to use. This can be useful for
changing these values across different invocations of optimizer
functions. @end_compatibility
"""
super(DPGradientDescentOptimizer, self).__init__(learning_rate, use_locking,
name)
stddev = l2_norm_clip * noise_multiplier
self._nb_microbatches = nb_microbatches
self._privacy_helper = ph.GaussianAverageQuery(l2_norm_clip, stddev,
nb_microbatches)
self._ph_global_state = self._privacy_helper.initial_global_state()
def compute_gradients(self,
loss,
var_list,
gate_gradients=tf.train.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None):
# Note: it would be closer to the correct i.i.d. sampling of records if
# we sampled each microbatch from the appropriate binomial distribution,
# although that still wouldn't be quite correct because it would be sampling
# from the dataset without replacement.
microbatches_losses = tf.reshape(loss, [self._nb_microbatches, -1])
sample_params = (
self._privacy_helper.derive_sample_params(self._ph_global_state))
def process_microbatch(i, sample_state):
"""Process one microbatch (record) with privacy helper."""
grads, _ = zip(*super(DPGradientDescentOptimizer, self).compute_gradients(
tf.gather(microbatches_losses, [i]), var_list, gate_gradients,
aggregation_method, colocate_gradients_with_ops, grad_loss))
grads_list = list(grads)
sample_state = self._privacy_helper.accumulate_record(
sample_params, sample_state, grads_list)
return [tf.add(i, 1), sample_state]
i = tf.constant(0)
if var_list is None:
var_list = (
tf.trainable_variables() +
tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
sample_state = self._privacy_helper.initial_sample_state(
self._ph_global_state, var_list)
# Use of while_loop here requires that sample_state be a nested structure of
# tensors. In general, we would prefer to allow it to be an arbitrary
# opaque type.
_, final_state = tf.while_loop(
lambda i, _: tf.less(i, self._nb_microbatches), process_microbatch,
[i, sample_state])
final_grads, self._ph_global_state = (
self._privacy_helper.get_noised_average(final_state,
self._ph_global_state))
return zip(final_grads, var_list)

View file

@ -0,0 +1,100 @@
# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Differentially private optimizers for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import privacy.optimizers.gaussian_average_query as ph
def make_optimizer_class(cls):
"""Constructs a DP optimizer class from an existing one."""
if (tf.train.Optimizer.compute_gradients.__code__ is
not cls.compute_gradients.__code__):
tf.logging.warning(
'WARNING: Calling make_optimizer_class() on class %s that overrides '
'method compute_gradients(). Check to ensure that '
'make_optimizer_class() does not interfere with overridden version.',
cls.__name__)
class DPOptimizerClass(cls):
"""Differentially private subclass of given class cls."""
def __init__(self, l2_norm_clip, noise_multiplier, num_microbatches, *args,
**kwargs):
super(DPOptimizerClass, self).__init__(*args, **kwargs)
stddev = l2_norm_clip * noise_multiplier
self._num_microbatches = num_microbatches
self._privacy_helper = ph.GaussianAverageQuery(l2_norm_clip, stddev,
num_microbatches)
self._ph_global_state = self._privacy_helper.initial_global_state()
def compute_gradients(self,
loss,
var_list,
gate_gradients=tf.train.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None):
# Note: it would be closer to the correct i.i.d. sampling of records if
# we sampled each microbatch from the appropriate binomial distribution,
# although that still wouldn't be quite correct because it would be
# sampling from the dataset without replacement.
microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
sample_params = (
self._privacy_helper.derive_sample_params(self._ph_global_state))
def process_microbatch(i, sample_state):
"""Process one microbatch (record) with privacy helper."""
grads, _ = zip(*super(cls, self).compute_gradients(
tf.gather(microbatches_losses, [i]), var_list, gate_gradients,
aggregation_method, colocate_gradients_with_ops, grad_loss))
grads_list = list(grads)
sample_state = self._privacy_helper.accumulate_record(
sample_params, sample_state, grads_list)
return [tf.add(i, 1), sample_state]
i = tf.constant(0)
if var_list is None:
var_list = (
tf.trainable_variables() + tf.get_collection(
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
sample_state = self._privacy_helper.initial_sample_state(
self._ph_global_state, var_list)
# Use of while_loop here requires that sample_state be a nested structure
# of tensors. In general, we would prefer to allow it to be an arbitrary
# opaque type.
_, final_state = tf.while_loop(
lambda i, _: tf.less(i, self._num_microbatches), process_microbatch,
[i, sample_state])
final_grads, self._ph_global_state = (
self._privacy_helper.get_noised_average(final_state,
self._ph_global_state))
return zip(final_grads, var_list)
return DPOptimizerClass
DPAdagradOptimizer = make_optimizer_class(tf.train.AdagradOptimizer)
DPAdamOptimizer = make_optimizer_class(tf.train.AdamOptimizer)
DPGradientDescentOptimizer = make_optimizer_class(
tf.train.GradientDescentOptimizer)

View file

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for differentially private optimizers."""
from __future__ import absolute_import
@ -19,11 +18,11 @@ from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import mock
import numpy as np
import tensorflow as tf
from privacy.optimizers import dp_adam
from privacy.optimizers import dp_gradient_descent
from privacy.optimizers import dp_optimizer
def loss(val0, val1):
@ -33,22 +32,31 @@ def loss(val0, val1):
class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
# Parameters for testing: optimizer, nb_microbatches, expected answer.
# Parameters for testing: optimizer, num_microbatches, expected answer.
@parameterized.named_parameters(
('DPGradientDescent 1', dp_gradient_descent.DPGradientDescentOptimizer, 1,
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1,
[-10.0, -10.0]),
('DPGradientDescent 2', dp_gradient_descent.DPGradientDescentOptimizer, 2,
('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2,
[-5.0, -5.0]),
('DPGradientDescent 4', dp_gradient_descent.DPGradientDescentOptimizer, 4,
[-2.5, -2.5]), ('DPAdam 1', dp_adam.DPAdamOptimizer, 1, [-10.0, -10.0]),
('DPAdam 2', dp_adam.DPAdamOptimizer, 2, [-5.0, -5.0]),
('DPAdam 4', dp_adam.DPAdamOptimizer, 4, [-2.5, -2.5]))
def testBaseline(self, cls, nb_microbatches, expected_answer):
('DPGradientDescent 4', dp_optimizer.DPGradientDescentOptimizer, 4,
[-2.5, -2.5]),
('DPAdagrad 1', dp_optimizer.DPAdagradOptimizer, 1, [-10.0, -10.0]),
('DPAdagrad 2', dp_optimizer.DPAdagradOptimizer, 2, [-5.0, -5.0]),
('DPAdagrad 4', dp_optimizer.DPAdagradOptimizer, 4, [-2.5, -2.5]),
('DPAdam 1', dp_optimizer.DPAdamOptimizer, 1, [-10.0, -10.0]),
('DPAdam 2', dp_optimizer.DPAdamOptimizer, 2, [-5.0, -5.0]),
('DPAdam 4', dp_optimizer.DPAdamOptimizer, 4, [-2.5, -2.5]))
def testBaseline(self, cls, num_microbatches, expected_answer):
with self.cached_session() as sess:
var0 = tf.Variable([1.0, 2.0])
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
opt = cls(learning_rate=2.0, nb_microbatches=nb_microbatches)
opt = cls(
l2_norm_clip=1.0e9,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
@ -60,14 +68,20 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllCloseAccordingToType(expected_answer, grads_and_vars[0][0])
@parameterized.named_parameters(
('DPGradientDescent', dp_gradient_descent.DPGradientDescentOptimizer),
('DPAdam', dp_adam.DPAdamOptimizer))
('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer),
('DPAdagrad', dp_optimizer.DPAdagradOptimizer),
('DPAdam', dp_optimizer.DPAdamOptimizer))
def testClippingNorm(self, cls):
with self.cached_session() as sess:
var0 = tf.Variable([0.0, 0.0])
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
opt = cls(learning_rate=2.0, l2_norm_clip=1.0, nb_microbatches=1)
opt = cls(
l2_norm_clip=1.0,
noise_multiplier=0.0,
num_microbatches=1,
learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
@ -78,18 +92,20 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
@parameterized.named_parameters(
('DPGradientDescent', dp_gradient_descent.DPGradientDescentOptimizer),
('DPAdam', dp_adam.DPAdamOptimizer))
('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer),
('DPAdagrad', dp_optimizer.DPAdagradOptimizer),
('DPAdam', dp_optimizer.DPAdamOptimizer))
def testNoiseMultiplier(self, cls):
with self.cached_session() as sess:
var0 = tf.Variable([0.0])
data0 = tf.Variable([[0.0]])
opt = cls(
learning_rate=2.0,
l2_norm_clip=4.0,
noise_multiplier=2.0,
nb_microbatches=1)
num_microbatches=1,
learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([0.0], self.evaluate(var0))
@ -103,6 +119,20 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear(np.std(grads), 2.0 * 4.0, 0.5)
@mock.patch.object(tf, 'logging')
def testComputeGradientsOverrideWarning(self, mock_logging):
class SimpleOptimizer(tf.train.Optimizer):
def compute_gradients(self):
return 0
dp_optimizer.make_optimizer_class(SimpleOptimizer)
mock_logging.warning.assert_called_once_with(
'WARNING: Calling make_optimizer_class() on class %s that overrides '
'method compute_gradients(). Check to ensure that '
'make_optimizer_class() does not interfere with overridden version.',
'SimpleOptimizer')
if __name__ == '__main__':
tf.test.main()

51
tutorials/README.md Normal file
View file

@ -0,0 +1,51 @@
# Tutorials
As demonstrated on MNIST in `mnist_dpsgd_tutorial.py`, the easiest to use
a differentially private optimizer is to modify an existing training loop
to replace an existing vanilla optimizer with its differentially private
counterpart implemented in the library.
## Parameters
All of the optimizers share some privacy-specific parameters that need to
be tuned in addition to any existing hyperparameter. There are currently three:
* num_microbatches (int): The input data for each step (i.e., batch) of your
original training algorithm is split into this many microbatches. Generally,
increasing this will improve your utility but slow down your training in terms
of wall-clock time. The total number of examples consumed in one global step
remains the same. This number should evenly divide your input batch size.
* l2_norm_clip (float): The cumulative gradient across all network parameters
from each microbatch will be clipped so that its L2 norm is at most this
value. You should set this to something close to some percentile of what
you expect the gradient from each microbatch to be. In previous experiments,
we've found numbers from 0.5 to 1.0 to work reasonably well.
* noise_multiplier (float): This governs the amount of noise added during
training. Generally, more noise results in better privacy and lower utility.
This generally has to be at least 0.3 to obtain rigorous privacy guarantees,
but smaller values may still be acceptable for practical purposes.
## Measuring Privacy
Differential privacy is measured by two values, epsilon and delta. Roughly
speaking, they mean the following:
* epsilon gives a ceiling on how much the probability of a change in model
behavior can increase by including a single extra training example. This is
the far more sensitive value, and we usually want it to be at most 10.0 or
so. However, note that this is only an upper bound, and a large value of
epsilon may still mean good practical privacy.
* delta bounds the probability of an "unconditional" change in model behavior.
We can usually set this to a very small number (1e-7 or so) without
compromising utility. A rule of thumb is to set it to the inverse of the
order of magnitude of the training data size.
To find out the epsilon given a fixed delta value for your model, follow the
approach demonstrated in the `compute_epsilon` of the `mnist_dpsgd_tutorial.py`
where the arguments used to call the RDP accountant (i.e., the tool used to
compute the privacy guarantee) are:
* q : The sampling ratio, defined as (number of examples consumed in one
step) / (total training examples).
* stddev_to_sensitivity_ratio : The noise_multiplier from your parameters above.
* steps : The number of global steps taken.

View file

@ -0,0 +1,183 @@
# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training a CNN on MNIST with differentially private Adam optimizer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from privacy.analysis.rdp_accountant import compute_rdp
from privacy.analysis.rdp_accountant import get_privacy_spent
from privacy.optimizers import dp_optimizer
tf.flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-Adam. If False,'
'train with vanilla Adam.')
tf.flags.DEFINE_float('learning_rate', 0.0015, 'Learning rate for training')
tf.flags.DEFINE_float('noise_multiplier', 1.0,
'Ratio of the standard deviation to the clipping norm')
tf.flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
tf.flags.DEFINE_integer('batch_size', 256, 'Batch size')
tf.flags.DEFINE_integer('epochs', 15, 'Number of epochs')
tf.flags.DEFINE_integer('microbatches', 256,
'Number of microbatches (must evenly divide batch_size')
tf.flags.DEFINE_string('model_dir', None, 'Model directory')
FLAGS = tf.flags.FLAGS
def cnn_model_fn(features, labels, mode):
"""Model function for a CNN."""
# Define CNN architecture using tf.keras.layers.
input_layer = tf.reshape(features['x'], [-1, 28, 28, 1])
y = tf.keras.layers.Conv2D(16, 8,
strides=2,
padding='same',
kernel_initializer='he_normal').apply(input_layer)
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
y = tf.keras.layers.Conv2D(32, 4,
strides=2,
padding='valid',
kernel_initializer='he_normal').apply(y)
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
y = tf.keras.layers.Flatten().apply(y)
y = tf.keras.layers.Dense(32, kernel_initializer='he_normal').apply(y)
logits = tf.keras.layers.Dense(10, kernel_initializer='he_normal').apply(y)
# Calculate loss as a vector (to support microbatches in DP-SGD).
vector_loss = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=labels, logits=logits)
# Define mean of loss across minibatch (for reporting through tf.Estimator).
scalar_loss = tf.reduce_mean(vector_loss)
# Configure the training op (for TRAIN mode).
if mode == tf.estimator.ModeKeys.TRAIN:
if FLAGS.dpsgd:
# Use DP version of AdamOptimizer. For illustration purposes, we do that
# here by calling make_optimizer_class() explicitly, though DP versions
# of standard optimizers are available in dp_optimizer.
dp_optimizer_class = dp_optimizer.make_optimizer_class(
tf.train.AdamOptimizer)
optimizer = dp_optimizer_class(
learning_rate=FLAGS.learning_rate,
noise_multiplier=FLAGS.noise_multiplier,
l2_norm_clip=FLAGS.l2_norm_clip,
num_microbatches=FLAGS.microbatches)
else:
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
global_step = tf.train.get_global_step()
train_op = optimizer.minimize(loss=vector_loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode,
loss=scalar_loss,
train_op=train_op)
# Add evaluation metrics (for EVAL mode).
elif mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = {
'accuracy':
tf.metrics.accuracy(
labels=tf.argmax(labels, axis=1),
predictions=tf.argmax(input=logits, axis=1))
}
return tf.estimator.EstimatorSpec(mode=mode,
loss=scalar_loss,
eval_metric_ops=eval_metric_ops)
def load_mnist():
"""Loads MNIST and preprocesses to combine training and validation data."""
train, test = tf.keras.datasets.mnist.load_data()
train_data, train_labels = train
test_data, test_labels = test
train_data = np.array(train_data, dtype=np.float32) / 255
test_data = np.array(test_data, dtype=np.float32) / 255
train_labels = tf.keras.utils.to_categorical(train_labels)
test_labels = tf.keras.utils.to_categorical(test_labels)
assert train_data.min() == 0.
assert train_data.max() == 1.
assert test_data.min() == 0.
assert test_data.max() == 1.
assert train_labels.shape[1] == 10
assert test_labels.shape[1] == 10
return train_data, train_labels, test_data, test_labels
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.batch_size % FLAGS.microbatches != 0:
raise ValueError('Number of microbatches should divide evenly batch_size')
# Load training and test data.
train_data, train_labels, test_data, test_labels = load_mnist()
# Instantiate the tf.Estimator.
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn,
model_dir=FLAGS.model_dir)
# Create tf.Estimator input functions for the training and test data.
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={'x': train_data},
y=train_labels,
batch_size=FLAGS.batch_size,
num_epochs=FLAGS.epochs,
shuffle=True)
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
x={'x': test_data},
y=test_labels,
num_epochs=1,
shuffle=False)
# Define a function that computes privacy budget expended so far.
def compute_epsilon(steps):
"""Computes epsilon value for given hyperparameters."""
if FLAGS.noise_multiplier == 0.0:
return float('inf')
orders = [1 + x / 10. for x in range(1, 100)] + range(12, 64)
sampling_probability = FLAGS.batch_size / 60000
rdp = compute_rdp(q=sampling_probability,
stddev_to_sensitivity_ratio=FLAGS.noise_multiplier,
steps=steps,
orders=orders)
# Delta is set to 1e-5 because MNIST has 60000 training points.
return get_privacy_spent(orders, rdp, target_delta=1e-5)[0]
# Training loop.
steps_per_epoch = 60000 // FLAGS.batch_size
for epoch in range(1, FLAGS.epochs + 1):
# Train the model for one epoch.
mnist_classifier.train(input_fn=train_input_fn, steps=steps_per_epoch)
# Evaluate the model and print results
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
test_accuracy = eval_results['accuracy']
print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))
# Compute the privacy budget expended so far.
if FLAGS.dpsgd:
eps = compute_epsilon(epoch * steps_per_epoch)
print('For delta=1e-5, the current epsilon is: %.2f' % eps)
else:
print('Trained with vanilla non-private Adam optimizer')
if __name__ == '__main__':
tf.app.run()