From 973a1759aa272868a759f6045a635d1173a5676f Mon Sep 17 00:00:00 2001 From: Galen Andrew Date: Thu, 27 Jun 2019 14:37:30 -0700 Subject: [PATCH] Remove unused global_state reference from initial_sample_state. global_state is never used in any of our existing DPQueries, and we don't have any compelling use case. PiperOrigin-RevId: 255480537 --- privacy/analysis/privacy_ledger.py | 4 ++-- privacy/dp_query/dp_query.py | 6 ++---- privacy/dp_query/gaussian_query.py | 2 +- privacy/dp_query/gaussian_query_test.py | 2 +- privacy/dp_query/nested_query.py | 4 ++-- privacy/dp_query/no_privacy_query.py | 8 +++----- privacy/dp_query/normalized_query.py | 5 ++--- privacy/dp_query/quantile_adaptive_clip_sum_query.py | 7 +++---- privacy/dp_query/test_utils.py | 2 +- privacy/optimizers/dp_optimizer.py | 6 ++---- 10 files changed, 19 insertions(+), 27 deletions(-) diff --git a/privacy/analysis/privacy_ledger.py b/privacy/analysis/privacy_ledger.py index 9c29eb9..0416aa2 100644 --- a/privacy/analysis/privacy_ledger.py +++ b/privacy/analysis/privacy_ledger.py @@ -226,9 +226,9 @@ class QueryWithLedger(dp_query.DPQuery): """See base class.""" return self._query.derive_sample_params(global_state) - def initial_sample_state(self, global_state, template): + def initial_sample_state(self, template): """See base class.""" - return self._query.initial_sample_state(global_state, template) + return self._query.initial_sample_state(template) def preprocess_record(self, params, record): """See base class.""" diff --git a/privacy/dp_query/dp_query.py b/privacy/dp_query/dp_query.py index 4fa4fe3..e85d6c4 100644 --- a/privacy/dp_query/dp_query.py +++ b/privacy/dp_query/dp_query.py @@ -88,11 +88,10 @@ class DPQuery(object): return () @abc.abstractmethod - def initial_sample_state(self, global_state, template): + def initial_sample_state(self, template): """Returns an initial state to use for the next sample. Args: - global_state: The current global state. template: A nested structure of tensors, TensorSpecs, or numpy arrays used as a template to create the initial sample state. It is assumed that the leaves of the structure are python scalars or some type that has @@ -216,8 +215,7 @@ def zeros_like(arg): class SumAggregationDPQuery(DPQuery): """Base class for DPQueries that aggregate via sum.""" - def initial_sample_state(self, global_state, template): - del global_state # unused. + def initial_sample_state(self, template): return nest.map_structure(zeros_like, template) def accumulate_preprocessed_record(self, sample_state, preprocessed_record): diff --git a/privacy/dp_query/gaussian_query.py b/privacy/dp_query/gaussian_query.py index 2977f91..3fc7be1 100644 --- a/privacy/dp_query/gaussian_query.py +++ b/privacy/dp_query/gaussian_query.py @@ -69,7 +69,7 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery): def derive_sample_params(self, global_state): return global_state.l2_norm_clip - def initial_sample_state(self, global_state, template): + def initial_sample_state(self, template): return nest.map_structure( dp_query.zeros_like, template) diff --git a/privacy/dp_query/gaussian_query_test.py b/privacy/dp_query/gaussian_query_test.py index e2a1db0..913c3a8 100644 --- a/privacy/dp_query/gaussian_query_test.py +++ b/privacy/dp_query/gaussian_query_test.py @@ -99,7 +99,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase): query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=1.0) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) - sample_state = query.initial_sample_state(global_state, records[0]) + sample_state = query.initial_sample_state(records[0]) for record in records: sample_state = query.accumulate_record(params, sample_state, record) return sample_state diff --git a/privacy/dp_query/nested_query.py b/privacy/dp_query/nested_query.py index 62c1f5f..90efbf1 100644 --- a/privacy/dp_query/nested_query.py +++ b/privacy/dp_query/nested_query.py @@ -73,9 +73,9 @@ class NestedQuery(dp_query.DPQuery): """See base class.""" return self._map_to_queries('derive_sample_params', global_state) - def initial_sample_state(self, global_state, template): + def initial_sample_state(self, template): """See base class.""" - return self._map_to_queries('initial_sample_state', global_state, template) + return self._map_to_queries('initial_sample_state', template) def preprocess_record(self, params, record): """See base class.""" diff --git a/privacy/dp_query/no_privacy_query.py b/privacy/dp_query/no_privacy_query.py index 68731b4..6928f01 100644 --- a/privacy/dp_query/no_privacy_query.py +++ b/privacy/dp_query/no_privacy_query.py @@ -45,12 +45,10 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery): Accumulates vectors and normalizes by the total number of accumulated vectors. """ - def initial_sample_state(self, global_state, template): + def initial_sample_state(self, template): """See base class.""" - return ( - super(NoPrivacyAverageQuery, self).initial_sample_state( - global_state, template), - tf.constant(0.0)) + return (super(NoPrivacyAverageQuery, self).initial_sample_state(template), + tf.constant(0.0)) def preprocess_record(self, params, record, weight=1): """Multiplies record by weight.""" diff --git a/privacy/dp_query/normalized_query.py b/privacy/dp_query/normalized_query.py index 6e0d833..8f7dcc0 100644 --- a/privacy/dp_query/normalized_query.py +++ b/privacy/dp_query/normalized_query.py @@ -68,11 +68,10 @@ class NormalizedQuery(dp_query.DPQuery): """See base class.""" return self._numerator.derive_sample_params(global_state.numerator_state) - def initial_sample_state(self, global_state, template): + def initial_sample_state(self, template): """See base class.""" # NormalizedQuery has no sample state beyond the numerator state. - return self._numerator.initial_sample_state( - global_state.numerator_state, template) + return self._numerator.initial_sample_state(template) def preprocess_record(self, params, record): return self._numerator.preprocess_record(params, record) diff --git a/privacy/dp_query/quantile_adaptive_clip_sum_query.py b/privacy/dp_query/quantile_adaptive_clip_sum_query.py index eaa516b..8960c14 100644 --- a/privacy/dp_query/quantile_adaptive_clip_sum_query.py +++ b/privacy/dp_query/quantile_adaptive_clip_sum_query.py @@ -144,12 +144,11 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): global_state.clipped_fraction_state) return self._SampleParams(sum_params, clipped_fraction_params) - def initial_sample_state(self, global_state, template): + def initial_sample_state(self, template): """See base class.""" - sum_state = self._sum_query.initial_sample_state( - global_state.sum_state, template) + sum_state = self._sum_query.initial_sample_state(template) clipped_fraction_state = self._clipped_fraction_query.initial_sample_state( - global_state.clipped_fraction_state, tf.constant(0.0)) + tf.constant(0.0)) return self._SampleState(sum_state, clipped_fraction_state) def preprocess_record(self, params, record): diff --git a/privacy/dp_query/test_utils.py b/privacy/dp_query/test_utils.py index f418b71..18456b3 100644 --- a/privacy/dp_query/test_utils.py +++ b/privacy/dp_query/test_utils.py @@ -38,7 +38,7 @@ def run_query(query, records, global_state=None, weights=None): if not global_state: global_state = query.initial_global_state() params = query.derive_sample_params(global_state) - sample_state = query.initial_sample_state(global_state, next(iter(records))) + sample_state = query.initial_sample_state(next(iter(records))) if weights is None: for record in records: sample_state = query.accumulate_record(params, sample_state, record) diff --git a/privacy/optimizers/dp_optimizer.py b/privacy/optimizers/dp_optimizer.py index 59cfe13..e70086f 100644 --- a/privacy/optimizers/dp_optimizer.py +++ b/privacy/optimizers/dp_optimizer.py @@ -95,8 +95,7 @@ def make_optimizer_class(cls): self._num_microbatches = tf.shape(vector_loss)[0] if isinstance(self._dp_sum_query, privacy_ledger.QueryWithLedger): self._dp_sum_query.set_batch_size(self._num_microbatches) - sample_state = self._dp_sum_query.initial_sample_state( - self._global_state, var_list) + sample_state = self._dp_sum_query.initial_sample_state(var_list) microbatches_losses = tf.reshape(vector_loss, [self._num_microbatches, -1]) sample_params = ( @@ -162,8 +161,7 @@ def make_optimizer_class(cls): tf.trainable_variables() + tf.get_collection( tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) - sample_state = self._dp_sum_query.initial_sample_state( - self._global_state, var_list) + sample_state = self._dp_sum_query.initial_sample_state(var_list) if self._unroll_microbatches: for idx in range(self._num_microbatches):