From 3908429796a3d7fb7615f45592bb1a1fbf3f0e1a Mon Sep 17 00:00:00 2001 From: Galen Andrew Date: Wed, 15 May 2019 16:06:15 -0700 Subject: [PATCH] Make DPQuery classes (almost) completely functional: the only state from the initializer that is used gets pushed into the initial_global_state. PiperOrigin-RevId: 248424593 --- privacy/analysis/privacy_ledger.py | 4 +- privacy/dp_query/gaussian_query.py | 27 +++++-- privacy/dp_query/normalized_query.py | 35 ++++++--- .../quantile_adaptive_clip_sum_query.py | 75 ++++++++++++------- .../quantile_adaptive_clip_sum_query_test.py | 4 +- privacy/optimizers/dp_optimizer.py | 8 +- 6 files changed, 100 insertions(+), 53 deletions(-) diff --git a/privacy/analysis/privacy_ledger.py b/privacy/analysis/privacy_ledger.py index f6394d0..448f4cc 100644 --- a/privacy/analysis/privacy_ledger.py +++ b/privacy/analysis/privacy_ledger.py @@ -257,6 +257,6 @@ class QueryWithLedger(dp_query.DPQuery): with tf.control_dependencies([self._ledger.finalize_sample()]): return self._query.get_noised_result(sample_state, global_state) - def set_denominator(self, num_microbatches, microbatch_size=1): - self._query.set_denominator(num_microbatches) + def set_denominator(self, global_state, num_microbatches, microbatch_size=1): self._ledger.set_sample_size(num_microbatches * microbatch_size) + return self._query.set_denominator(global_state, num_microbatches) diff --git a/privacy/dp_query/gaussian_query.py b/privacy/dp_query/gaussian_query.py index 07817f0..35e0bcf 100644 --- a/privacy/dp_query/gaussian_query.py +++ b/privacy/dp_query/gaussian_query.py @@ -19,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + from distutils.version import LooseVersion import tensorflow as tf @@ -37,6 +39,10 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery): Accumulates clipped vectors, then adds Gaussian noise to the sum. """ + # pylint: disable=invalid-name + _GlobalState = collections.namedtuple( + '_GlobalState', ['l2_norm_clip', 'stddev']) + def __init__(self, l2_norm_clip, stddev, ledger=None): """Initializes the GaussianSumQuery. @@ -46,17 +52,26 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery): stddev: The stddev of the noise added to the sum. ledger: The privacy ledger to which queries should be recorded. """ - self._l2_norm_clip = tf.cast(l2_norm_clip, tf.float32) - self._stddev = tf.cast(stddev, tf.float32) + self._l2_norm_clip = l2_norm_clip + self._stddev = stddev self._ledger = ledger + def make_global_state(self, l2_norm_clip, stddev): + """Creates a global state from the given parameters.""" + return self._GlobalState(tf.cast(l2_norm_clip, tf.float32), + tf.cast(stddev, tf.float32)) + + def initial_global_state(self): + return self.make_global_state(self._l2_norm_clip, self._stddev) + def derive_sample_params(self, global_state): - return self._l2_norm_clip + return global_state.l2_norm_clip def initial_sample_state(self, global_state, template): if self._ledger: dependencies = [ - self._ledger.record_sum_query(self._l2_norm_clip, self._stddev) + self._ledger.record_sum_query( + global_state.l2_norm_clip, global_state.stddev) ] else: dependencies = [] @@ -89,9 +104,9 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery): """See base class.""" if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): def add_noise(v): - return v + tf.random_normal(tf.shape(v), stddev=self._stddev) + return v + tf.random_normal(tf.shape(v), stddev=global_state.stddev) else: - random_normal = tf.random_normal_initializer(stddev=self._stddev) + random_normal = tf.random_normal_initializer(stddev=global_state.stddev) def add_noise(v): return v + random_normal(tf.shape(v)) diff --git a/privacy/dp_query/normalized_query.py b/privacy/dp_query/normalized_query.py index 0cc73c4..c3ca4d4 100644 --- a/privacy/dp_query/normalized_query.py +++ b/privacy/dp_query/normalized_query.py @@ -19,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + from distutils.version import LooseVersion import tensorflow as tf @@ -33,6 +35,10 @@ else: class NormalizedQuery(dp_query.DPQuery): """DPQuery for queries with a DPQuery numerator and fixed denominator.""" + # pylint: disable=invalid-name + _GlobalState = collections.namedtuple( + '_GlobalState', ['numerator_state', 'denominator']) + def __init__(self, numerator_query, denominator): """Initializer for NormalizedQuery. @@ -43,22 +49,26 @@ class NormalizedQuery(dp_query.DPQuery): called. """ self._numerator = numerator_query - self._denominator = ( - tf.cast(denominator, tf.float32) if denominator is not None else None) + self._denominator = denominator def initial_global_state(self): """See base class.""" - # NormalizedQuery has no global state beyond the numerator state. - return self._numerator.initial_global_state() + if self._denominator is not None: + denominator = tf.cast(self._denominator, tf.float32) + else: + denominator = None + return self._GlobalState( + self._numerator.initial_global_state(), denominator) def derive_sample_params(self, global_state): """See base class.""" - return self._numerator.derive_sample_params(global_state) + return self._numerator.derive_sample_params(global_state.numerator_state) def initial_sample_state(self, global_state, template): """See base class.""" # NormalizedQuery has no sample state beyond the numerator state. - return self._numerator.initial_sample_state(global_state, template) + return self._numerator.initial_sample_state( + global_state.numerator_state, template) def preprocess_record(self, params, record): return self._numerator.preprocess_record(params, record) @@ -72,16 +82,17 @@ class NormalizedQuery(dp_query.DPQuery): def get_noised_result(self, sample_state, global_state): """See base class.""" noised_sum, new_sum_global_state = self._numerator.get_noised_result( - sample_state, global_state) + sample_state, global_state.numerator_state) def normalize(v): - return tf.truediv(v, self._denominator) + return tf.truediv(v, global_state.denominator) - return nest.map_structure(normalize, noised_sum), new_sum_global_state + return (nest.map_structure(normalize, noised_sum), + self._GlobalState(new_sum_global_state, global_state.denominator)) def merge_sample_states(self, sample_state_1, sample_state_2): """See base class.""" return self._numerator.merge_sample_states(sample_state_1, sample_state_2) - def set_denominator(self, denominator): - """Sets the denominator for the NormalizedQuery.""" - self._denominator = tf.cast(denominator, tf.float32) + def set_denominator(self, global_state, denominator): + """Returns an updated global_state with the given denominator.""" + return global_state._replace(denominator=tf.cast(denominator, tf.float32)) diff --git a/privacy/dp_query/quantile_adaptive_clip_sum_query.py b/privacy/dp_query/quantile_adaptive_clip_sum_query.py index a8eb400..6aa9785 100644 --- a/privacy/dp_query/quantile_adaptive_clip_sum_query.py +++ b/privacy/dp_query/quantile_adaptive_clip_sum_query.py @@ -45,7 +45,13 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): # pylint: disable=invalid-name _GlobalState = collections.namedtuple( - '_GlobalState', ['l2_norm_clip', 'sum_state', 'clipped_fraction_state']) + '_GlobalState', [ + 'l2_norm_clip', + 'noise_multiplier', + 'target_unclipped_quantile', + 'learning_rate', + 'sum_state', + 'clipped_fraction_state']) # pylint: disable=invalid-name _SampleState = collections.namedtuple( @@ -75,8 +81,7 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): found for which approximately 20% of updates are clipped each round. learning_rate: The learning rate for the clipping norm adaptation. A rate of r means that the clipping norm will change by a maximum of r at - each step. This maximum is attained when |clip - target| is 1.0. Can be - a tf.Variable for example to implement a learning rate schedule. + each step. This maximum is attained when |clip - target| is 1.0. clipped_count_stddev: The stddev of the noise added to the clipped_count. Since the sensitivity of the clipped count is 0.5, as a rule of thumb it should be about 0.5 for reasonable privacy. @@ -84,19 +89,14 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): estimate the clipped count quantile. ledger: The privacy ledger to which queries should be recorded. """ - self._initial_l2_norm_clip = tf.cast(initial_l2_norm_clip, tf.float32) - self._noise_multiplier = tf.cast(noise_multiplier, tf.float32) - self._target_unclipped_quantile = tf.cast( - target_unclipped_quantile, tf.float32) - self._learning_rate = tf.cast(learning_rate, tf.float32) + self._initial_l2_norm_clip = initial_l2_norm_clip + self._noise_multiplier = noise_multiplier + self._target_unclipped_quantile = target_unclipped_quantile + self._learning_rate = learning_rate - self._l2_norm_clip = tf.Variable(self._initial_l2_norm_clip) - self._sum_stddev = tf.Variable( - self._initial_l2_norm_clip * self._noise_multiplier) + # Initialize sum query's global state with None, to be set later. self._sum_query = gaussian_query.GaussianSumQuery( - self._l2_norm_clip, - self._sum_stddev, - ledger) + None, None, ledger) # self._clipped_fraction_query is a DPQuery used to estimate the fraction of # records that are clipped. It accumulates an indicator 0/1 of whether each @@ -115,29 +115,40 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): def initial_global_state(self): """See base class.""" + initial_l2_norm_clip = tf.cast(self._initial_l2_norm_clip, tf.float32) + noise_multiplier = tf.cast(self._noise_multiplier, tf.float32) + target_unclipped_quantile = tf.cast(self._target_unclipped_quantile, + tf.float32) + learning_rate = tf.cast(self._learning_rate, tf.float32) + sum_stddev = initial_l2_norm_clip * noise_multiplier + + sum_query_global_state = self._sum_query.make_global_state( + l2_norm_clip=initial_l2_norm_clip, + stddev=sum_stddev) + return self._GlobalState( - self._initial_l2_norm_clip, - self._sum_query.initial_global_state(), + initial_l2_norm_clip, + noise_multiplier, + target_unclipped_quantile, + learning_rate, + sum_query_global_state, self._clipped_fraction_query.initial_global_state()) def derive_sample_params(self, global_state): """See base class.""" - gs = global_state # Assign values to variables that inner sum query uses. - tf.assign(self._l2_norm_clip, gs.l2_norm_clip) - tf.assign(self._sum_stddev, gs.l2_norm_clip * self._noise_multiplier) - sum_params = self._sum_query.derive_sample_params(gs.sum_state) + sum_params = self._sum_query.derive_sample_params(global_state.sum_state) clipped_fraction_params = self._clipped_fraction_query.derive_sample_params( - gs.clipped_fraction_state) + global_state.clipped_fraction_state) return self._SampleParams(sum_params, clipped_fraction_params) def initial_sample_state(self, global_state, template): """See base class.""" - clipped_fraction_state = self._clipped_fraction_query.initial_sample_state( - global_state.clipped_fraction_state, tf.constant(0.0)) sum_state = self._sum_query.initial_sample_state( global_state.sum_state, template) + clipped_fraction_state = self._clipped_fraction_query.initial_sample_state( + global_state.clipped_fraction_state, tf.constant(0.0)) return self._SampleState(sum_state, clipped_fraction_state) def preprocess_record(self, params, record): @@ -187,6 +198,7 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): noised_vectors, sum_state = self._sum_query.get_noised_result( sample_state.sum_state, gs.sum_state) + del sum_state # Unused. To be set explicitly later. clipped_fraction_result, new_clipped_fraction_state = ( self._clipped_fraction_query.get_noised_result( @@ -202,15 +214,20 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): # Loss function is convex, with derivative in [-1, 1], and minimized when # the true quantile matches the target. - loss_grad = unclipped_quantile - self._target_unclipped_quantile + loss_grad = unclipped_quantile - global_state.target_unclipped_quantile - new_l2_norm_clip = gs.l2_norm_clip - self._learning_rate * loss_grad + new_l2_norm_clip = gs.l2_norm_clip - global_state.learning_rate * loss_grad new_l2_norm_clip = tf.maximum(0.0, new_l2_norm_clip) - new_global_state = self._GlobalState( - new_l2_norm_clip, - sum_state, - new_clipped_fraction_state) + new_sum_stddev = new_l2_norm_clip * global_state.noise_multiplier + new_sum_query_global_state = self._sum_query.make_global_state( + l2_norm_clip=new_l2_norm_clip, + stddev=new_sum_stddev) + + new_global_state = global_state._replace( + l2_norm_clip=new_l2_norm_clip, + sum_state=new_sum_query_global_state, + clipped_fraction_state=new_clipped_fraction_state) return noised_vectors, new_global_state diff --git a/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py b/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py index f24c9c0..396d240 100644 --- a/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py +++ b/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py @@ -270,7 +270,7 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase): tf.assign(selection_probability, 0.1) _, global_state = test_utils.run_query(query, [record1, record2]) - expected_queries = [[0.5, 0.0], [10.0, 10.0]] + expected_queries = [[10.0, 10.0], [0.5, 0.0]] formatted = ledger.get_formatted_ledger_eager() sample_1 = formatted[0] self.assertAllClose(sample_1.population_size, 10.0) @@ -288,7 +288,7 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase): self.assertAllClose(sample_1.selection_probability, 0.1) self.assertAllClose(sample_1.queries, expected_queries) - expected_queries_2 = [[0.5, 0.0], [9.0, 9.0]] + expected_queries_2 = [[9.0, 9.0], [0.5, 0.0]] self.assertAllClose(sample_2.population_size, 20.0) self.assertAllClose(sample_2.selection_probability, 0.2) self.assertAllClose(sample_2.queries, expected_queries_2) diff --git a/privacy/optimizers/dp_optimizer.py b/privacy/optimizers/dp_optimizer.py index 2d0ab0f..2b27191 100644 --- a/privacy/optimizers/dp_optimizer.py +++ b/privacy/optimizers/dp_optimizer.py @@ -88,7 +88,9 @@ def make_optimizer_class(cls): vector_loss = loss() if self._num_microbatches is None: self._num_microbatches = tf.shape(vector_loss)[0] - self._dp_average_query.set_denominator(self._num_microbatches) + self._global_state = self._dp_average_query.set_denominator( + self._global_state, + self._num_microbatches) sample_state = self._dp_average_query.initial_sample_state( self._global_state, var_list) microbatches_losses = tf.reshape(vector_loss, @@ -126,7 +128,9 @@ def make_optimizer_class(cls): # sampling from the dataset without replacement. if self._num_microbatches is None: self._num_microbatches = tf.shape(loss)[0] - self._dp_average_query.set_denominator(self._num_microbatches) + self._global_state = self._dp_average_query.set_denominator( + self._global_state, + self._num_microbatches) microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1]) sample_params = ( self._dp_average_query.derive_sample_params(self._global_state))