diff --git a/README.md b/README.md index 1e799d7..7cdf98a 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,88 @@ # TensorFlow Privacy -This repository will contain implementations of TensorFlow optimizers that -support training machine learning models with (differential) privacy, as well -as tutorials and analysis tools for computing the privacy guarantees provided. +This repository contains the source code for TensorFlow Privacy, a Python +library that includes implementations of TensorFlow optimizers for training +machine learning models with differential privacy. The library comes with +tutorials and analysis tools for computing the privacy guarantees provided. -The content of this repository will supersede the following existing repository: - https://github.com/tensorflow/models/tree/master/research/differential_privacy +The TensorFlow Privacy library is under continual development, always welcoming +contributions. In particular, we always welcome help towards resolving the +issues currently open. + +## Setting up TensorFlow Privacy + +### Dependencies + +This library uses [TensorFlow](https://www.tensorflow.org/) to define machine +learning models. Therefore, installing TensorFlow is a pre-requisite. You can +find instructions [here](https://www.tensorflow.org/install/). For better +performance, it is also recommended to install TensorFlow with GPU support +(detailed instructions on how to do this are available in the TensorFlow +installation documentation). + +Installing TensorFlow will take care of all other dependencies like `numpy` and +`scipy`. + +### Installing TensorFlow Privacy + +First, clone this GitHub repository into a directory of your choice: + +``` +git clone https://github.com/tensorflow/privacy +``` + +You can then install the local package in "editable" mode in order to add it to +your `PYTHONPATH`: + +``` +cd privacy +pip install -e ./privacy +``` + +If you'd like to make contributions, we recommend first forking the repository +and then cloning your fork rather than cloning this repository directly. + +## Contributing + +Contributions are welcomed! Bug fixes and new features can be initiated through +Github pull requests. To speed the code review process, we ask that: + +* When making code contributions to TensorFlow Privacy, you follow the `PEP8 + with two spaces` coding style (the same as the one used by TensorFlow) in + your pull requests. In most cases this can be done by running `autopep8 -i + --indent-size 2 ` on the files you have edited. + +* When making your first pull request, you + [sign the Google CLA](https://cla.developers.google.com/clas) + +* We do not accept pull requests that add git submodules because of + [the problems that arise when maintaining git submodules](https://medium.com/@porteneuve/mastering-git-submodules-34c65e940407) + +## Tutorials directory + +To help you get started with the functionalities provided by this library, the +`tutorials/` folder comes with scripts demonstrating how to use the library +features. + +NOTE: the tutorials are maintained carefully. However, they are not considered +part of the API and they can change at any time without warning. You should not +write 3rd party code that imports the tutorials and expect that the interface +will not break. + +## Remarks + +The content of this repository supersedes the following existing folder in the +tensorflow/models [repository](https://github.com/tensorflow/models/tree/master/research/differential_privacy) + +## Contacts + +If you have any questions that cannot be addressed by raising an issue, feel +free to contact: + +* Nicolas Papernot (@npapernot) +* Steve Chien +* Galen Andrew (@galenmandrew) + +## Copyright + +Copyright 2018 - Google LLC diff --git a/privacy/__init__.py b/privacy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/privacy/analysis/__init__.py b/privacy/analysis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/privacy/analysis/rdp_accountant.py b/privacy/analysis/rdp_accountant.py new file mode 100644 index 0000000..4910385 --- /dev/null +++ b/privacy/analysis/rdp_accountant.py @@ -0,0 +1,295 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RDP analysis of the Sampled Gaussian mechanism. + +Functionality for computing Renyi differential privacy (RDP) of an additive +Sampled Gaussian mechanism (SGM). Its public interface consists of two methods: + compute_rdp(q, stddev_to_sensitivity_ratio, T, orders) computes RDP with for + SGM iterated T times. + get_privacy_spent(orders, rdp, target_eps, target_delta) computes delta + (or eps) given RDP at multiple orders and + a target value for eps (or delta). + +Example use: + +Suppose that we have run an SGM applied to a function with l2-sensitivity 1. +Its parameters are given as a list of tuples (q1, sigma1, T1), ..., +(qk, sigma_k, Tk), and we wish to compute eps for a given delta. +The example code would be: + + max_order = 32 + orders = range(2, max_order + 1) + rdp = np.zeros_like(orders, dtype=float) + for q, sigma, T in parameters: + rdp += rdp_accountant.compute_rdp(q, sigma, T, orders) + eps, _, opt_order = rdp_accountant.get_privacy_spent(rdp, target_delta=delta) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import sys + +import numpy as np +from scipy import special + +######################## +# LOG-SPACE ARITHMETIC # +######################## + + +def _log_add(logx, logy): + """Add two numbers in the log space.""" + a, b = min(logx, logy), max(logx, logy) + if a == -np.inf: # adding 0 + return b + # Use exp(a) + exp(b) = (exp(a - b) + 1) * exp(b) + return math.log1p(math.exp(a - b)) + b # log1p(x) = log(x + 1) + + +def _log_sub(logx, logy): + """Subtract two numbers in the log space. Answer must be non-negative.""" + if logx < logy: + raise ValueError("The result of subtraction must be non-negative .") + if logy == -np.inf: # subtracting 0 + return logx + if logx == logy: + return -np.inf # 0 is represented as -np.inf in the log space. + + try: + # Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y). + return math.log(math.expm1(logx - logy)) + logy # expm1(x) = exp(x) - 1 + except OverflowError: + return logx + + +def _log_print(logx): + """Pretty print.""" + if logx < math.log(sys.float_info.max): + return "{}".format(math.exp(logx)) + else: + return "exp({})".format(logx) + + +def _compute_log_a_int(q, sigma, alpha): + """Compute log(A_alpha) for integer alpha. 0 < q < 1.""" + assert isinstance(alpha, (int, long)) + + # Initialize with 0 in the log space. + log_a = -np.inf + + for i in range(alpha + 1): + log_coef_i = ( + math.log(special.binom(alpha, i)) + i * math.log(q) + + (alpha - i) * math.log(1 - q)) + + s = log_coef_i + (i * i - i) / (2 * (sigma**2)) + log_a = _log_add(log_a, s) + + return float(log_a) + + +def _compute_log_a_frac(q, sigma, alpha): + """Compute log(A_alpha) for fractional alpha. 0 < q < 1.""" + # The two parts of A_alpha, integrals over (-inf,z0] and (z0, +inf), are + # initialized to 0 in the log space: + log_a0, log_a1 = -np.inf, -np.inf + i = 0 + + z0 = sigma**2 * math.log(1 / q - 1) + .5 + + while True: # do ... until loop + coef = special.binom(alpha, i) + log_coef = math.log(abs(coef)) + j = alpha - i + + log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q) + log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q) + + log_e0 = math.log(.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma)) + log_e1 = math.log(.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma)) + + log_s0 = log_t0 + (i * i - i) / (2 * (sigma**2)) + log_e0 + log_s1 = log_t1 + (j * j - j) / (2 * (sigma**2)) + log_e1 + + if coef > 0: + log_a0 = _log_add(log_a0, log_s0) + log_a1 = _log_add(log_a1, log_s1) + else: + log_a0 = _log_sub(log_a0, log_s0) + log_a1 = _log_sub(log_a1, log_s1) + + i += 1 + if max(log_s0, log_s1) < -30: + break + + return _log_add(log_a0, log_a1) + + +def _compute_log_a(q, sigma, alpha): + """Compute log(A_alpha) for any positive finite alpha.""" + if float(alpha).is_integer(): + return _compute_log_a_int(q, sigma, int(alpha)) + else: + return _compute_log_a_frac(q, sigma, alpha) + + +def _log_erfc(x): + try: + return math.log(2) + special.log_ndtr(-x * 2**.5) + except NameError: + # If log_ndtr is not available, approximate as follows: + r = special.erfc(x) + if r == 0.0: + # Using the Laurent series at infinity for the tail of the erfc function: + # erfc(x) ~ exp(-x^2-.5/x^2+.625/x^4)/(x*pi^.5) + # To verify in Mathematica: + # Series[Log[Erfc[x]] + Log[x] + Log[Pi]/2 + x^2, {x, Infinity, 6}] + return (-math.log(math.pi) / 2 - math.log(x) - x**2 - .5 * x**-2 + + .625 * x**-4 - 37. / 24. * x**-6 + 353. / 64. * x**-8) + else: + return math.log(r) + + +def _compute_delta(orders, rdp, eps): + """Compute delta given an RDP curve and target epsilon. + + Args: + orders: An array (or a scalar) of orders. + rdp: A list (or a scalar) of RDP guarantees. + eps: The target epsilon. + + Returns: + Pair of (delta, optimal_order). + + Raises: + ValueError: If input is malformed. + + """ + orders_vec = np.atleast_1d(orders) + rdp_vec = np.atleast_1d(rdp) + + if len(orders_vec) != len(rdp_vec): + raise ValueError("Input lists must have the same length.") + + deltas = np.exp((rdp_vec - eps) * (orders_vec - 1)) + idx_opt = np.argmin(deltas) + return min(deltas[idx_opt], 1.), orders_vec[idx_opt] + + +def _compute_eps(orders, rdp, delta): + """Compute epsilon given an RDP curve and target delta. + + Args: + orders: An array (or a scalar) of orders. + rdp: A list (or a scalar) of RDP guarantees. + delta: The target delta. + + Returns: + Pair of (eps, optimal_order). + + Raises: + ValueError: If input is malformed. + + """ + orders_vec = np.atleast_1d(orders) + rdp_vec = np.atleast_1d(rdp) + + if len(orders_vec) != len(rdp_vec): + raise ValueError("Input lists must have the same length.") + + eps = rdp_vec - math.log(delta) / (orders_vec - 1) + + idx_opt = np.nanargmin(eps) # Ignore NaNs + return eps[idx_opt], orders_vec[idx_opt] + + +def _compute_rdp(q, sigma, alpha): + """Compute RDP of the Sampled Gaussian mechanism at order alpha. + + Args: + q: The sampling rate. + sigma: The std of the additive Gaussian noise. + alpha: The order at which RDP is computed. + + Returns: + RDP at alpha, can be np.inf. + """ + if q == 0: + return 0 + + if q == 1.: + return alpha / (2 * sigma**2) + + if np.isinf(alpha): + return np.inf + + return _compute_log_a(q, sigma, alpha) / (alpha - 1) + + +def compute_rdp(q, stddev_to_sensitivity_ratio, steps, orders): + """Compute RDP of the Sampled Gaussian Mechanism for given parameters. + + Args: + q: The sampling rate. + stddev_to_sensitivity_ratio: The ratio of std of the Gaussian noise to the + l2-sensitivity of the function to which it is added. + steps: The number of steps. + orders: An array (or a scalar) of RDP orders. + + Returns: + The RDPs at all orders, can be np.inf. + """ + + if np.isscalar(orders): + rdp = _compute_rdp(q, stddev_to_sensitivity_ratio, orders) + else: + rdp = np.array([_compute_rdp(q, stddev_to_sensitivity_ratio, order) + for order in orders]) + + return rdp * steps + + +def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None): + """Compute delta (or eps) for given eps (or delta) from the RDP curve. + + Args: + orders: An array (or a scalar) of RDP orders. + rdp: An array of RDP values. Must be of the same length as the orders list. + target_eps: If not None, the epsilon for which we compute the corresponding + delta. + target_delta: If not None, the delta for which we compute the corresponding + epsilon. Exactly one of target_eps and target_delta must be None. + Returns: + eps, delta, opt_order. + + Raises: + ValueError: If target_eps and target_delta are messed up. + """ + if target_eps is None and target_delta is None: + raise ValueError( + "Exactly one out of eps and delta must be None. (Both are).") + + if target_eps is not None and target_delta is not None: + raise ValueError( + "Exactly one out of eps and delta must be None. (None is).") + + if target_eps is not None: + delta, opt_order = _compute_delta(orders, rdp, target_eps) + return target_eps, delta, opt_order + else: + eps, opt_order = _compute_eps(orders, rdp, target_delta) + return eps, target_delta, opt_order diff --git a/privacy/analysis/rdp_accountant_test.py b/privacy/analysis/rdp_accountant_test.py new file mode 100644 index 0000000..603f7ad --- /dev/null +++ b/privacy/analysis/rdp_accountant_test.py @@ -0,0 +1,155 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for rdp_accountant.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from absl.testing import absltest +from absl.testing import parameterized +import mpmath as mp +import numpy as np + +import rdp_accountant + + +class TestGaussianMoments(parameterized.TestCase): + ################################# + # HELPER FUNCTIONS: # + # Exact computations using # + # multi-precision arithmetic. # + ################################# + + def _log_float_mp(self, x): + # Convert multi-precision input to float log space. + if x >= sys.float_info.min: + return float(mp.log(x)) + else: + return -np.inf + + def _integral_mp(self, fn, bounds=(-mp.inf, mp.inf)): + integral, _ = mp.quad(fn, bounds, error=True, maxdegree=8) + return integral + + def _distributions_mp(self, sigma, q): + + def _mu0(x): + return mp.npdf(x, mu=0, sigma=sigma) + + def _mu1(x): + return mp.npdf(x, mu=1, sigma=sigma) + + def _mu(x): + return (1 - q) * _mu0(x) + q * _mu1(x) + + return _mu0, _mu # Closure! + + def _mu1_over_mu0(self, x, sigma): + # Closed-form expression for N(1, sigma^2) / N(0, sigma^2) at x. + return mp.exp((2 * x - 1) / (2 * sigma**2)) + + def _mu_over_mu0(self, x, q, sigma): + return (1 - q) + q * self._mu1_over_mu0(x, sigma) + + def _compute_a_mp(self, sigma, q, alpha): + """Compute A_alpha for arbitrary alpha by numerical integration.""" + mu0, _ = self._distributions_mp(sigma, q) + a_alpha_fn = lambda z: mu0(z) * self._mu_over_mu0(z, q, sigma)**alpha + a_alpha = self._integral_mp(a_alpha_fn) + return a_alpha + + # TEST ROUTINES + def test_compute_rdp_no_data(self): + # q = 0 + self.assertEqual(rdp_accountant.compute_rdp(0, 10, 1, 20), 0) + + def test_compute_rdp_no_sampling(self): + # q = 1, RDP = alpha/2 * sigma^2 + self.assertEqual(rdp_accountant.compute_rdp(1, 10, 1, 20), 0.1) + + def test_compute_rdp_scalar(self): + rdp_scalar = rdp_accountant.compute_rdp(0.1, 2, 10, 5) + self.assertAlmostEqual(rdp_scalar, 0.07737, places=5) + + def test_compute_rdp_sequence(self): + rdp_vec = rdp_accountant.compute_rdp(0.01, 2.5, 50, + [1.5, 2.5, 5, 50, 100, np.inf]) + self.assertSequenceAlmostEqual( + rdp_vec, [0.00065, 0.001085, 0.00218075, 0.023846, 167.416307, np.inf], + delta=1e-5) + + params = ({'q': 1e-7, 'sigma': .1, 'order': 1.01}, + {'q': 1e-6, 'sigma': .1, 'order': 256}, + {'q': 1e-5, 'sigma': .1, 'order': 256.1}, + {'q': 1e-6, 'sigma': 1, 'order': 27}, + {'q': 1e-4, 'sigma': 1., 'order': 1.5}, + {'q': 1e-3, 'sigma': 1., 'order': 2}, + {'q': .01, 'sigma': 10, 'order': 20}, + {'q': .1, 'sigma': 100, 'order': 20.5}, + {'q': .99, 'sigma': .1, 'order': 256}, + {'q': .999, 'sigma': 100, 'order': 256.1}) + + # pylint:disable=undefined-variable + @parameterized.parameters(p for p in params) + def test_compute_log_a_equals_mp(self, q, sigma, order): + # Compare the cheap computation of log(A) with an expensive, multi-precision + # computation. + log_a = rdp_accountant._compute_log_a(q, sigma, order) + log_a_mp = self._log_float_mp(self._compute_a_mp(sigma, q, order)) + np.testing.assert_allclose(log_a, log_a_mp, rtol=1e-4) + + def test_get_privacy_spent_check_target_delta(self): + orders = range(2, 33) + rdp = rdp_accountant.compute_rdp(0.01, 4, 10000, orders) + eps, _, opt_order = rdp_accountant.get_privacy_spent( + orders, rdp, target_delta=1e-5) + self.assertAlmostEqual(eps, 1.258575, places=5) + self.assertEqual(opt_order, 20) + + def test_get_privacy_spent_check_target_eps(self): + orders = range(2, 33) + rdp = rdp_accountant.compute_rdp(0.01, 4, 10000, orders) + _, delta, opt_order = rdp_accountant.get_privacy_spent( + orders, rdp, target_eps=1.258575) + self.assertAlmostEqual(delta, 1e-5) + self.assertEqual(opt_order, 20) + + def test_check_composition(self): + orders = (1.25, 1.5, 1.75, 2., 2.5, 3., 4., 5., 6., 7., 8., 10., 12., 14., + 16., 20., 24., 28., 32., 64., 256.) + + rdp = rdp_accountant.compute_rdp(q=1e-4, + stddev_to_sensitivity_ratio=.4, + steps=40000, + orders=orders) + + eps, _, opt_order = rdp_accountant.get_privacy_spent(orders, rdp, + target_delta=1e-6) + + rdp += rdp_accountant.compute_rdp(q=0.1, + stddev_to_sensitivity_ratio=2, + steps=100, + orders=orders) + eps, _, opt_order = rdp_accountant.get_privacy_spent(orders, rdp, + target_delta=1e-5) + self.assertAlmostEqual(eps, 8.509656, places=5) + self.assertEqual(opt_order, 2.5) + + +if __name__ == '__main__': + absltest.main() diff --git a/privacy/optimizers/__init__.py b/privacy/optimizers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/privacy/optimizers/dp_adam.py b/privacy/optimizers/dp_adam.py deleted file mode 100644 index 1578782..0000000 --- a/privacy/optimizers/dp_adam.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2018, The TensorFlow Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DPAdamOptimizer for TensorFlow.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - -import privacy.optimizers.gaussian_average_query as ph - - -class DPAdamOptimizer(tf.train.AdamOptimizer): - """Optimizer that implements the DP Adam algorithm. - - """ - - def __init__(self, - learning_rate, - beta1=0.9, - beta2=0.999, - epsilon=1e-8, - use_locking=False, - l2_norm_clip=1e9, - noise_multiplier=0.0, - nb_microbatches=1, - name='DPAdam'): - """Construct a new DP Adam optimizer. - - Args: - learning_rate: A Tensor or a floating point value. The learning rate to - use. - beta1: A float value or a constant float tensor. - The exponential decay rate for the 1st moment estimates. - beta2: A float value or a constant float tensor. - The exponential decay rate for the 2nd moment estimates. - epsilon: A small constant for numerical stability. This epsilon is - "epsilon hat" in the Kingma and Ba paper (in the formula just before - Section 2.1), not the epsilon in Algorithm 1 of the paper. - use_locking: If True use locks for update operations. - l2_norm_clip: Clipping parameter for DP-SGD. - noise_multiplier: Noise multiplier for DP-SGD. - nb_microbatches: Number of microbatches in which to split the input. - name: Optional name prefix for the operations created when applying - gradients. Defaults to "DPAdam". @compatibility(eager) When eager - execution is enabled, `learning_rate` can be a callable that takes no - arguments and returns the actual value to use. This can be useful for - changing these values across different invocations of optimizer - functions. @end_compatibility - """ - super(DPAdamOptimizer, self).__init__( - learning_rate, - beta1, - beta2, - epsilon, - use_locking, - name) - stddev = l2_norm_clip * noise_multiplier - self._nb_microbatches = nb_microbatches - self._privacy_helper = ph.GaussianAverageQuery(l2_norm_clip, stddev, - nb_microbatches) - self._ph_global_state = self._privacy_helper.initial_global_state() - - def compute_gradients(self, - loss, - var_list, - gate_gradients=tf.train.Optimizer.GATE_OP, - aggregation_method=None, - colocate_gradients_with_ops=False, - grad_loss=None): - - # Note: it would be closer to the correct i.i.d. sampling of records if - # we sampled each microbatch from the appropriate binomial distribution, - # although that still wouldn't be quite correct because it would be sampling - # from the dataset without replacement. - microbatches_losses = tf.reshape(loss, [self._nb_microbatches, -1]) - sample_params = ( - self._privacy_helper.derive_sample_params(self._ph_global_state)) - - def process_microbatch(i, sample_state): - """Process one microbatch (record) with privacy helper.""" - grads, _ = zip(*super(DPAdamOptimizer, self).compute_gradients( - tf.gather(microbatches_losses, [i]), var_list, gate_gradients, - aggregation_method, colocate_gradients_with_ops, grad_loss)) - grads_list = list(grads) - sample_state = self._privacy_helper.accumulate_record( - sample_params, sample_state, grads_list) - return [tf.add(i, 1), sample_state] - - i = tf.constant(0) - - if var_list is None: - var_list = ( - tf.trainable_variables() + - tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) - sample_state = self._privacy_helper.initial_sample_state( - self._ph_global_state, var_list) - - # Use of while_loop here requires that sample_state be a nested structure of - # tensors. In general, we would prefer to allow it to be an arbitrary - # opaque type. - _, final_state = tf.while_loop( - lambda i, _: tf.less(i, self._nb_microbatches), process_microbatch, - [i, sample_state]) - final_grads, self._ph_global_state = ( - self._privacy_helper.get_noised_average(final_state, - self._ph_global_state)) - - return zip(final_grads, var_list) - diff --git a/privacy/optimizers/dp_gradient_descent.py b/privacy/optimizers/dp_gradient_descent.py deleted file mode 100644 index 141e18b..0000000 --- a/privacy/optimizers/dp_gradient_descent.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2018, The TensorFlow Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DPGradientDescentOptimizer for TensorFlow.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - -import privacy.optimizers.gaussian_average_query as ph - - -class DPGradientDescentOptimizer(tf.train.GradientDescentOptimizer): - """Optimizer that implements the DP gradient descent algorithm. - - """ - - def __init__(self, - learning_rate, - use_locking=False, - l2_norm_clip=1e9, - noise_multiplier=0.0, - nb_microbatches=1, - name='DPGradientDescent'): - """Construct a new DP gradient descent optimizer. - - Args: - learning_rate: A Tensor or a floating point value. The learning rate to - use. - use_locking: If True use locks for update operations. - l2_norm_clip: Clipping parameter for DP-SGD. - noise_multiplier: Noise multiplier for DP-SGD. - nb_microbatches: Number of microbatches in which to split the input. - name: Optional name prefix for the operations created when applying - gradients. Defaults to "DPGradientDescent". @compatibility(eager) When - eager execution is enabled, `learning_rate` can be a callable that takes - no arguments and returns the actual value to use. This can be useful for - changing these values across different invocations of optimizer - functions. @end_compatibility - """ - super(DPGradientDescentOptimizer, self).__init__(learning_rate, use_locking, - name) - stddev = l2_norm_clip * noise_multiplier - self._nb_microbatches = nb_microbatches - self._privacy_helper = ph.GaussianAverageQuery(l2_norm_clip, stddev, - nb_microbatches) - self._ph_global_state = self._privacy_helper.initial_global_state() - - def compute_gradients(self, - loss, - var_list, - gate_gradients=tf.train.Optimizer.GATE_OP, - aggregation_method=None, - colocate_gradients_with_ops=False, - grad_loss=None): - - # Note: it would be closer to the correct i.i.d. sampling of records if - # we sampled each microbatch from the appropriate binomial distribution, - # although that still wouldn't be quite correct because it would be sampling - # from the dataset without replacement. - microbatches_losses = tf.reshape(loss, [self._nb_microbatches, -1]) - sample_params = ( - self._privacy_helper.derive_sample_params(self._ph_global_state)) - - def process_microbatch(i, sample_state): - """Process one microbatch (record) with privacy helper.""" - grads, _ = zip(*super(DPGradientDescentOptimizer, self).compute_gradients( - tf.gather(microbatches_losses, [i]), var_list, gate_gradients, - aggregation_method, colocate_gradients_with_ops, grad_loss)) - grads_list = list(grads) - sample_state = self._privacy_helper.accumulate_record( - sample_params, sample_state, grads_list) - return [tf.add(i, 1), sample_state] - - i = tf.constant(0) - - if var_list is None: - var_list = ( - tf.trainable_variables() + - tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) - sample_state = self._privacy_helper.initial_sample_state( - self._ph_global_state, var_list) - - # Use of while_loop here requires that sample_state be a nested structure of - # tensors. In general, we would prefer to allow it to be an arbitrary - # opaque type. - _, final_state = tf.while_loop( - lambda i, _: tf.less(i, self._nb_microbatches), process_microbatch, - [i, sample_state]) - final_grads, self._ph_global_state = ( - self._privacy_helper.get_noised_average(final_state, - self._ph_global_state)) - - return zip(final_grads, var_list) diff --git a/privacy/optimizers/dp_optimizer.py b/privacy/optimizers/dp_optimizer.py new file mode 100644 index 0000000..f0f323b --- /dev/null +++ b/privacy/optimizers/dp_optimizer.py @@ -0,0 +1,100 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Differentially private optimizers for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +import privacy.optimizers.gaussian_average_query as ph + + +def make_optimizer_class(cls): + """Constructs a DP optimizer class from an existing one.""" + if (tf.train.Optimizer.compute_gradients.__code__ is + not cls.compute_gradients.__code__): + tf.logging.warning( + 'WARNING: Calling make_optimizer_class() on class %s that overrides ' + 'method compute_gradients(). Check to ensure that ' + 'make_optimizer_class() does not interfere with overridden version.', + cls.__name__) + + class DPOptimizerClass(cls): + """Differentially private subclass of given class cls.""" + + def __init__(self, l2_norm_clip, noise_multiplier, num_microbatches, *args, + **kwargs): + super(DPOptimizerClass, self).__init__(*args, **kwargs) + stddev = l2_norm_clip * noise_multiplier + self._num_microbatches = num_microbatches + self._privacy_helper = ph.GaussianAverageQuery(l2_norm_clip, stddev, + num_microbatches) + self._ph_global_state = self._privacy_helper.initial_global_state() + + def compute_gradients(self, + loss, + var_list, + gate_gradients=tf.train.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + grad_loss=None): + + # Note: it would be closer to the correct i.i.d. sampling of records if + # we sampled each microbatch from the appropriate binomial distribution, + # although that still wouldn't be quite correct because it would be + # sampling from the dataset without replacement. + microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1]) + sample_params = ( + self._privacy_helper.derive_sample_params(self._ph_global_state)) + + def process_microbatch(i, sample_state): + """Process one microbatch (record) with privacy helper.""" + grads, _ = zip(*super(cls, self).compute_gradients( + tf.gather(microbatches_losses, [i]), var_list, gate_gradients, + aggregation_method, colocate_gradients_with_ops, grad_loss)) + grads_list = list(grads) + sample_state = self._privacy_helper.accumulate_record( + sample_params, sample_state, grads_list) + return [tf.add(i, 1), sample_state] + + i = tf.constant(0) + + if var_list is None: + var_list = ( + tf.trainable_variables() + tf.get_collection( + tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + sample_state = self._privacy_helper.initial_sample_state( + self._ph_global_state, var_list) + + # Use of while_loop here requires that sample_state be a nested structure + # of tensors. In general, we would prefer to allow it to be an arbitrary + # opaque type. + _, final_state = tf.while_loop( + lambda i, _: tf.less(i, self._num_microbatches), process_microbatch, + [i, sample_state]) + final_grads, self._ph_global_state = ( + self._privacy_helper.get_noised_average(final_state, + self._ph_global_state)) + + return zip(final_grads, var_list) + + return DPOptimizerClass + + +DPAdagradOptimizer = make_optimizer_class(tf.train.AdagradOptimizer) +DPAdamOptimizer = make_optimizer_class(tf.train.AdamOptimizer) +DPGradientDescentOptimizer = make_optimizer_class( + tf.train.GradientDescentOptimizer) diff --git a/privacy/optimizers/dp_optimizer_test.py b/privacy/optimizers/dp_optimizer_test.py index 290888d..a5f24d2 100644 --- a/privacy/optimizers/dp_optimizer_test.py +++ b/privacy/optimizers/dp_optimizer_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for differentially private optimizers.""" from __future__ import absolute_import @@ -19,11 +18,11 @@ from __future__ import division from __future__ import print_function from absl.testing import parameterized +import mock import numpy as np import tensorflow as tf -from privacy.optimizers import dp_adam -from privacy.optimizers import dp_gradient_descent +from privacy.optimizers import dp_optimizer def loss(val0, val1): @@ -33,22 +32,31 @@ def loss(val0, val1): class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): - # Parameters for testing: optimizer, nb_microbatches, expected answer. + # Parameters for testing: optimizer, num_microbatches, expected answer. @parameterized.named_parameters( - ('DPGradientDescent 1', dp_gradient_descent.DPGradientDescentOptimizer, 1, + ('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1, [-10.0, -10.0]), - ('DPGradientDescent 2', dp_gradient_descent.DPGradientDescentOptimizer, 2, + ('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2, [-5.0, -5.0]), - ('DPGradientDescent 4', dp_gradient_descent.DPGradientDescentOptimizer, 4, - [-2.5, -2.5]), ('DPAdam 1', dp_adam.DPAdamOptimizer, 1, [-10.0, -10.0]), - ('DPAdam 2', dp_adam.DPAdamOptimizer, 2, [-5.0, -5.0]), - ('DPAdam 4', dp_adam.DPAdamOptimizer, 4, [-2.5, -2.5])) - def testBaseline(self, cls, nb_microbatches, expected_answer): + ('DPGradientDescent 4', dp_optimizer.DPGradientDescentOptimizer, 4, + [-2.5, -2.5]), + ('DPAdagrad 1', dp_optimizer.DPAdagradOptimizer, 1, [-10.0, -10.0]), + ('DPAdagrad 2', dp_optimizer.DPAdagradOptimizer, 2, [-5.0, -5.0]), + ('DPAdagrad 4', dp_optimizer.DPAdagradOptimizer, 4, [-2.5, -2.5]), + ('DPAdam 1', dp_optimizer.DPAdamOptimizer, 1, [-10.0, -10.0]), + ('DPAdam 2', dp_optimizer.DPAdamOptimizer, 2, [-5.0, -5.0]), + ('DPAdam 4', dp_optimizer.DPAdamOptimizer, 4, [-2.5, -2.5])) + def testBaseline(self, cls, num_microbatches, expected_answer): with self.cached_session() as sess: var0 = tf.Variable([1.0, 2.0]) data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) - opt = cls(learning_rate=2.0, nb_microbatches=nb_microbatches) + opt = cls( + l2_norm_clip=1.0e9, + noise_multiplier=0.0, + num_microbatches=num_microbatches, + learning_rate=2.0) + self.evaluate(tf.global_variables_initializer()) # Fetch params to validate initial values self.assertAllClose([1.0, 2.0], self.evaluate(var0)) @@ -60,14 +68,20 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): self.assertAllCloseAccordingToType(expected_answer, grads_and_vars[0][0]) @parameterized.named_parameters( - ('DPGradientDescent', dp_gradient_descent.DPGradientDescentOptimizer), - ('DPAdam', dp_adam.DPAdamOptimizer)) + ('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer), + ('DPAdagrad', dp_optimizer.DPAdagradOptimizer), + ('DPAdam', dp_optimizer.DPAdamOptimizer)) def testClippingNorm(self, cls): with self.cached_session() as sess: var0 = tf.Variable([0.0, 0.0]) data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]]) - opt = cls(learning_rate=2.0, l2_norm_clip=1.0, nb_microbatches=1) + opt = cls( + l2_norm_clip=1.0, + noise_multiplier=0.0, + num_microbatches=1, + learning_rate=2.0) + self.evaluate(tf.global_variables_initializer()) # Fetch params to validate initial values self.assertAllClose([0.0, 0.0], self.evaluate(var0)) @@ -78,18 +92,20 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0]) @parameterized.named_parameters( - ('DPGradientDescent', dp_gradient_descent.DPGradientDescentOptimizer), - ('DPAdam', dp_adam.DPAdamOptimizer)) + ('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer), + ('DPAdagrad', dp_optimizer.DPAdagradOptimizer), + ('DPAdam', dp_optimizer.DPAdamOptimizer)) def testNoiseMultiplier(self, cls): with self.cached_session() as sess: var0 = tf.Variable([0.0]) data0 = tf.Variable([[0.0]]) opt = cls( - learning_rate=2.0, l2_norm_clip=4.0, noise_multiplier=2.0, - nb_microbatches=1) + num_microbatches=1, + learning_rate=2.0) + self.evaluate(tf.global_variables_initializer()) # Fetch params to validate initial values self.assertAllClose([0.0], self.evaluate(var0)) @@ -103,6 +119,20 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): # Test standard deviation is close to l2_norm_clip * noise_multiplier. self.assertNear(np.std(grads), 2.0 * 4.0, 0.5) + @mock.patch.object(tf, 'logging') + def testComputeGradientsOverrideWarning(self, mock_logging): + + class SimpleOptimizer(tf.train.Optimizer): + + def compute_gradients(self): + return 0 + + dp_optimizer.make_optimizer_class(SimpleOptimizer) + mock_logging.warning.assert_called_once_with( + 'WARNING: Calling make_optimizer_class() on class %s that overrides ' + 'method compute_gradients(). Check to ensure that ' + 'make_optimizer_class() does not interfere with overridden version.', + 'SimpleOptimizer') if __name__ == '__main__': tf.test.main() diff --git a/tutorials/README.md b/tutorials/README.md new file mode 100644 index 0000000..86e7132 --- /dev/null +++ b/tutorials/README.md @@ -0,0 +1,51 @@ +# Tutorials + +As demonstrated on MNIST in `mnist_dpsgd_tutorial.py`, the easiest to use +a differentially private optimizer is to modify an existing training loop +to replace an existing vanilla optimizer with its differentially private +counterpart implemented in the library. + +## Parameters + +All of the optimizers share some privacy-specific parameters that need to +be tuned in addition to any existing hyperparameter. There are currently three: +* num_microbatches (int): The input data for each step (i.e., batch) of your + original training algorithm is split into this many microbatches. Generally, + increasing this will improve your utility but slow down your training in terms + of wall-clock time. The total number of examples consumed in one global step + remains the same. This number should evenly divide your input batch size. +* l2_norm_clip (float): The cumulative gradient across all network parameters + from each microbatch will be clipped so that its L2 norm is at most this + value. You should set this to something close to some percentile of what + you expect the gradient from each microbatch to be. In previous experiments, + we've found numbers from 0.5 to 1.0 to work reasonably well. +* noise_multiplier (float): This governs the amount of noise added during + training. Generally, more noise results in better privacy and lower utility. + This generally has to be at least 0.3 to obtain rigorous privacy guarantees, + but smaller values may still be acceptable for practical purposes. + +## Measuring Privacy + +Differential privacy is measured by two values, epsilon and delta. Roughly +speaking, they mean the following: + +* epsilon gives a ceiling on how much the probability of a change in model + behavior can increase by including a single extra training example. This is + the far more sensitive value, and we usually want it to be at most 10.0 or + so. However, note that this is only an upper bound, and a large value of + epsilon may still mean good practical privacy. +* delta bounds the probability of an "unconditional" change in model behavior. + We can usually set this to a very small number (1e-7 or so) without + compromising utility. A rule of thumb is to set it to the inverse of the + order of magnitude of the training data size. + +To find out the epsilon given a fixed delta value for your model, follow the +approach demonstrated in the `compute_epsilon` of the `mnist_dpsgd_tutorial.py` +where the arguments used to call the RDP accountant (i.e., the tool used to +compute the privacy guarantee) are: + +* q : The sampling ratio, defined as (number of examples consumed in one + step) / (total training examples). +* stddev_to_sensitivity_ratio : The noise_multiplier from your parameters above. +* steps : The number of global steps taken. + diff --git a/tutorials/mnist_dpsgd_tutorial.py b/tutorials/mnist_dpsgd_tutorial.py new file mode 100644 index 0000000..fdacd63 --- /dev/null +++ b/tutorials/mnist_dpsgd_tutorial.py @@ -0,0 +1,183 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training a CNN on MNIST with differentially private Adam optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from privacy.analysis.rdp_accountant import compute_rdp +from privacy.analysis.rdp_accountant import get_privacy_spent +from privacy.optimizers import dp_optimizer + +tf.flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-Adam. If False,' + 'train with vanilla Adam.') +tf.flags.DEFINE_float('learning_rate', 0.0015, 'Learning rate for training') +tf.flags.DEFINE_float('noise_multiplier', 1.0, + 'Ratio of the standard deviation to the clipping norm') +tf.flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm') +tf.flags.DEFINE_integer('batch_size', 256, 'Batch size') +tf.flags.DEFINE_integer('epochs', 15, 'Number of epochs') +tf.flags.DEFINE_integer('microbatches', 256, + 'Number of microbatches (must evenly divide batch_size') +tf.flags.DEFINE_string('model_dir', None, 'Model directory') + +FLAGS = tf.flags.FLAGS + + +def cnn_model_fn(features, labels, mode): + """Model function for a CNN.""" + + # Define CNN architecture using tf.keras.layers. + input_layer = tf.reshape(features['x'], [-1, 28, 28, 1]) + y = tf.keras.layers.Conv2D(16, 8, + strides=2, + padding='same', + kernel_initializer='he_normal').apply(input_layer) + y = tf.keras.layers.MaxPool2D(2, 1).apply(y) + y = tf.keras.layers.Conv2D(32, 4, + strides=2, + padding='valid', + kernel_initializer='he_normal').apply(y) + y = tf.keras.layers.MaxPool2D(2, 1).apply(y) + y = tf.keras.layers.Flatten().apply(y) + y = tf.keras.layers.Dense(32, kernel_initializer='he_normal').apply(y) + logits = tf.keras.layers.Dense(10, kernel_initializer='he_normal').apply(y) + + # Calculate loss as a vector (to support microbatches in DP-SGD). + vector_loss = tf.nn.softmax_cross_entropy_with_logits_v2( + labels=labels, logits=logits) + # Define mean of loss across minibatch (for reporting through tf.Estimator). + scalar_loss = tf.reduce_mean(vector_loss) + + # Configure the training op (for TRAIN mode). + if mode == tf.estimator.ModeKeys.TRAIN: + + if FLAGS.dpsgd: + # Use DP version of AdamOptimizer. For illustration purposes, we do that + # here by calling make_optimizer_class() explicitly, though DP versions + # of standard optimizers are available in dp_optimizer. + dp_optimizer_class = dp_optimizer.make_optimizer_class( + tf.train.AdamOptimizer) + optimizer = dp_optimizer_class( + learning_rate=FLAGS.learning_rate, + noise_multiplier=FLAGS.noise_multiplier, + l2_norm_clip=FLAGS.l2_norm_clip, + num_microbatches=FLAGS.microbatches) + else: + optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) + global_step = tf.train.get_global_step() + train_op = optimizer.minimize(loss=vector_loss, global_step=global_step) + return tf.estimator.EstimatorSpec(mode=mode, + loss=scalar_loss, + train_op=train_op) + + # Add evaluation metrics (for EVAL mode). + elif mode == tf.estimator.ModeKeys.EVAL: + eval_metric_ops = { + 'accuracy': + tf.metrics.accuracy( + labels=tf.argmax(labels, axis=1), + predictions=tf.argmax(input=logits, axis=1)) + } + return tf.estimator.EstimatorSpec(mode=mode, + loss=scalar_loss, + eval_metric_ops=eval_metric_ops) + + +def load_mnist(): + """Loads MNIST and preprocesses to combine training and validation data.""" + train, test = tf.keras.datasets.mnist.load_data() + train_data, train_labels = train + test_data, test_labels = test + + train_data = np.array(train_data, dtype=np.float32) / 255 + test_data = np.array(test_data, dtype=np.float32) / 255 + + train_labels = tf.keras.utils.to_categorical(train_labels) + test_labels = tf.keras.utils.to_categorical(test_labels) + + assert train_data.min() == 0. + assert train_data.max() == 1. + assert test_data.min() == 0. + assert test_data.max() == 1. + assert train_labels.shape[1] == 10 + assert test_labels.shape[1] == 10 + + return train_data, train_labels, test_data, test_labels + + +def main(unused_argv): + tf.logging.set_verbosity(tf.logging.INFO) + if FLAGS.batch_size % FLAGS.microbatches != 0: + raise ValueError('Number of microbatches should divide evenly batch_size') + + # Load training and test data. + train_data, train_labels, test_data, test_labels = load_mnist() + + # Instantiate the tf.Estimator. + mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, + model_dir=FLAGS.model_dir) + + # Create tf.Estimator input functions for the training and test data. + train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={'x': train_data}, + y=train_labels, + batch_size=FLAGS.batch_size, + num_epochs=FLAGS.epochs, + shuffle=True) + eval_input_fn = tf.estimator.inputs.numpy_input_fn( + x={'x': test_data}, + y=test_labels, + num_epochs=1, + shuffle=False) + + # Define a function that computes privacy budget expended so far. + def compute_epsilon(steps): + """Computes epsilon value for given hyperparameters.""" + if FLAGS.noise_multiplier == 0.0: + return float('inf') + orders = [1 + x / 10. for x in range(1, 100)] + range(12, 64) + sampling_probability = FLAGS.batch_size / 60000 + rdp = compute_rdp(q=sampling_probability, + stddev_to_sensitivity_ratio=FLAGS.noise_multiplier, + steps=steps, + orders=orders) + # Delta is set to 1e-5 because MNIST has 60000 training points. + return get_privacy_spent(orders, rdp, target_delta=1e-5)[0] + + # Training loop. + steps_per_epoch = 60000 // FLAGS.batch_size + for epoch in range(1, FLAGS.epochs + 1): + # Train the model for one epoch. + mnist_classifier.train(input_fn=train_input_fn, steps=steps_per_epoch) + + # Evaluate the model and print results + eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) + test_accuracy = eval_results['accuracy'] + print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy)) + + # Compute the privacy budget expended so far. + if FLAGS.dpsgd: + eps = compute_epsilon(epoch * steps_per_epoch) + print('For delta=1e-5, the current epsilon is: %.2f' % eps) + else: + print('Trained with vanilla non-private Adam optimizer') + +if __name__ == '__main__': + tf.app.run()