From b4188446e092a9246fc30b132e8257fb06a3092f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 20 Dec 2018 09:13:33 -0800 Subject: [PATCH] Project import generated by Copybara. PiperOrigin-RevId: 226345615 --- privacy/optimizers/dp_optimizer.py | 22 +-- ...ian_average_query.py => gaussian_query.py} | 0 ...e_query_test.py => gaussian_query_test.py} | 61 ++++--- privacy/optimizers/nested_query.py | 123 +++++++++++++ privacy/optimizers/nested_query_test.py | 164 ++++++++++++++++++ tutorials/mnist_dpsgd_tutorial.py | 25 +-- 6 files changed, 351 insertions(+), 44 deletions(-) rename privacy/optimizers/{gaussian_average_query.py => gaussian_query.py} (100%) rename privacy/optimizers/{gaussian_average_query_test.py => gaussian_query_test.py} (62%) create mode 100644 privacy/optimizers/nested_query.py create mode 100644 privacy/optimizers/nested_query_test.py diff --git a/privacy/optimizers/dp_optimizer.py b/privacy/optimizers/dp_optimizer.py index f0f323b..138856d 100644 --- a/privacy/optimizers/dp_optimizer.py +++ b/privacy/optimizers/dp_optimizer.py @@ -19,7 +19,7 @@ from __future__ import print_function import tensorflow as tf -import privacy.optimizers.gaussian_average_query as ph +from privacy.optimizers import gaussian_query def make_optimizer_class(cls): @@ -40,9 +40,9 @@ def make_optimizer_class(cls): super(DPOptimizerClass, self).__init__(*args, **kwargs) stddev = l2_norm_clip * noise_multiplier self._num_microbatches = num_microbatches - self._privacy_helper = ph.GaussianAverageQuery(l2_norm_clip, stddev, - num_microbatches) - self._ph_global_state = self._privacy_helper.initial_global_state() + self._private_query = gaussian_query.GaussianAverageQuery( + l2_norm_clip, stddev, num_microbatches) + self._global_state = self._private_query.initial_global_state() def compute_gradients(self, loss, @@ -58,7 +58,7 @@ def make_optimizer_class(cls): # sampling from the dataset without replacement. microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1]) sample_params = ( - self._privacy_helper.derive_sample_params(self._ph_global_state)) + self._private_query.derive_sample_params(self._global_state)) def process_microbatch(i, sample_state): """Process one microbatch (record) with privacy helper.""" @@ -66,7 +66,7 @@ def make_optimizer_class(cls): tf.gather(microbatches_losses, [i]), var_list, gate_gradients, aggregation_method, colocate_gradients_with_ops, grad_loss)) grads_list = list(grads) - sample_state = self._privacy_helper.accumulate_record( + sample_state = self._private_query.accumulate_record( sample_params, sample_state, grads_list) return [tf.add(i, 1), sample_state] @@ -76,8 +76,8 @@ def make_optimizer_class(cls): var_list = ( tf.trainable_variables() + tf.get_collection( tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) - sample_state = self._privacy_helper.initial_sample_state( - self._ph_global_state, var_list) + sample_state = self._private_query.initial_sample_state( + self._global_state, var_list) # Use of while_loop here requires that sample_state be a nested structure # of tensors. In general, we would prefer to allow it to be an arbitrary @@ -85,9 +85,9 @@ def make_optimizer_class(cls): _, final_state = tf.while_loop( lambda i, _: tf.less(i, self._num_microbatches), process_microbatch, [i, sample_state]) - final_grads, self._ph_global_state = ( - self._privacy_helper.get_noised_average(final_state, - self._ph_global_state)) + final_grads, self._global_state = ( + self._private_query.get_noised_average(final_state, + self._global_state)) return zip(final_grads, var_list) diff --git a/privacy/optimizers/gaussian_average_query.py b/privacy/optimizers/gaussian_query.py similarity index 100% rename from privacy/optimizers/gaussian_average_query.py rename to privacy/optimizers/gaussian_query.py diff --git a/privacy/optimizers/gaussian_average_query_test.py b/privacy/optimizers/gaussian_query_test.py similarity index 62% rename from privacy/optimizers/gaussian_average_query_test.py rename to privacy/optimizers/gaussian_query_test.py index 28ce337..43c9085 100644 --- a/privacy/optimizers/gaussian_average_query_test.py +++ b/privacy/optimizers/gaussian_query_test.py @@ -18,32 +18,42 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np import tensorflow as tf -from privacy.optimizers import gaussian_average_query +from privacy.optimizers import gaussian_query -class GaussianAverageQueryTest(tf.test.TestCase): +def _run_query(query, records): + """Executes query on the given set of records as a single sample. - def _run_query(self, query, *records): - """Executes query on the given set of records and returns the result.""" - global_state = query.initial_global_state() - params = query.derive_sample_params(global_state) - sample_state = query.initial_sample_state(global_state, records[0]) - for record in records: - sample_state = query.accumulate_record(params, sample_state, record) - result, _ = query.get_query_result(sample_state, global_state) - return result + Args: + query: A PrivateQuery to run. + records: An iterable containing records to pass to the query. + + Returns: + The result of the query. + """ + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + sample_state = query.initial_sample_state(global_state, next(iter(records))) + for record in records: + sample_state = query.accumulate_record(params, sample_state, record) + result, _ = query.get_query_result(sample_state, global_state) + return result + + +class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase): def test_gaussian_sum_no_clip_no_noise(self): with self.cached_session() as sess: record1 = tf.constant([2.0, 0.0]) record2 = tf.constant([-1.0, 1.0]) - query = gaussian_average_query.GaussianSumQuery( + query = gaussian_query.GaussianSumQuery( l2_norm_clip=10.0, stddev=0.0) - query_result = self._run_query(query, record1, record2) + query_result = _run_query(query, [record1, record2]) result = sess.run(query_result) expected = [1.0, 1.0] self.assertAllClose(result, expected) @@ -53,9 +63,9 @@ class GaussianAverageQueryTest(tf.test.TestCase): record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0]. record2 = tf.constant([4.0, -3.0]) # Not clipped. - query = gaussian_average_query.GaussianSumQuery( + query = gaussian_query.GaussianSumQuery( l2_norm_clip=5.0, stddev=0.0) - query_result = self._run_query(query, record1, record2) + query_result = _run_query(query, [record1, record2]) result = sess.run(query_result) expected = [1.0, 1.0] self.assertAllClose(result, expected) @@ -65,9 +75,9 @@ class GaussianAverageQueryTest(tf.test.TestCase): record1, record2 = 2.71828, 3.14159 stddev = 1.0 - query = gaussian_average_query.GaussianSumQuery( + query = gaussian_query.GaussianSumQuery( l2_norm_clip=5.0, stddev=stddev) - query_result = self._run_query(query, record1, record2) + query_result = _run_query(query, [record1, record2]) noised_sums = [] for _ in xrange(1000): @@ -81,9 +91,9 @@ class GaussianAverageQueryTest(tf.test.TestCase): record1 = tf.constant([5.0, 0.0]) # Clipped to [3.0, 0.0]. record2 = tf.constant([-1.0, 2.0]) # Not clipped. - query = gaussian_average_query.GaussianAverageQuery( + query = gaussian_query.GaussianAverageQuery( l2_norm_clip=3.0, sum_stddev=0.0, denominator=2.0) - query_result = self._run_query(query, record1, record2) + query_result = _run_query(query, [record1, record2]) result = sess.run(query_result) expected_average = [1.0, 1.0] self.assertAllClose(result, expected_average) @@ -94,9 +104,9 @@ class GaussianAverageQueryTest(tf.test.TestCase): sum_stddev = 1.0 denominator = 2.0 - query = gaussian_average_query.GaussianAverageQuery( + query = gaussian_query.GaussianAverageQuery( l2_norm_clip=5.0, sum_stddev=sum_stddev, denominator=denominator) - query_result = self._run_query(query, record1, record2) + query_result = _run_query(query, [record1, record2]) noised_averages = [] for _ in xrange(1000): @@ -106,6 +116,15 @@ class GaussianAverageQueryTest(tf.test.TestCase): avg_stddev = sum_stddev / denominator self.assertNear(result_stddev, avg_stddev, 0.1) + @parameterized.named_parameters( + ('type_mismatch', [1.0], (1.0,), TypeError), + ('too_few_on_left', [1.0], [1.0, 1.0], ValueError), + ('too_few_on_right', [1.0, 1.0], [1.0], ValueError)) + def test_incompatible_records(self, record1, record2, error_type): + query = gaussian_query.GaussianSumQuery(1.0, 0.0) + with self.assertRaises(error_type): + _run_query(query, [record1, record2]) + if __name__ == '__main__': tf.test.main() diff --git a/privacy/optimizers/nested_query.py b/privacy/optimizers/nested_query.py new file mode 100644 index 0000000..3cd073e --- /dev/null +++ b/privacy/optimizers/nested_query.py @@ -0,0 +1,123 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implements PrivateQuery interface for queries over nested structures. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import tensorflow as tf + +from privacy.optimizers import private_queries + +nest = tf.contrib.framework.nest + + +class NestedQuery(private_queries.PrivateQuery): + """Implements PrivateQuery interface for structured queries. + + NestedQuery evaluates arbitrary nested structures of queries. Records must be + nested structures of tensors that are compatible (in type and arity) with the + query structure, but are allowed to have deeper structure within each leaf of + the query structure. For example, the nested query [q1, q2] is compatible with + the record [t1, t2] or [t1, (t2, t3)], but not with (t1, t2), [t1] or + [t1, t2, t3]. The entire substructure of each record corresponding to a leaf + node of the query structure is routed to the corresponding query. If the same + tensor should be consumed by multiple sub-queries, it can be replicated in the + record, for example [t1, t1]. + + NestedQuery is intended to allow privacy mechanisms for groups as described in + [McMahan & Andrew, 2018: "A General Approach to Adding Differential Privacy to + Iterative Training Procedures" (https://arxiv.org/abs/1812.06210)]. + """ + + def __init__(self, queries): + """Initializes the NestedQuery. + + Args: + queries: A nested structure of queries. + """ + self._queries = queries + + def _map_to_queries(self, fn, *inputs): + def caller(query, *args): + return getattr(query, fn)(*args) + return nest.map_structure_up_to( + self._queries, caller, self._queries, *inputs) + + def initial_global_state(self): + """Returns the initial global state for the NestedQuery.""" + return self._map_to_queries('initial_global_state') + + def derive_sample_params(self, global_state): + """Given the global state, derives parameters to use for the next sample. + + Args: + global_state: The current global state. + + Returns: + Parameters to use to process records in the next sample. + """ + return self._map_to_queries('derive_sample_params', global_state) + + def initial_sample_state(self, global_state, tensors): + """Returns an initial state to use for the next sample. + + Args: + global_state: The current global state. + tensors: A structure of tensors used as a template to create the initial + sample state. + + Returns: An initial sample state. + """ + return self._map_to_queries('initial_sample_state', global_state, tensors) + + def accumulate_record(self, params, sample_state, record): + """Accumulates a single record into the sample state. + + Args: + params: The parameters for the sample. + sample_state: The current sample state. + record: The record to accumulate. + + Returns: + The updated sample state. + """ + return self._map_to_queries( + 'accumulate_record', params, sample_state, record) + + def get_query_result(self, sample_state, global_state): + """Gets query result after all records of sample have been accumulated. + + Args: + sample_state: The sample state after all records have been accumulated. + global_state: The global state. + + Returns: + A tuple (result, new_global_state) where "result" is a structure matching + the query structure containing the results of the subqueries and + "new_global_state" is a structure containing the updated global states + for the subqueries. + """ + estimates_and_new_global_states = self._map_to_queries( + 'get_query_result', sample_state, global_state) + + flat_estimates, flat_new_global_states = zip( + *nest.flatten_up_to(self._queries, estimates_and_new_global_states)) + return ( + nest.pack_sequence_as(self._queries, flat_estimates), + nest.pack_sequence_as(self._queries, flat_new_global_states)) diff --git a/privacy/optimizers/nested_query_test.py b/privacy/optimizers/nested_query_test.py new file mode 100644 index 0000000..134dbb2 --- /dev/null +++ b/privacy/optimizers/nested_query_test.py @@ -0,0 +1,164 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for NestedQuery.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from absl.testing import parameterized +import numpy as np +import tensorflow as tf + +from privacy.optimizers import gaussian_query +from privacy.optimizers import nested_query + +nest = tf.contrib.framework.nest + +_basic_query = gaussian_query.GaussianSumQuery(1.0, 0.0) + + +def _run_query(query, records): + """Executes query on the given set of records as a single sample. + + Args: + query: A PrivateQuery to run. + records: An iterable containing records to pass to the query. + + Returns: + The result of the query. + """ + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + sample_state = query.initial_sample_state(global_state, next(iter(records))) + for record in records: + sample_state = query.accumulate_record(params, sample_state, record) + result, _ = query.get_query_result(sample_state, global_state) + return result + + +class NestedQueryTest(tf.test.TestCase, parameterized.TestCase): + + def test_nested_gaussian_sum_no_clip_no_noise(self): + with self.cached_session() as sess: + query1 = gaussian_query.GaussianSumQuery( + l2_norm_clip=10.0, stddev=0.0) + query2 = gaussian_query.GaussianSumQuery( + l2_norm_clip=10.0, stddev=0.0) + + query = nested_query.NestedQuery([query1, query2]) + + record1 = [1.0, [2.0, 3.0]] + record2 = [4.0, [3.0, 2.0]] + + query_result = _run_query(query, [record1, record2]) + result = sess.run(query_result) + expected = [5.0, [5.0, 5.0]] + self.assertAllClose(result, expected) + + def test_nested_gaussian_average_no_clip_no_noise(self): + with self.cached_session() as sess: + query1 = gaussian_query.GaussianAverageQuery( + l2_norm_clip=10.0, sum_stddev=0.0, denominator=5.0) + query2 = gaussian_query.GaussianAverageQuery( + l2_norm_clip=10.0, sum_stddev=0.0, denominator=5.0) + + query = nested_query.NestedQuery([query1, query2]) + + record1 = [1.0, [2.0, 3.0]] + record2 = [4.0, [3.0, 2.0]] + + query_result = _run_query(query, [record1, record2]) + result = sess.run(query_result) + expected = [1.0, [1.0, 1.0]] + self.assertAllClose(result, expected) + + def test_nested_gaussian_average_with_clip_no_noise(self): + with self.cached_session() as sess: + query1 = gaussian_query.GaussianAverageQuery( + l2_norm_clip=4.0, sum_stddev=0.0, denominator=5.0) + query2 = gaussian_query.GaussianAverageQuery( + l2_norm_clip=5.0, sum_stddev=0.0, denominator=5.0) + + query = nested_query.NestedQuery([query1, query2]) + + record1 = [1.0, [12.0, 9.0]] # Clipped to [1.0, [4.0, 3.0]] + record2 = [5.0, [1.0, 2.0]] # Clipped to [4.0, [1.0, 2.0]] + + query_result = _run_query(query, [record1, record2]) + result = sess.run(query_result) + expected = [1.0, [1.0, 1.0]] + self.assertAllClose(result, expected) + + def test_complex_nested_query(self): + with self.cached_session() as sess: + query_ab = gaussian_query.GaussianSumQuery( + l2_norm_clip=1.0, stddev=0.0) + query_c = gaussian_query.GaussianAverageQuery( + l2_norm_clip=10.0, sum_stddev=0.0, denominator=2.0) + query_d = gaussian_query.GaussianSumQuery( + l2_norm_clip=10.0, stddev=0.0) + + query = nested_query.NestedQuery( + [query_ab, {'c': query_c, 'd': [query_d]}]) + + record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}] + record2 = [{'a': 3.14159, 'b': 0.0}, {'c': (6.0, -4.0), 'd': [5.0]}] + + query_result = _run_query(query, [record1, record2]) + result = sess.run(query_result) + expected = [{'a': 1.0, 'b': 1.0}, {'c': (1.0, 1.0), 'd': [1.0]}] + self.assertAllClose(result, expected) + + def test_nested_query_with_noise(self): + with self.cached_session() as sess: + sum_stddev = 2.71828 + denominator = 3.14159 + + query1 = gaussian_query.GaussianSumQuery( + l2_norm_clip=1.5, stddev=sum_stddev) + query2 = gaussian_query.GaussianAverageQuery( + l2_norm_clip=0.5, sum_stddev=sum_stddev, denominator=denominator) + query = nested_query.NestedQuery((query1, query2)) + + record1 = (3.0, [2.0, 1.5]) + record2 = (0.0, [-1.0, -3.5]) + + query_result = _run_query(query, [record1, record2]) + + noised_averages = [] + for _ in xrange(1000): + noised_averages.append(nest.flatten(sess.run(query_result))) + + result_stddev = np.std(noised_averages, 0) + avg_stddev = sum_stddev / denominator + expected_stddev = [sum_stddev, avg_stddev, avg_stddev] + self.assertArrayNear(result_stddev, expected_stddev, 0.1) + + @parameterized.named_parameters( + ('type_mismatch', [_basic_query], (1.0,), TypeError), + ('too_many_queries', [_basic_query, _basic_query], [1.0], ValueError), + ('too_many_records', [_basic_query, _basic_query], + [1.0, 2.0, 3.0], ValueError), + ('query_too_deep', [_basic_query, [_basic_query]], [1.0, 1.0], TypeError)) + def test_record_incompatible_with_query( + self, queries, record, error_type): + with self.assertRaises(error_type): + _run_query(nested_query.NestedQuery(queries), [record]) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tutorials/mnist_dpsgd_tutorial.py b/tutorials/mnist_dpsgd_tutorial.py index fdacd63..d87defb 100644 --- a/tutorials/mnist_dpsgd_tutorial.py +++ b/tutorials/mnist_dpsgd_tutorial.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Training a CNN on MNIST with differentially private Adam optimizer.""" +"""Training a CNN on MNIST with differentially private SGD optimizer.""" from __future__ import absolute_import from __future__ import division @@ -25,14 +25,14 @@ from privacy.analysis.rdp_accountant import compute_rdp from privacy.analysis.rdp_accountant import get_privacy_spent from privacy.optimizers import dp_optimizer -tf.flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-Adam. If False,' - 'train with vanilla Adam.') -tf.flags.DEFINE_float('learning_rate', 0.0015, 'Learning rate for training') -tf.flags.DEFINE_float('noise_multiplier', 1.0, +tf.flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-SGD. If False,' + 'train with vanilla SGD.') +tf.flags.DEFINE_float('learning_rate', 0.08, 'Learning rate for training') +tf.flags.DEFINE_float('noise_multiplier', 1.12, 'Ratio of the standard deviation to the clipping norm') tf.flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm') tf.flags.DEFINE_integer('batch_size', 256, 'Batch size') -tf.flags.DEFINE_integer('epochs', 15, 'Number of epochs') +tf.flags.DEFINE_integer('epochs', 60, 'Number of epochs') tf.flags.DEFINE_integer('microbatches', 256, 'Number of microbatches (must evenly divide batch_size') tf.flags.DEFINE_string('model_dir', None, 'Model directory') @@ -69,18 +69,19 @@ def cnn_model_fn(features, labels, mode): if mode == tf.estimator.ModeKeys.TRAIN: if FLAGS.dpsgd: - # Use DP version of AdamOptimizer. For illustration purposes, we do that - # here by calling make_optimizer_class() explicitly, though DP versions - # of standard optimizers are available in dp_optimizer. + # Use DP version of GradientDescentOptimizer. For illustration purposes, + # we do that here by calling make_optimizer_class() explicitly, though DP + # versions of standard optimizers are available in dp_optimizer. dp_optimizer_class = dp_optimizer.make_optimizer_class( - tf.train.AdamOptimizer) + tf.train.GradientDescentOptimizer) optimizer = dp_optimizer_class( learning_rate=FLAGS.learning_rate, noise_multiplier=FLAGS.noise_multiplier, l2_norm_clip=FLAGS.l2_norm_clip, num_microbatches=FLAGS.microbatches) else: - optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) + optimizer = tf.train.GradientDescentOptimizer( + learning_rate=FLAGS.learning_rate) global_step = tf.train.get_global_step() train_op = optimizer.minimize(loss=vector_loss, global_step=global_step) return tf.estimator.EstimatorSpec(mode=mode, @@ -177,7 +178,7 @@ def main(unused_argv): eps = compute_epsilon(epoch * steps_per_epoch) print('For delta=1e-5, the current epsilon is: %.2f' % eps) else: - print('Trained with vanilla non-private Adam optimizer') + print('Trained with vanilla non-private SGD optimizer') if __name__ == '__main__': tf.app.run()