From 45bcb3a0e4d5ed7759af348725e7d031e0e22b96 Mon Sep 17 00:00:00 2001 From: Ilya Mironov Date: Mon, 24 Jun 2019 12:49:48 -0700 Subject: [PATCH 01/10] Adding privacy analysis to the Logistic Regression for MNIST tutorial. PiperOrigin-RevId: 254815428 --- tutorials/README.md | 4 + ...gression_mnist.py => mnist_lr_tutorial.py} | 145 +++++++++++------- 2 files changed, 92 insertions(+), 57 deletions(-) rename tutorials/{logistic_regression_mnist.py => mnist_lr_tutorial.py} (59%) diff --git a/tutorials/README.md b/tutorials/README.md index 94b5cef..d3f60d3 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -20,6 +20,10 @@ Here is a list of all the tutorials included: * `mnist_dpsgd_tutorial_keras.py`: learn a convolutional neural network on MNIST with differential privacy using tf.Keras. +* `mnist_lr_tutorial.py`: learn a differentially private logistic regression + model on MNIST. The model illustrates application of the + "amplification-by-iteration" analysis (https://arxiv.org/abs/1808.06651). + The rest of this README describes the different parameters used to configure DP-SGD as well as expected outputs for the `mnist_dpsgd_tutorial.py` tutorial. diff --git a/tutorials/logistic_regression_mnist.py b/tutorials/mnist_lr_tutorial.py similarity index 59% rename from tutorials/logistic_regression_mnist.py rename to tutorials/mnist_lr_tutorial.py index 694ee7d..62f446d 100644 --- a/tutorials/logistic_regression_mnist.py +++ b/tutorials/mnist_lr_tutorial.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """DP Logistic Regression on MNIST. DP Logistic Regression on MNIST with support for privacy-by-iteration analysis. -Feldman, Vitaly, Ilya Mironov, Kunal Talwar, and Abhradeep Thakurta. +Vitaly Feldman, Ilya Mironov, Kunal Talwar, and Abhradeep Thakurta. "Privacy amplification by iteration." In 2018 IEEE 59th Annual Symposium on Foundations of Computer Science (FOCS), pp. 521-532. IEEE, 2018. @@ -36,6 +35,8 @@ from distutils.version import LooseVersion import numpy as np import tensorflow as tf +from privacy.analysis.rdp_accountant import compute_rdp +from privacy.analysis.rdp_accountant import get_privacy_spent from privacy.optimizers import dp_optimizer if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): @@ -45,32 +46,30 @@ else: FLAGS = flags.FLAGS -flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-SGD. If False, ' - 'train with vanilla SGD.') +flags.DEFINE_boolean( + 'dpsgd', True, 'If True, train with DP-SGD. If False, ' + 'train with vanilla SGD.') flags.DEFINE_float('learning_rate', 0.001, 'Learning rate for training') -flags.DEFINE_float('noise_multiplier', 0.02, +flags.DEFINE_float('noise_multiplier', 0.05, 'Ratio of the standard deviation to the clipping norm') -flags.DEFINE_integer('batch_size', 1, 'Batch size') +flags.DEFINE_integer('batch_size', 5, 'Batch size') flags.DEFINE_integer('epochs', 5, 'Number of epochs') -flags.DEFINE_integer('microbatches', 1, 'Number of microbatches ' - '(must evenly divide batch_size)') flags.DEFINE_float('regularizer', 0, 'L2 regularizer coefficient') flags.DEFINE_string('model_dir', None, 'Model directory') -flags.DEFINE_float('data_l2_norm', 8, - 'Bound on the L2 norm of normalized data.') +flags.DEFINE_float('data_l2_norm', 8, 'Bound on the L2 norm of normalized data') def lr_model_fn(features, labels, mode, nclasses, dim): """Model function for logistic regression.""" input_layer = tf.reshape(features['x'], tuple([-1]) + dim) - logits = tf.layers.dense(inputs=input_layer, - units=nclasses, - kernel_regularizer=tf.contrib.layers.l2_regularizer( - scale=FLAGS.regularizer), - bias_regularizer=tf.contrib.layers.l2_regularizer( - scale=FLAGS.regularizer) - ) + logits = tf.layers.dense( + inputs=input_layer, + units=nclasses, + kernel_regularizer=tf.contrib.layers.l2_regularizer( + scale=FLAGS.regularizer), + bias_regularizer=tf.contrib.layers.l2_regularizer( + scale=FLAGS.regularizer)) # Calculate loss as a vector (to support microbatches in DP-SGD). vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( @@ -80,18 +79,15 @@ def lr_model_fn(features, labels, mode, nclasses, dim): # Configure the training op (for TRAIN mode). if mode == tf.estimator.ModeKeys.TRAIN: - if FLAGS.dpsgd: - # Use DP version of GradientDescentOptimizer. Other optimizers are - # available in dp_optimizer. Most optimizers inheriting from - # tf.train.Optimizer should be wrappable in differentially private - # counterparts by calling dp_optimizer.optimizer_from_args(). # The loss function is L-Lipschitz with L = sqrt(2*(||x||^2 + 1)) where # ||x|| is the norm of the data. + # We don't use microbatches (thus speeding up computation), since no + # clipping is necessary due to data normalization. optimizer = dp_optimizer.DPGradientDescentGaussianOptimizer( - l2_norm_clip=math.sqrt(2*(FLAGS.data_l2_norm**2 + 1)), + l2_norm_clip=math.sqrt(2 * (FLAGS.data_l2_norm**2 + 1)), noise_multiplier=FLAGS.noise_multiplier, - num_microbatches=FLAGS.microbatches, + num_microbatches=1, learning_rate=FLAGS.learning_rate) opt_loss = vector_loss else: @@ -103,21 +99,18 @@ def lr_model_fn(features, labels, mode, nclasses, dim): # the vector_loss because tf.estimator requires a scalar loss. This is only # used for evaluation and debugging by tf.estimator. The actual loss being # minimized is opt_loss defined above and passed to optimizer.minimize(). - return tf.estimator.EstimatorSpec(mode=mode, - loss=scalar_loss, - train_op=train_op) + return tf.estimator.EstimatorSpec( + mode=mode, loss=scalar_loss, train_op=train_op) # Add evaluation metrics (for EVAL mode). elif mode == tf.estimator.ModeKeys.EVAL: eval_metric_ops = { 'accuracy': tf.metrics.accuracy( - labels=labels, - predictions=tf.argmax(input=logits, axis=1)) + labels=labels, predictions=tf.argmax(input=logits, axis=1)) } - return tf.estimator.EstimatorSpec(mode=mode, - loss=scalar_loss, - eval_metric_ops=eval_metric_ops) + return tf.estimator.EstimatorSpec( + mode=mode, loss=scalar_loss, eval_metric_ops=eval_metric_ops) def normalize_data(data, data_l2_norm): @@ -146,7 +139,7 @@ def load_mnist(data_l2_norm=float('inf')): train_data = train_data.reshape(train_data.shape[0], -1) test_data = test_data.reshape(test_data.shape[0], -1) - idx = np.random.permutation(len(train_data)) # shuffle data once + idx = np.random.permutation(len(train_data)) # shuffle data once train_data = train_data[idx] train_labels = train_labels[idx] @@ -159,14 +152,50 @@ def load_mnist(data_l2_norm=float('inf')): return train_data, train_labels, test_data, test_labels +def print_privacy_guarantees(epochs, batch_size, samples, noise_multiplier): + """Tabulating position-dependent privacy guarantees.""" + if noise_multiplier == 0: + print('No differential privacy (additive noise is 0).') + return + + print('In the conditions of Theorem 34 (https://arxiv.org/abs/1808.06651) ' + 'the training procedure results in the following privacy guarantees.') + + print('Out of the total of {} samples:'.format(samples)) + + steps_per_epoch = samples // batch_size + orders = np.concatenate( + [np.linspace(2, 20, num=181), + np.linspace(20, 100, num=81)]) + delta = 1e-5 + for p in (.5, .9, .99): + steps = math.ceil(steps_per_epoch * p) # Steps in the last epoch. + coef = 2 * (noise_multiplier * batch_size)**-2 * ( + # Accounting for privacy loss + (epochs - 1) / steps_per_epoch + # ... from all-but-last epochs + 1 / (steps_per_epoch - steps + 1)) # ... due to the last epoch + # Using RDP accountant to compute eps. Doing computation analytically is + # an option. + rdp = [order * coef for order in orders] + eps, _, _ = get_privacy_spent(orders, rdp, target_delta=delta) + print('\t{:g}% enjoy at least ({:.2f}, {})-DP'.format( + p * 100, eps, delta)) + + # Compute privacy guarantees for the Sampled Gaussian Mechanism. + rdp_sgm = compute_rdp(batch_size / samples, noise_multiplier, + epochs * steps_per_epoch, orders) + eps_sgm, _, _ = get_privacy_spent(orders, rdp_sgm, target_delta=delta) + print('By comparison, DP-SGD analysis for training done with the same ' + 'parameters and random shuffling in each epoch guarantees ' + '({:.2f}, {})-DP for all samples.'.format(eps_sgm, delta)) + + def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) - if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0: - raise ValueError('Number of microbatches should divide evenly batch_size') if FLAGS.data_l2_norm <= 0: - raise ValueError('FLAGS.data_l2_norm needs to be positive.') - if FLAGS.learning_rate > 8 / FLAGS.data_l2_norm**2: - raise ValueError('The amplification by iteration analysis requires' + raise ValueError('data_l2_norm must be positive.') + if FLAGS.dpsgd and FLAGS.learning_rate > 8 / FLAGS.data_l2_norm**2: + raise ValueError('The amplification-by-iteration analysis requires' 'learning_rate <= 2 / beta, where beta is the smoothness' 'of the loss function and is upper bounded by ||x||^2 / 4' 'with ||x|| being the largest L2 norm of the samples.') @@ -178,15 +207,12 @@ def main(unused_argv): train_data, train_labels, test_data, test_labels = load_mnist( data_l2_norm=FLAGS.data_l2_norm) - # Instantiate the tf.Estimator. + # Instantiate tf.Estimator. # pylint: disable=g-long-lambda - model_fn = lambda features, labels, mode: lr_model_fn(features, labels, mode, - nclasses=10, - dim=train_data.shape[1:] - ) + model_fn = lambda features, labels, mode: lr_model_fn( + features, labels, mode, nclasses=10, dim=train_data.shape[1:]) mnist_classifier = tf.estimator.Estimator( - model_fn=model_fn, - model_dir=FLAGS.model_dir) + model_fn=model_fn, model_dir=FLAGS.model_dir) # Create tf.Estimator input functions for the training and test data. # To analyze the per-user privacy loss, we keep the same orders of samples in @@ -198,22 +224,27 @@ def main(unused_argv): num_epochs=FLAGS.epochs, shuffle=False) eval_input_fn = tf.estimator.inputs.numpy_input_fn( - x={'x': test_data}, - y=test_labels, - num_epochs=1, - shuffle=False) + x={'x': test_data}, y=test_labels, num_epochs=1, shuffle=False) - # Train the model - steps_per_epoch = train_data.shape[0] // FLAGS.batch_size - mnist_classifier.train(input_fn=train_input_fn, - steps=steps_per_epoch * FLAGS.epochs) + # Train the model. + num_samples = train_data.shape[0] + steps_per_epoch = num_samples // FLAGS.batch_size - # Evaluate the model and print results + mnist_classifier.train( + input_fn=train_input_fn, steps=steps_per_epoch * FLAGS.epochs) + + # Evaluate the model and print results. eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) - test_accuracy = eval_results['accuracy'] - print('Test accuracy after %d epochs is: %.3f' % (FLAGS.epochs, - test_accuracy)) + print('Test accuracy after {} epochs is: {:.2f}'.format( + FLAGS.epochs, eval_results['accuracy'])) + if FLAGS.dpsgd: + print_privacy_guarantees( + epochs=FLAGS.epochs, + batch_size=FLAGS.batch_size, + samples=num_samples, + noise_multiplier=FLAGS.noise_multiplier, + ) if __name__ == '__main__': app.run(main) From 6171474465cc07c13ce720a3637bf17ba787aea0 Mon Sep 17 00:00:00 2001 From: Nicolas Papernot Date: Tue, 25 Jun 2019 17:05:41 -0700 Subject: [PATCH 02/10] harmonize analysis parameters with current DPSGD API PiperOrigin-RevId: 255080643 --- privacy/analysis/compute_dp_sgd_privacy.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/privacy/analysis/compute_dp_sgd_privacy.py b/privacy/analysis/compute_dp_sgd_privacy.py index 4ca2ab1..e4bab4d 100644 --- a/privacy/analysis/compute_dp_sgd_privacy.py +++ b/privacy/analysis/compute_dp_sgd_privacy.py @@ -61,6 +61,10 @@ flags.mark_flag_as_required('epochs') def apply_dp_sgd_analysis(q, sigma, steps, orders, delta): """Compute and print results of DP-SGD analysis.""" + # compute_rdp requires that sigma be the ratio of the standard deviation of + # the Gaussian noise to the l2-sensitivity of the function to which it is + # added. Hence, sigma here corresponds to the `noise_multiplier` parameter + # in the DP-SGD implementation found in privacy.optimizers.dp_optimizer rdp = compute_rdp(q, sigma, steps, orders) eps, _, opt_order = get_privacy_spent(orders, rdp, target_delta=delta) @@ -80,13 +84,10 @@ def main(argv): del argv # argv is not used. q = FLAGS.batch_size / FLAGS.N # q - the sampling ratio. - if q > 1: raise app.UsageError('N must be larger than the batch size.') - orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] + list(range(5, 64)) + [128, 256, 512]) - steps = int(math.ceil(FLAGS.epochs * FLAGS.N / FLAGS.batch_size)) apply_dp_sgd_analysis(q, FLAGS.noise_multiplier, steps, orders, FLAGS.delta) From 973a1759aa272868a759f6045a635d1173a5676f Mon Sep 17 00:00:00 2001 From: Galen Andrew Date: Thu, 27 Jun 2019 14:37:30 -0700 Subject: [PATCH 03/10] Remove unused global_state reference from initial_sample_state. global_state is never used in any of our existing DPQueries, and we don't have any compelling use case. PiperOrigin-RevId: 255480537 --- privacy/analysis/privacy_ledger.py | 4 ++-- privacy/dp_query/dp_query.py | 6 ++---- privacy/dp_query/gaussian_query.py | 2 +- privacy/dp_query/gaussian_query_test.py | 2 +- privacy/dp_query/nested_query.py | 4 ++-- privacy/dp_query/no_privacy_query.py | 8 +++----- privacy/dp_query/normalized_query.py | 5 ++--- privacy/dp_query/quantile_adaptive_clip_sum_query.py | 7 +++---- privacy/dp_query/test_utils.py | 2 +- privacy/optimizers/dp_optimizer.py | 6 ++---- 10 files changed, 19 insertions(+), 27 deletions(-) diff --git a/privacy/analysis/privacy_ledger.py b/privacy/analysis/privacy_ledger.py index 9c29eb9..0416aa2 100644 --- a/privacy/analysis/privacy_ledger.py +++ b/privacy/analysis/privacy_ledger.py @@ -226,9 +226,9 @@ class QueryWithLedger(dp_query.DPQuery): """See base class.""" return self._query.derive_sample_params(global_state) - def initial_sample_state(self, global_state, template): + def initial_sample_state(self, template): """See base class.""" - return self._query.initial_sample_state(global_state, template) + return self._query.initial_sample_state(template) def preprocess_record(self, params, record): """See base class.""" diff --git a/privacy/dp_query/dp_query.py b/privacy/dp_query/dp_query.py index 4fa4fe3..e85d6c4 100644 --- a/privacy/dp_query/dp_query.py +++ b/privacy/dp_query/dp_query.py @@ -88,11 +88,10 @@ class DPQuery(object): return () @abc.abstractmethod - def initial_sample_state(self, global_state, template): + def initial_sample_state(self, template): """Returns an initial state to use for the next sample. Args: - global_state: The current global state. template: A nested structure of tensors, TensorSpecs, or numpy arrays used as a template to create the initial sample state. It is assumed that the leaves of the structure are python scalars or some type that has @@ -216,8 +215,7 @@ def zeros_like(arg): class SumAggregationDPQuery(DPQuery): """Base class for DPQueries that aggregate via sum.""" - def initial_sample_state(self, global_state, template): - del global_state # unused. + def initial_sample_state(self, template): return nest.map_structure(zeros_like, template) def accumulate_preprocessed_record(self, sample_state, preprocessed_record): diff --git a/privacy/dp_query/gaussian_query.py b/privacy/dp_query/gaussian_query.py index 2977f91..3fc7be1 100644 --- a/privacy/dp_query/gaussian_query.py +++ b/privacy/dp_query/gaussian_query.py @@ -69,7 +69,7 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery): def derive_sample_params(self, global_state): return global_state.l2_norm_clip - def initial_sample_state(self, global_state, template): + def initial_sample_state(self, template): return nest.map_structure( dp_query.zeros_like, template) diff --git a/privacy/dp_query/gaussian_query_test.py b/privacy/dp_query/gaussian_query_test.py index e2a1db0..913c3a8 100644 --- a/privacy/dp_query/gaussian_query_test.py +++ b/privacy/dp_query/gaussian_query_test.py @@ -99,7 +99,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase): query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=1.0) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) - sample_state = query.initial_sample_state(global_state, records[0]) + sample_state = query.initial_sample_state(records[0]) for record in records: sample_state = query.accumulate_record(params, sample_state, record) return sample_state diff --git a/privacy/dp_query/nested_query.py b/privacy/dp_query/nested_query.py index 62c1f5f..90efbf1 100644 --- a/privacy/dp_query/nested_query.py +++ b/privacy/dp_query/nested_query.py @@ -73,9 +73,9 @@ class NestedQuery(dp_query.DPQuery): """See base class.""" return self._map_to_queries('derive_sample_params', global_state) - def initial_sample_state(self, global_state, template): + def initial_sample_state(self, template): """See base class.""" - return self._map_to_queries('initial_sample_state', global_state, template) + return self._map_to_queries('initial_sample_state', template) def preprocess_record(self, params, record): """See base class.""" diff --git a/privacy/dp_query/no_privacy_query.py b/privacy/dp_query/no_privacy_query.py index 68731b4..6928f01 100644 --- a/privacy/dp_query/no_privacy_query.py +++ b/privacy/dp_query/no_privacy_query.py @@ -45,12 +45,10 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery): Accumulates vectors and normalizes by the total number of accumulated vectors. """ - def initial_sample_state(self, global_state, template): + def initial_sample_state(self, template): """See base class.""" - return ( - super(NoPrivacyAverageQuery, self).initial_sample_state( - global_state, template), - tf.constant(0.0)) + return (super(NoPrivacyAverageQuery, self).initial_sample_state(template), + tf.constant(0.0)) def preprocess_record(self, params, record, weight=1): """Multiplies record by weight.""" diff --git a/privacy/dp_query/normalized_query.py b/privacy/dp_query/normalized_query.py index 6e0d833..8f7dcc0 100644 --- a/privacy/dp_query/normalized_query.py +++ b/privacy/dp_query/normalized_query.py @@ -68,11 +68,10 @@ class NormalizedQuery(dp_query.DPQuery): """See base class.""" return self._numerator.derive_sample_params(global_state.numerator_state) - def initial_sample_state(self, global_state, template): + def initial_sample_state(self, template): """See base class.""" # NormalizedQuery has no sample state beyond the numerator state. - return self._numerator.initial_sample_state( - global_state.numerator_state, template) + return self._numerator.initial_sample_state(template) def preprocess_record(self, params, record): return self._numerator.preprocess_record(params, record) diff --git a/privacy/dp_query/quantile_adaptive_clip_sum_query.py b/privacy/dp_query/quantile_adaptive_clip_sum_query.py index eaa516b..8960c14 100644 --- a/privacy/dp_query/quantile_adaptive_clip_sum_query.py +++ b/privacy/dp_query/quantile_adaptive_clip_sum_query.py @@ -144,12 +144,11 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): global_state.clipped_fraction_state) return self._SampleParams(sum_params, clipped_fraction_params) - def initial_sample_state(self, global_state, template): + def initial_sample_state(self, template): """See base class.""" - sum_state = self._sum_query.initial_sample_state( - global_state.sum_state, template) + sum_state = self._sum_query.initial_sample_state(template) clipped_fraction_state = self._clipped_fraction_query.initial_sample_state( - global_state.clipped_fraction_state, tf.constant(0.0)) + tf.constant(0.0)) return self._SampleState(sum_state, clipped_fraction_state) def preprocess_record(self, params, record): diff --git a/privacy/dp_query/test_utils.py b/privacy/dp_query/test_utils.py index f418b71..18456b3 100644 --- a/privacy/dp_query/test_utils.py +++ b/privacy/dp_query/test_utils.py @@ -38,7 +38,7 @@ def run_query(query, records, global_state=None, weights=None): if not global_state: global_state = query.initial_global_state() params = query.derive_sample_params(global_state) - sample_state = query.initial_sample_state(global_state, next(iter(records))) + sample_state = query.initial_sample_state(next(iter(records))) if weights is None: for record in records: sample_state = query.accumulate_record(params, sample_state, record) diff --git a/privacy/optimizers/dp_optimizer.py b/privacy/optimizers/dp_optimizer.py index 59cfe13..e70086f 100644 --- a/privacy/optimizers/dp_optimizer.py +++ b/privacy/optimizers/dp_optimizer.py @@ -95,8 +95,7 @@ def make_optimizer_class(cls): self._num_microbatches = tf.shape(vector_loss)[0] if isinstance(self._dp_sum_query, privacy_ledger.QueryWithLedger): self._dp_sum_query.set_batch_size(self._num_microbatches) - sample_state = self._dp_sum_query.initial_sample_state( - self._global_state, var_list) + sample_state = self._dp_sum_query.initial_sample_state(var_list) microbatches_losses = tf.reshape(vector_loss, [self._num_microbatches, -1]) sample_params = ( @@ -162,8 +161,7 @@ def make_optimizer_class(cls): tf.trainable_variables() + tf.get_collection( tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) - sample_state = self._dp_sum_query.initial_sample_state( - self._global_state, var_list) + sample_state = self._dp_sum_query.initial_sample_state(var_list) if self._unroll_microbatches: for idx in range(self._num_microbatches): From bb7956ed7e8e819d482edef05a440caa063f9239 Mon Sep 17 00:00:00 2001 From: Nicolas Papernot Date: Tue, 16 Jul 2019 13:58:56 -0700 Subject: [PATCH 04/10] fix keras typo PiperOrigin-RevId: 258434656 --- tutorials/mnist_dpsgd_tutorial_keras.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tutorials/mnist_dpsgd_tutorial_keras.py b/tutorials/mnist_dpsgd_tutorial_keras.py index 71f67cb..865fb9f 100644 --- a/tutorials/mnist_dpsgd_tutorial_keras.py +++ b/tutorials/mnist_dpsgd_tutorial_keras.py @@ -41,10 +41,10 @@ flags.DEFINE_float('learning_rate', 0.15, 'Learning rate for training') flags.DEFINE_float('noise_multiplier', 1.1, 'Ratio of the standard deviation to the clipping norm') flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm') -flags.DEFINE_integer('batch_size', 256, 'Batch size') +flags.DEFINE_integer('batch_size', 250, 'Batch size') flags.DEFINE_integer('epochs', 60, 'Number of epochs') flags.DEFINE_integer( - 'microbatches', 256, 'Number of microbatches ' + 'microbatches', 250, 'Number of microbatches ' '(must evenly divide batch_size)') flags.DEFINE_string('model_dir', None, 'Model directory') @@ -121,9 +121,8 @@ def main(unused_argv): optimizer = DPGradientDescentGaussianOptimizer( l2_norm_clip=FLAGS.l2_norm_clip, noise_multiplier=FLAGS.noise_multiplier, - num_microbatches=FLAGS.num_microbatches, - learning_rate=FLAGS.learning_rate, - unroll_microbatches=True) + num_microbatches=FLAGS.microbatches, + learning_rate=FLAGS.learning_rate) # Compute vector of per-example loss rather than its mean over a minibatch. loss = tf.keras.losses.CategoricalCrossentropy( from_logits=True, reduction=tf.losses.Reduction.NONE) From 98723b9c3ad7b5b5fdeb4bc347e3bd79e54cbc1a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 17 Jul 2019 15:43:54 -0700 Subject: [PATCH 05/10] Added rdp_accountant dependency to privacy/BUILD. PiperOrigin-RevId: 258657061 --- privacy/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/privacy/BUILD b/privacy/BUILD index dd7e102..6efaad7 100644 --- a/privacy/BUILD +++ b/privacy/BUILD @@ -9,6 +9,7 @@ py_library( srcs = ["__init__.py"], deps = [ "//third_party/py/tensorflow_privacy/privacy/analysis:privacy_ledger", + "//third_party/py/tensorflow_privacy/privacy/analysis:rdp_accountant", "//third_party/py/tensorflow_privacy/privacy/dp_query", "//third_party/py/tensorflow_privacy/privacy/dp_query:gaussian_query", "//third_party/py/tensorflow_privacy/privacy/dp_query:nested_query", From 3de1fcd829020d837ea0109c1c61ee42c1ad209d Mon Sep 17 00:00:00 2001 From: Nicolas Papernot Date: Fri, 19 Jul 2019 13:40:45 -0700 Subject: [PATCH 06/10] add pylint to README PiperOrigin-RevId: 259029555 --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 6fd80a3..5720801 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,10 @@ GitHub pull requests. To speed the code review process, we ask that: your pull requests. In most cases this can be done by running `autopep8 -i --indent-size 2 ` on the files you have edited. +* You should also check your code with pylint and TensorFlow's pylint + [configuration file](https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/tools/ci_build/pylintrc) + by running `pylint --rcfile=/path/to/the/tf/rcfile `. + * When making your first pull request, you [sign the Google CLA](https://cla.developers.google.com/clas) From 3072c86c79733274823390cdd2b07fce731c2a07 Mon Sep 17 00:00:00 2001 From: jvmancuso Date: Sat, 20 Jul 2019 10:43:48 -0400 Subject: [PATCH 07/10] find nest module based on TF version for quantile_adaptive_clip_sum_query.py --- privacy/dp_query/quantile_adaptive_clip_sum_query.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/privacy/dp_query/quantile_adaptive_clip_sum_query.py b/privacy/dp_query/quantile_adaptive_clip_sum_query.py index 8960c14..d8f0cbb 100644 --- a/privacy/dp_query/quantile_adaptive_clip_sum_query.py +++ b/privacy/dp_query/quantile_adaptive_clip_sum_query.py @@ -26,6 +26,7 @@ from __future__ import division from __future__ import print_function import collections +from distutils.version import LooseVersion import tensorflow as tf @@ -33,7 +34,10 @@ from privacy.dp_query import dp_query from privacy.dp_query import gaussian_query from privacy.dp_query import normalized_query -nest = tf.contrib.framework.nest +if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): + nest = tf.contrib.framework.nest +else: + nest = tf.nest class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): From 5cd2439401eebad2aed370f615842cc61a5230f6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 24 Jul 2019 18:08:03 -0700 Subject: [PATCH 08/10] Remove calls to _dp_sum_query.set_batch_size in dp_optimizer.py, as no method with that name exists for objects of class QueryWithLedger. PiperOrigin-RevId: 259858031 --- privacy/optimizers/dp_optimizer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/privacy/optimizers/dp_optimizer.py b/privacy/optimizers/dp_optimizer.py index e70086f..83a3f4d 100644 --- a/privacy/optimizers/dp_optimizer.py +++ b/privacy/optimizers/dp_optimizer.py @@ -93,8 +93,6 @@ def make_optimizer_class(cls): vector_loss = loss() if self._num_microbatches is None: self._num_microbatches = tf.shape(vector_loss)[0] - if isinstance(self._dp_sum_query, privacy_ledger.QueryWithLedger): - self._dp_sum_query.set_batch_size(self._num_microbatches) sample_state = self._dp_sum_query.initial_sample_state(var_list) microbatches_losses = tf.reshape(vector_loss, [self._num_microbatches, -1]) @@ -135,8 +133,6 @@ def make_optimizer_class(cls): # sampling from the dataset without replacement. if self._num_microbatches is None: self._num_microbatches = tf.shape(loss)[0] - if isinstance(self._dp_sum_query, privacy_ledger.QueryWithLedger): - self._dp_sum_query.set_batch_size(self._num_microbatches) microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1]) sample_params = ( From c14a5464402af2700a29771373a5d33770b23e0e Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Mon, 29 Jul 2019 10:16:30 -0700 Subject: [PATCH 09/10] Explicitly mark Python binaries/tests with python_version = "PY2". PiperOrigin-RevId: 260525846 --- privacy/dp_query/BUILD | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/privacy/dp_query/BUILD b/privacy/dp_query/BUILD index 2494096..32815fb 100644 --- a/privacy/dp_query/BUILD +++ b/privacy/dp_query/BUILD @@ -26,6 +26,7 @@ py_test( name = "gaussian_query_test", size = "small", srcs = ["gaussian_query_test.py"], + python_version = "PY2", deps = [ ":gaussian_query", ":test_utils", @@ -50,6 +51,7 @@ py_test( name = "no_privacy_query_test", size = "small", srcs = ["no_privacy_query_test.py"], + python_version = "PY2", deps = [ ":no_privacy_query", ":test_utils", @@ -72,6 +74,7 @@ py_test( name = "normalized_query_test", size = "small", srcs = ["normalized_query_test.py"], + python_version = "PY2", deps = [ ":gaussian_query", ":normalized_query", @@ -94,6 +97,7 @@ py_test( name = "nested_query_test", size = "small", srcs = ["nested_query_test.py"], + python_version = "PY2", deps = [ ":gaussian_query", ":nested_query", @@ -119,6 +123,7 @@ py_library( py_test( name = "quantile_adaptive_clip_sum_query_test", srcs = ["quantile_adaptive_clip_sum_query_test.py"], + python_version = "PY2", deps = [ ":quantile_adaptive_clip_sum_query", ":test_utils", From c08f3ebdc76d93679865888f93ca8b5224f70536 Mon Sep 17 00:00:00 2001 From: Zachary Garrett Date: Thu, 1 Aug 2019 16:32:08 -0700 Subject: [PATCH 10/10] Workaround until the new `bolt_on` module is integrated into the rest of the TF Privacy build system. PiperOrigin-RevId: 261222062 --- privacy/__init__.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/privacy/__init__.py b/privacy/__init__.py index aab6e94..f530a11 100644 --- a/privacy/__init__.py +++ b/privacy/__init__.py @@ -13,6 +13,10 @@ # limitations under the License. """TensorFlow Privacy library.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + import sys # pylint: disable=g-import-not-at-top @@ -42,8 +46,11 @@ else: from privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer - from privacy.bolt_on.models import BoltOnModel - from privacy.bolt_on.optimizers import BoltOn - from privacy.bolt_on.losses import StrongConvexMixin - from privacy.bolt_on.losses import StrongConvexBinaryCrossentropy - from privacy.bolt_on.losses import StrongConvexHuber + try: + from privacy.bolt_on.models import BoltOnModel + from privacy.bolt_on.optimizers import BoltOn + from privacy.bolt_on.losses import StrongConvexMixin + from privacy.bolt_on.losses import StrongConvexBinaryCrossentropy + from privacy.bolt_on.losses import StrongConvexHuber + except ImportError: + print('module `bolt_on` was not found in this version of TF Privacy')