From afb8189dba2fb310ca582633dcacc74e221855ff Mon Sep 17 00:00:00 2001 From: Steve Chien Date: Tue, 4 Dec 2018 15:50:21 -0800 Subject: [PATCH] PiperOrigin-RevId: 224061027 --- README.md | 7 +- privacy/optimizers/BUILD | 55 +++++++++ privacy/optimizers/dp_adam.py | 122 +++++++++++++++++++ privacy/optimizers/dp_gradient_descent.py | 106 ++++++++++++++++ privacy/optimizers/dp_optimizer_test.py | 108 ++++++++++++++++ privacy/optimizers/gaussian_average_query.py | 108 ++++++++++++++++ privacy/optimizers/private_queries.py | 90 ++++++++++++++ privacy/test.py | 31 ----- 8 files changed, 590 insertions(+), 37 deletions(-) create mode 100644 privacy/optimizers/BUILD create mode 100644 privacy/optimizers/dp_adam.py create mode 100644 privacy/optimizers/dp_gradient_descent.py create mode 100644 privacy/optimizers/dp_optimizer_test.py create mode 100644 privacy/optimizers/gaussian_average_query.py create mode 100644 privacy/optimizers/private_queries.py delete mode 100644 privacy/test.py diff --git a/README.md b/README.md index 1d22c0b..1e799d7 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,5 @@ This repository will contain implementations of TensorFlow optimizers that support training machine learning models with (differential) privacy, as well as tutorials and analysis tools for computing the privacy guarantees provided. -The content of this repository will superseed the following existing repository: +The content of this repository will supersede the following existing repository: https://github.com/tensorflow/models/tree/master/research/differential_privacy - -# Contact - -* Steve Chien (schien@google.com) -* Nicolas Papernot (@npapernot) diff --git a/privacy/optimizers/BUILD b/privacy/optimizers/BUILD new file mode 100644 index 0000000..8820512 --- /dev/null +++ b/privacy/optimizers/BUILD @@ -0,0 +1,55 @@ +licenses(["notice"]) # Apache 2.0 + +py_library( + name = "gaussian_average_query", + srcs = ["gaussian_average_query.py"], + deps = [ + ":private_queries", + "//third_party/py/tensorflow", + ], +) + +py_library( + name = "dp_optimizers", + deps = [ + ":dp_adam", + ":dp_gradient_descent", + ], +) + +py_library( + name = "dp_adam", + srcs = [ + "dp_adam.py", + ], + deps = [ + ":gaussian_average_query", + "//third_party/py/tensorflow", + ], +) + +py_library( + name = "dp_gradient_descent", + srcs = [ + "dp_gradient_descent.py", + ], + deps = [ + ":gaussian_average_query", + "//third_party/py/tensorflow", + ], +) + +py_test( + name = "dp_optimizer_test", + srcs = ["dp_optimizer_test.py"], + deps = [ + ":dp_optimizers", + "//third_party/py/absl/testing:parameterized", + "//third_party/py/tensorflow", + ], +) + +py_library( + name = "private_queries", + srcs = ["private_queries.py"], +) diff --git a/privacy/optimizers/dp_adam.py b/privacy/optimizers/dp_adam.py new file mode 100644 index 0000000..135aa97 --- /dev/null +++ b/privacy/optimizers/dp_adam.py @@ -0,0 +1,122 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DPAdamOptimizer for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +import tensorflow_privacy.privacy.optimizers.gaussian_average_query as ph + + +class DPAdamOptimizer(tf.train.AdamOptimizer): + """Optimizer that implements the DP Adam algorithm. + + """ + + def __init__(self, + learning_rate, + beta1=0.9, + beta2=0.999, + epsilon=1e-8, + use_locking=False, + l2_norm_clip=1e9, + noise_multiplier=0.0, + nb_microbatches=1, + name='DPAdam'): + """Construct a new DP Adam optimizer. + + Args: + learning_rate: A Tensor or a floating point value. The learning rate to + use. + beta1: A float value or a constant float tensor. + The exponential decay rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. + The exponential decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + use_locking: If True use locks for update operations. + l2_norm_clip: Clipping parameter for DP-SGD. + noise_multiplier: Noise multiplier for DP-SGD. + nb_microbatches: Number of microbatches in which to split the input. + name: Optional name prefix for the operations created when applying + gradients. Defaults to "DPAdam". @compatibility(eager) When eager + execution is enabled, `learning_rate` can be a callable that takes no + arguments and returns the actual value to use. This can be useful for + changing these values across different invocations of optimizer + functions. @end_compatibility + """ + super(DPAdamOptimizer, self).__init__( + learning_rate, + beta1, + beta2, + epsilon, + use_locking, + name) + stddev = l2_norm_clip * noise_multiplier + self._nb_microbatches = nb_microbatches + self._privacy_helper = ph.GaussianAverageQuery(l2_norm_clip, stddev, + nb_microbatches) + self._ph_global_state = self._privacy_helper.initial_global_state() + + def compute_gradients(self, + loss, + var_list, + gate_gradients=tf.train.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + grad_loss=None): + + # Note: it would be closer to the correct i.i.d. sampling of records if + # we sampled each microbatch from the appropriate binomial distribution, + # although that still wouldn't be quite correct because it would be sampling + # from the dataset without replacement. + microbatches_losses = tf.reshape(loss, [self._nb_microbatches, -1]) + sample_params = ( + self._privacy_helper.derive_sample_params(self._ph_global_state)) + + def process_microbatch(i, sample_state): + """Process one microbatch (record) with privacy helper.""" + grads, _ = zip(*super(DPAdamOptimizer, self).compute_gradients( + tf.gather(microbatches_losses, [i]), var_list, gate_gradients, + aggregation_method, colocate_gradients_with_ops, grad_loss)) + sample_state = self._privacy_helper.accumulate_record( + sample_params, sample_state, grads) + return [tf.add(i, 1), sample_state] + + i = tf.constant(0) + + if var_list is None: + var_list = ( + tf.trainable_variables() + + tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + sample_state = self._privacy_helper.initial_sample_state( + self._ph_global_state, var_list) + + # Use of while_loop here requires that sample_state be a nested structure of + # tensors. In general, we would prefer to allow it to be an arbitrary + # opaque type. + _, final_state = tf.while_loop( + lambda i, _: tf.less(i, self._nb_microbatches), process_microbatch, + [i, sample_state]) + final_grads, self._ph_global_state = ( + self._privacy_helper.get_noised_average(final_state, + self._ph_global_state)) + + return zip(final_grads, var_list) + diff --git a/privacy/optimizers/dp_gradient_descent.py b/privacy/optimizers/dp_gradient_descent.py new file mode 100644 index 0000000..97aa4b1 --- /dev/null +++ b/privacy/optimizers/dp_gradient_descent.py @@ -0,0 +1,106 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DPGradientDescentOptimizer for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +import tensorflow_privacy.privacy.optimizers.gaussian_average_query as ph + + +class DPGradientDescentOptimizer(tf.train.GradientDescentOptimizer): + """Optimizer that implements the DP gradient descent algorithm. + + """ + + def __init__(self, + learning_rate, + use_locking=False, + l2_norm_clip=1e9, + noise_multiplier=0.0, + nb_microbatches=1, + name='DPGradientDescent'): + """Construct a new DP gradient descent optimizer. + + Args: + learning_rate: A Tensor or a floating point value. The learning rate to + use. + use_locking: If True use locks for update operations. + l2_norm_clip: Clipping parameter for DP-SGD. + noise_multiplier: Noise multiplier for DP-SGD. + nb_microbatches: Number of microbatches in which to split the input. + name: Optional name prefix for the operations created when applying + gradients. Defaults to "DPGradientDescent". @compatibility(eager) When + eager execution is enabled, `learning_rate` can be a callable that takes + no arguments and returns the actual value to use. This can be useful for + changing these values across different invocations of optimizer + functions. @end_compatibility + """ + super(DPGradientDescentOptimizer, self).__init__(learning_rate, use_locking, + name) + stddev = l2_norm_clip * noise_multiplier + self._nb_microbatches = nb_microbatches + self._privacy_helper = ph.GaussianAverageQuery(l2_norm_clip, stddev, + nb_microbatches) + self._ph_global_state = self._privacy_helper.initial_global_state() + + def compute_gradients(self, + loss, + var_list, + gate_gradients=tf.train.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + grad_loss=None): + + # Note: it would be closer to the correct i.i.d. sampling of records if + # we sampled each microbatch from the appropriate binomial distribution, + # although that still wouldn't be quite correct because it would be sampling + # from the dataset without replacement. + microbatches_losses = tf.reshape(loss, [self._nb_microbatches, -1]) + sample_params = ( + self._privacy_helper.derive_sample_params(self._ph_global_state)) + + def process_microbatch(i, sample_state): + """Process one microbatch (record) with privacy helper.""" + grads, _ = zip(*super(DPGradientDescentOptimizer, self).compute_gradients( + tf.gather(microbatches_losses, [i]), var_list, gate_gradients, + aggregation_method, colocate_gradients_with_ops, grad_loss)) + sample_state = self._privacy_helper.accumulate_record( + sample_params, sample_state, grads) + return [tf.add(i, 1), sample_state] + + i = tf.constant(0) + + if var_list is None: + var_list = ( + tf.trainable_variables() + + tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + sample_state = self._privacy_helper.initial_sample_state( + self._ph_global_state, var_list) + + # Use of while_loop here requires that sample_state be a nested structure of + # tensors. In general, we would prefer to allow it to be an arbitrary + # opaque type. + _, final_state = tf.while_loop( + lambda i, _: tf.less(i, self._nb_microbatches), process_microbatch, + [i, sample_state]) + final_grads, self._ph_global_state = ( + self._privacy_helper.get_noised_average(final_state, + self._ph_global_state)) + + return zip(final_grads, var_list) diff --git a/privacy/optimizers/dp_optimizer_test.py b/privacy/optimizers/dp_optimizer_test.py new file mode 100644 index 0000000..d351bf0 --- /dev/null +++ b/privacy/optimizers/dp_optimizer_test.py @@ -0,0 +1,108 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for differentially private optimizers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np +import tensorflow as tf + +from tensorflow_privacy.privacy.optimizers import dp_adam +from tensorflow_privacy.privacy.optimizers import dp_gradient_descent + + +def loss(val0, val1): + """Loss function that is minimized at the mean of the input points.""" + return 0.5 * tf.reduce_sum(tf.squared_difference(val0, val1), axis=1) + + +class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): + + # Parameters for testing: optimizer, nb_microbatches, expected answer. + @parameterized.named_parameters( + ('DPGradientDescent 1', dp_gradient_descent.DPGradientDescentOptimizer, 1, + [-10.0, -10.0]), + ('DPGradientDescent 2', dp_gradient_descent.DPGradientDescentOptimizer, 2, + [-5.0, -5.0]), + ('DPGradientDescent 4', dp_gradient_descent.DPGradientDescentOptimizer, 4, + [-2.5, -2.5]), ('DPAdam 1', dp_adam.DPAdamOptimizer, 1, [-10.0, -10.0]), + ('DPAdam 2', dp_adam.DPAdamOptimizer, 2, [-5.0, -5.0]), + ('DPAdam 4', dp_adam.DPAdamOptimizer, 4, [-2.5, -2.5])) + def testBaseline(self, cls, nb_microbatches, expected_answer): + with self.cached_session() as sess: + var0 = tf.Variable([1.0, 2.0]) + data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) + + opt = cls(learning_rate=2.0, nb_microbatches=nb_microbatches) + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + + # Expected gradient is sum of differences divided by number of + # microbatches. + gradient_op = opt.compute_gradients(loss(data0, var0), [var0]) + grads_and_vars = sess.run(gradient_op) + self.assertAllCloseAccordingToType(expected_answer, grads_and_vars[0][0]) + + @parameterized.named_parameters( + ('DPGradientDescent', dp_gradient_descent.DPGradientDescentOptimizer), + ('DPAdam', dp_adam.DPAdamOptimizer)) + def testClippingNorm(self, cls): + with self.cached_session() as sess: + var0 = tf.Variable([0.0, 0.0]) + data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]]) + + opt = cls(learning_rate=2.0, l2_norm_clip=1.0, nb_microbatches=1) + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + + # Expected gradient is sum of differences. + gradient_op = opt.compute_gradients(loss(data0, var0), [var0]) + grads_and_vars = sess.run(gradient_op) + self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0]) + + @parameterized.named_parameters( + ('DPGradientDescent', dp_gradient_descent.DPGradientDescentOptimizer), + ('DPAdam', dp_adam.DPAdamOptimizer)) + def testNoiseMultiplier(self, cls): + with self.cached_session() as sess: + var0 = tf.Variable([0.0]) + data0 = tf.Variable([[0.0]]) + + opt = cls( + learning_rate=2.0, + l2_norm_clip=4.0, + noise_multiplier=2.0, + nb_microbatches=1) + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([0.0], self.evaluate(var0)) + + gradient_op = opt.compute_gradients(loss(data0, var0), [var0]) + grads = [] + for _ in xrange(1000): + grads_and_vars = sess.run(gradient_op) + grads.append(grads_and_vars[0][0]) + + # Test standard deviation is close to l2_norm_clip * noise_multiplier. + self.assertNear(np.std(grads), 2.0 * 4.0, 0.5) + + +if __name__ == '__main__': + tf.test.main() diff --git a/privacy/optimizers/gaussian_average_query.py b/privacy/optimizers/gaussian_average_query.py new file mode 100644 index 0000000..9123ffa --- /dev/null +++ b/privacy/optimizers/gaussian_average_query.py @@ -0,0 +1,108 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implements PrivateQuery interface for Gaussian average queries. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import tensorflow as tf + +from tensorflow_privacy.privacy.optimizers import private_queries + + +class GaussianAverageQuery(private_queries.PrivateAverageQuery): + """Implements PrivateQuery interface for Gaussian average queries. + + Accumulates clipped vectors, then adds Gaussian noise to the average. + """ + + # pylint: disable=invalid-name + _GlobalState = collections.namedtuple( + '_GlobalState', ['l2_norm_clip', 'stddev', 'denominator']) + + def __init__(self, l2_norm_clip, stddev, denominator): + """Initializes the GaussianAverageQuery.""" + self._l2_norm_clip = l2_norm_clip + self._stddev = stddev + self._denominator = denominator + + def initial_global_state(self): + """Returns the initial global state for the PrivacyHelper.""" + return self._GlobalState( + float(self._l2_norm_clip), float(self._stddev), + float(self._denominator)) + + def derive_sample_params(self, global_state): + """Given the global state, derives parameters to use for the next sample. + + Args: + global_state: The current global state. + + Returns: + Parameters to use to process records in the next sample. + """ + return global_state.l2_norm_clip + + def initial_sample_state(self, global_state, tensors): + """Returns an initial state to use for the next sample. + + Args: + global_state: The current global state. + tensors: A structure of tensors used as a template to create the initial + sample state. + + Returns: An initial sample state. + """ + del global_state # unused. + return tf.contrib.framework.nest.map_structure(tf.zeros_like, tensors) + + def accumulate_record(self, params, sample_state, record): + """Accumulates a single record into the sample state. + + Args: + params: The parameters for the sample. + sample_state: The current sample state. + record: The record to accumulate. + + Returns: + The updated sample state. + """ + l2_norm_clip = params + clipped, _ = tf.clip_by_global_norm(record, l2_norm_clip) + return tf.contrib.framework.nest.map_structure(tf.add, sample_state, + clipped) + + def get_noised_average(self, sample_state, global_state): + """Gets noised average after all records of sample have been accumulated. + + Args: + sample_state: The sample state after all records have been accumulated. + global_state: The global state. + + Returns: + A tuple (estimate, new_global_state) where "estimate" is the estimated + average of the records and "new_global_state" is the updated global state. + """ + def noised_average(v): + return tf.truediv( + v + tf.random_normal(tf.shape(v), stddev=self._stddev), + global_state.denominator) + + return (tf.contrib.framework.nest.map_structure(noised_average, + sample_state), global_state) diff --git a/privacy/optimizers/private_queries.py b/privacy/optimizers/private_queries.py new file mode 100644 index 0000000..86a1967 --- /dev/null +++ b/privacy/optimizers/private_queries.py @@ -0,0 +1,90 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An interface for differentially private query mechanisms. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + + +class PrivateQuery(object): + """Interface for differentially private query mechanisms.""" + + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def initial_global_state(self): + """Returns the initial global state for the PrivateQuery.""" + pass + + @abc.abstractmethod + def derive_sample_params(self, global_state): + """Given the global state, derives parameters to use for the next sample. + + Args: + global_state: The current global state. + + Returns: + Parameters to use to process records in the next sample. + """ + pass + + @abc.abstractmethod + def initial_sample_state(self, global_state, tensors): + """Returns an initial state to use for the next sample. + + Args: + global_state: The current global state. + tensors: A structure of tensors used as a template to create the initial + sample state. + + Returns: An initial sample state. + """ + pass + + @abc.abstractmethod + def accumulate_record(self, params, sample_state, record): + """Accumulates a single record into the sample state. + + Args: + params: The parameters for the sample. + sample_state: The current sample state. + record: The record to accumulate. + + Returns: + The updated sample state. + """ + pass + + +class PrivateAverageQuery(PrivateQuery): + """Interface for differentially private mechanisms to compute an average.""" + + @abc.abstractmethod + def get_noised_average(self, sample_state, global_state): + """Gets average estimate after all records of sample have been accumulated. + + Args: + sample_state: The sample state after all records have been accumulated. + global_state: The global state. + + Returns: + A tuple (estimate, new_global_state) where "estimate" is the estimated + average of the records and "new_global_state" is the updated global state. + """ + pass diff --git a/privacy/test.py b/privacy/test.py deleted file mode 100644 index 8773bc0..0000000 --- a/privacy/test.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl import app -from absl import flags - -FLAGS = flags.FLAGS - - -def main(argv): - if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') - -if __name__ == '__main__': - app.run(main)