Remove unused global_state reference from initial_sample_state.

global_state is never used in any of our existing DPQueries, and we don't have any compelling use case.

PiperOrigin-RevId: 255480537
This commit is contained in:
Galen Andrew 2019-06-27 14:37:30 -07:00 committed by A. Unique TensorFlower
parent 6171474465
commit 973a1759aa
10 changed files with 19 additions and 27 deletions

View file

@ -226,9 +226,9 @@ class QueryWithLedger(dp_query.DPQuery):
"""See base class."""
return self._query.derive_sample_params(global_state)
def initial_sample_state(self, global_state, template):
def initial_sample_state(self, template):
"""See base class."""
return self._query.initial_sample_state(global_state, template)
return self._query.initial_sample_state(template)
def preprocess_record(self, params, record):
"""See base class."""

View file

@ -88,11 +88,10 @@ class DPQuery(object):
return ()
@abc.abstractmethod
def initial_sample_state(self, global_state, template):
def initial_sample_state(self, template):
"""Returns an initial state to use for the next sample.
Args:
global_state: The current global state.
template: A nested structure of tensors, TensorSpecs, or numpy arrays used
as a template to create the initial sample state. It is assumed that the
leaves of the structure are python scalars or some type that has
@ -216,8 +215,7 @@ def zeros_like(arg):
class SumAggregationDPQuery(DPQuery):
"""Base class for DPQueries that aggregate via sum."""
def initial_sample_state(self, global_state, template):
del global_state # unused.
def initial_sample_state(self, template):
return nest.map_structure(zeros_like, template)
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):

View file

@ -69,7 +69,7 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
def derive_sample_params(self, global_state):
return global_state.l2_norm_clip
def initial_sample_state(self, global_state, template):
def initial_sample_state(self, template):
return nest.map_structure(
dp_query.zeros_like, template)

View file

@ -99,7 +99,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=1.0)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
sample_state = query.initial_sample_state(global_state, records[0])
sample_state = query.initial_sample_state(records[0])
for record in records:
sample_state = query.accumulate_record(params, sample_state, record)
return sample_state

View file

@ -73,9 +73,9 @@ class NestedQuery(dp_query.DPQuery):
"""See base class."""
return self._map_to_queries('derive_sample_params', global_state)
def initial_sample_state(self, global_state, template):
def initial_sample_state(self, template):
"""See base class."""
return self._map_to_queries('initial_sample_state', global_state, template)
return self._map_to_queries('initial_sample_state', template)
def preprocess_record(self, params, record):
"""See base class."""

View file

@ -45,11 +45,9 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
Accumulates vectors and normalizes by the total number of accumulated vectors.
"""
def initial_sample_state(self, global_state, template):
def initial_sample_state(self, template):
"""See base class."""
return (
super(NoPrivacyAverageQuery, self).initial_sample_state(
global_state, template),
return (super(NoPrivacyAverageQuery, self).initial_sample_state(template),
tf.constant(0.0))
def preprocess_record(self, params, record, weight=1):

View file

@ -68,11 +68,10 @@ class NormalizedQuery(dp_query.DPQuery):
"""See base class."""
return self._numerator.derive_sample_params(global_state.numerator_state)
def initial_sample_state(self, global_state, template):
def initial_sample_state(self, template):
"""See base class."""
# NormalizedQuery has no sample state beyond the numerator state.
return self._numerator.initial_sample_state(
global_state.numerator_state, template)
return self._numerator.initial_sample_state(template)
def preprocess_record(self, params, record):
return self._numerator.preprocess_record(params, record)

View file

@ -144,12 +144,11 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
global_state.clipped_fraction_state)
return self._SampleParams(sum_params, clipped_fraction_params)
def initial_sample_state(self, global_state, template):
def initial_sample_state(self, template):
"""See base class."""
sum_state = self._sum_query.initial_sample_state(
global_state.sum_state, template)
sum_state = self._sum_query.initial_sample_state(template)
clipped_fraction_state = self._clipped_fraction_query.initial_sample_state(
global_state.clipped_fraction_state, tf.constant(0.0))
tf.constant(0.0))
return self._SampleState(sum_state, clipped_fraction_state)
def preprocess_record(self, params, record):

View file

@ -38,7 +38,7 @@ def run_query(query, records, global_state=None, weights=None):
if not global_state:
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
sample_state = query.initial_sample_state(global_state, next(iter(records)))
sample_state = query.initial_sample_state(next(iter(records)))
if weights is None:
for record in records:
sample_state = query.accumulate_record(params, sample_state, record)

View file

@ -95,8 +95,7 @@ def make_optimizer_class(cls):
self._num_microbatches = tf.shape(vector_loss)[0]
if isinstance(self._dp_sum_query, privacy_ledger.QueryWithLedger):
self._dp_sum_query.set_batch_size(self._num_microbatches)
sample_state = self._dp_sum_query.initial_sample_state(
self._global_state, var_list)
sample_state = self._dp_sum_query.initial_sample_state(var_list)
microbatches_losses = tf.reshape(vector_loss,
[self._num_microbatches, -1])
sample_params = (
@ -162,8 +161,7 @@ def make_optimizer_class(cls):
tf.trainable_variables() + tf.get_collection(
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
sample_state = self._dp_sum_query.initial_sample_state(
self._global_state, var_list)
sample_state = self._dp_sum_query.initial_sample_state(var_list)
if self._unroll_microbatches:
for idx in range(self._num_microbatches):