Remove unused global_state reference from initial_sample_state.

global_state is never used in any of our existing DPQueries, and we don't have any compelling use case.

PiperOrigin-RevId: 255480537
This commit is contained in:
Galen Andrew 2019-06-27 14:37:30 -07:00 committed by A. Unique TensorFlower
parent 6171474465
commit 973a1759aa
10 changed files with 19 additions and 27 deletions

View file

@ -226,9 +226,9 @@ class QueryWithLedger(dp_query.DPQuery):
"""See base class.""" """See base class."""
return self._query.derive_sample_params(global_state) return self._query.derive_sample_params(global_state)
def initial_sample_state(self, global_state, template): def initial_sample_state(self, template):
"""See base class.""" """See base class."""
return self._query.initial_sample_state(global_state, template) return self._query.initial_sample_state(template)
def preprocess_record(self, params, record): def preprocess_record(self, params, record):
"""See base class.""" """See base class."""

View file

@ -88,11 +88,10 @@ class DPQuery(object):
return () return ()
@abc.abstractmethod @abc.abstractmethod
def initial_sample_state(self, global_state, template): def initial_sample_state(self, template):
"""Returns an initial state to use for the next sample. """Returns an initial state to use for the next sample.
Args: Args:
global_state: The current global state.
template: A nested structure of tensors, TensorSpecs, or numpy arrays used template: A nested structure of tensors, TensorSpecs, or numpy arrays used
as a template to create the initial sample state. It is assumed that the as a template to create the initial sample state. It is assumed that the
leaves of the structure are python scalars or some type that has leaves of the structure are python scalars or some type that has
@ -216,8 +215,7 @@ def zeros_like(arg):
class SumAggregationDPQuery(DPQuery): class SumAggregationDPQuery(DPQuery):
"""Base class for DPQueries that aggregate via sum.""" """Base class for DPQueries that aggregate via sum."""
def initial_sample_state(self, global_state, template): def initial_sample_state(self, template):
del global_state # unused.
return nest.map_structure(zeros_like, template) return nest.map_structure(zeros_like, template)
def accumulate_preprocessed_record(self, sample_state, preprocessed_record): def accumulate_preprocessed_record(self, sample_state, preprocessed_record):

View file

@ -69,7 +69,7 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
def derive_sample_params(self, global_state): def derive_sample_params(self, global_state):
return global_state.l2_norm_clip return global_state.l2_norm_clip
def initial_sample_state(self, global_state, template): def initial_sample_state(self, template):
return nest.map_structure( return nest.map_structure(
dp_query.zeros_like, template) dp_query.zeros_like, template)

View file

@ -99,7 +99,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=1.0) query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=1.0)
global_state = query.initial_global_state() global_state = query.initial_global_state()
params = query.derive_sample_params(global_state) params = query.derive_sample_params(global_state)
sample_state = query.initial_sample_state(global_state, records[0]) sample_state = query.initial_sample_state(records[0])
for record in records: for record in records:
sample_state = query.accumulate_record(params, sample_state, record) sample_state = query.accumulate_record(params, sample_state, record)
return sample_state return sample_state

View file

@ -73,9 +73,9 @@ class NestedQuery(dp_query.DPQuery):
"""See base class.""" """See base class."""
return self._map_to_queries('derive_sample_params', global_state) return self._map_to_queries('derive_sample_params', global_state)
def initial_sample_state(self, global_state, template): def initial_sample_state(self, template):
"""See base class.""" """See base class."""
return self._map_to_queries('initial_sample_state', global_state, template) return self._map_to_queries('initial_sample_state', template)
def preprocess_record(self, params, record): def preprocess_record(self, params, record):
"""See base class.""" """See base class."""

View file

@ -45,11 +45,9 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
Accumulates vectors and normalizes by the total number of accumulated vectors. Accumulates vectors and normalizes by the total number of accumulated vectors.
""" """
def initial_sample_state(self, global_state, template): def initial_sample_state(self, template):
"""See base class.""" """See base class."""
return ( return (super(NoPrivacyAverageQuery, self).initial_sample_state(template),
super(NoPrivacyAverageQuery, self).initial_sample_state(
global_state, template),
tf.constant(0.0)) tf.constant(0.0))
def preprocess_record(self, params, record, weight=1): def preprocess_record(self, params, record, weight=1):

View file

@ -68,11 +68,10 @@ class NormalizedQuery(dp_query.DPQuery):
"""See base class.""" """See base class."""
return self._numerator.derive_sample_params(global_state.numerator_state) return self._numerator.derive_sample_params(global_state.numerator_state)
def initial_sample_state(self, global_state, template): def initial_sample_state(self, template):
"""See base class.""" """See base class."""
# NormalizedQuery has no sample state beyond the numerator state. # NormalizedQuery has no sample state beyond the numerator state.
return self._numerator.initial_sample_state( return self._numerator.initial_sample_state(template)
global_state.numerator_state, template)
def preprocess_record(self, params, record): def preprocess_record(self, params, record):
return self._numerator.preprocess_record(params, record) return self._numerator.preprocess_record(params, record)

View file

@ -144,12 +144,11 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
global_state.clipped_fraction_state) global_state.clipped_fraction_state)
return self._SampleParams(sum_params, clipped_fraction_params) return self._SampleParams(sum_params, clipped_fraction_params)
def initial_sample_state(self, global_state, template): def initial_sample_state(self, template):
"""See base class.""" """See base class."""
sum_state = self._sum_query.initial_sample_state( sum_state = self._sum_query.initial_sample_state(template)
global_state.sum_state, template)
clipped_fraction_state = self._clipped_fraction_query.initial_sample_state( clipped_fraction_state = self._clipped_fraction_query.initial_sample_state(
global_state.clipped_fraction_state, tf.constant(0.0)) tf.constant(0.0))
return self._SampleState(sum_state, clipped_fraction_state) return self._SampleState(sum_state, clipped_fraction_state)
def preprocess_record(self, params, record): def preprocess_record(self, params, record):

View file

@ -38,7 +38,7 @@ def run_query(query, records, global_state=None, weights=None):
if not global_state: if not global_state:
global_state = query.initial_global_state() global_state = query.initial_global_state()
params = query.derive_sample_params(global_state) params = query.derive_sample_params(global_state)
sample_state = query.initial_sample_state(global_state, next(iter(records))) sample_state = query.initial_sample_state(next(iter(records)))
if weights is None: if weights is None:
for record in records: for record in records:
sample_state = query.accumulate_record(params, sample_state, record) sample_state = query.accumulate_record(params, sample_state, record)

View file

@ -95,8 +95,7 @@ def make_optimizer_class(cls):
self._num_microbatches = tf.shape(vector_loss)[0] self._num_microbatches = tf.shape(vector_loss)[0]
if isinstance(self._dp_sum_query, privacy_ledger.QueryWithLedger): if isinstance(self._dp_sum_query, privacy_ledger.QueryWithLedger):
self._dp_sum_query.set_batch_size(self._num_microbatches) self._dp_sum_query.set_batch_size(self._num_microbatches)
sample_state = self._dp_sum_query.initial_sample_state( sample_state = self._dp_sum_query.initial_sample_state(var_list)
self._global_state, var_list)
microbatches_losses = tf.reshape(vector_loss, microbatches_losses = tf.reshape(vector_loss,
[self._num_microbatches, -1]) [self._num_microbatches, -1])
sample_params = ( sample_params = (
@ -162,8 +161,7 @@ def make_optimizer_class(cls):
tf.trainable_variables() + tf.get_collection( tf.trainable_variables() + tf.get_collection(
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
sample_state = self._dp_sum_query.initial_sample_state( sample_state = self._dp_sum_query.initial_sample_state(var_list)
self._global_state, var_list)
if self._unroll_microbatches: if self._unroll_microbatches:
for idx in range(self._num_microbatches): for idx in range(self._num_microbatches):