Remove unused global_state reference from initial_sample_state.
global_state is never used in any of our existing DPQueries, and we don't have any compelling use case. PiperOrigin-RevId: 255480537
This commit is contained in:
parent
6171474465
commit
973a1759aa
10 changed files with 19 additions and 27 deletions
|
@ -226,9 +226,9 @@ class QueryWithLedger(dp_query.DPQuery):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._query.derive_sample_params(global_state)
|
return self._query.derive_sample_params(global_state)
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, template):
|
def initial_sample_state(self, template):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._query.initial_sample_state(global_state, template)
|
return self._query.initial_sample_state(template)
|
||||||
|
|
||||||
def preprocess_record(self, params, record):
|
def preprocess_record(self, params, record):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
|
|
|
@ -88,11 +88,10 @@ class DPQuery(object):
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def initial_sample_state(self, global_state, template):
|
def initial_sample_state(self, template):
|
||||||
"""Returns an initial state to use for the next sample.
|
"""Returns an initial state to use for the next sample.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
global_state: The current global state.
|
|
||||||
template: A nested structure of tensors, TensorSpecs, or numpy arrays used
|
template: A nested structure of tensors, TensorSpecs, or numpy arrays used
|
||||||
as a template to create the initial sample state. It is assumed that the
|
as a template to create the initial sample state. It is assumed that the
|
||||||
leaves of the structure are python scalars or some type that has
|
leaves of the structure are python scalars or some type that has
|
||||||
|
@ -216,8 +215,7 @@ def zeros_like(arg):
|
||||||
class SumAggregationDPQuery(DPQuery):
|
class SumAggregationDPQuery(DPQuery):
|
||||||
"""Base class for DPQueries that aggregate via sum."""
|
"""Base class for DPQueries that aggregate via sum."""
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, template):
|
def initial_sample_state(self, template):
|
||||||
del global_state # unused.
|
|
||||||
return nest.map_structure(zeros_like, template)
|
return nest.map_structure(zeros_like, template)
|
||||||
|
|
||||||
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
|
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
|
||||||
|
|
|
@ -69,7 +69,7 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
return global_state.l2_norm_clip
|
return global_state.l2_norm_clip
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, template):
|
def initial_sample_state(self, template):
|
||||||
return nest.map_structure(
|
return nest.map_structure(
|
||||||
dp_query.zeros_like, template)
|
dp_query.zeros_like, template)
|
||||||
|
|
||||||
|
|
|
@ -99,7 +99,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=1.0)
|
query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=1.0)
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
params = query.derive_sample_params(global_state)
|
params = query.derive_sample_params(global_state)
|
||||||
sample_state = query.initial_sample_state(global_state, records[0])
|
sample_state = query.initial_sample_state(records[0])
|
||||||
for record in records:
|
for record in records:
|
||||||
sample_state = query.accumulate_record(params, sample_state, record)
|
sample_state = query.accumulate_record(params, sample_state, record)
|
||||||
return sample_state
|
return sample_state
|
||||||
|
|
|
@ -73,9 +73,9 @@ class NestedQuery(dp_query.DPQuery):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._map_to_queries('derive_sample_params', global_state)
|
return self._map_to_queries('derive_sample_params', global_state)
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, template):
|
def initial_sample_state(self, template):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._map_to_queries('initial_sample_state', global_state, template)
|
return self._map_to_queries('initial_sample_state', template)
|
||||||
|
|
||||||
def preprocess_record(self, params, record):
|
def preprocess_record(self, params, record):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
|
|
|
@ -45,12 +45,10 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
|
||||||
Accumulates vectors and normalizes by the total number of accumulated vectors.
|
Accumulates vectors and normalizes by the total number of accumulated vectors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, template):
|
def initial_sample_state(self, template):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return (
|
return (super(NoPrivacyAverageQuery, self).initial_sample_state(template),
|
||||||
super(NoPrivacyAverageQuery, self).initial_sample_state(
|
tf.constant(0.0))
|
||||||
global_state, template),
|
|
||||||
tf.constant(0.0))
|
|
||||||
|
|
||||||
def preprocess_record(self, params, record, weight=1):
|
def preprocess_record(self, params, record, weight=1):
|
||||||
"""Multiplies record by weight."""
|
"""Multiplies record by weight."""
|
||||||
|
|
|
@ -68,11 +68,10 @@ class NormalizedQuery(dp_query.DPQuery):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._numerator.derive_sample_params(global_state.numerator_state)
|
return self._numerator.derive_sample_params(global_state.numerator_state)
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, template):
|
def initial_sample_state(self, template):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
# NormalizedQuery has no sample state beyond the numerator state.
|
# NormalizedQuery has no sample state beyond the numerator state.
|
||||||
return self._numerator.initial_sample_state(
|
return self._numerator.initial_sample_state(template)
|
||||||
global_state.numerator_state, template)
|
|
||||||
|
|
||||||
def preprocess_record(self, params, record):
|
def preprocess_record(self, params, record):
|
||||||
return self._numerator.preprocess_record(params, record)
|
return self._numerator.preprocess_record(params, record)
|
||||||
|
|
|
@ -144,12 +144,11 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
|
||||||
global_state.clipped_fraction_state)
|
global_state.clipped_fraction_state)
|
||||||
return self._SampleParams(sum_params, clipped_fraction_params)
|
return self._SampleParams(sum_params, clipped_fraction_params)
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, template):
|
def initial_sample_state(self, template):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
sum_state = self._sum_query.initial_sample_state(
|
sum_state = self._sum_query.initial_sample_state(template)
|
||||||
global_state.sum_state, template)
|
|
||||||
clipped_fraction_state = self._clipped_fraction_query.initial_sample_state(
|
clipped_fraction_state = self._clipped_fraction_query.initial_sample_state(
|
||||||
global_state.clipped_fraction_state, tf.constant(0.0))
|
tf.constant(0.0))
|
||||||
return self._SampleState(sum_state, clipped_fraction_state)
|
return self._SampleState(sum_state, clipped_fraction_state)
|
||||||
|
|
||||||
def preprocess_record(self, params, record):
|
def preprocess_record(self, params, record):
|
||||||
|
|
|
@ -38,7 +38,7 @@ def run_query(query, records, global_state=None, weights=None):
|
||||||
if not global_state:
|
if not global_state:
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
params = query.derive_sample_params(global_state)
|
params = query.derive_sample_params(global_state)
|
||||||
sample_state = query.initial_sample_state(global_state, next(iter(records)))
|
sample_state = query.initial_sample_state(next(iter(records)))
|
||||||
if weights is None:
|
if weights is None:
|
||||||
for record in records:
|
for record in records:
|
||||||
sample_state = query.accumulate_record(params, sample_state, record)
|
sample_state = query.accumulate_record(params, sample_state, record)
|
||||||
|
|
|
@ -95,8 +95,7 @@ def make_optimizer_class(cls):
|
||||||
self._num_microbatches = tf.shape(vector_loss)[0]
|
self._num_microbatches = tf.shape(vector_loss)[0]
|
||||||
if isinstance(self._dp_sum_query, privacy_ledger.QueryWithLedger):
|
if isinstance(self._dp_sum_query, privacy_ledger.QueryWithLedger):
|
||||||
self._dp_sum_query.set_batch_size(self._num_microbatches)
|
self._dp_sum_query.set_batch_size(self._num_microbatches)
|
||||||
sample_state = self._dp_sum_query.initial_sample_state(
|
sample_state = self._dp_sum_query.initial_sample_state(var_list)
|
||||||
self._global_state, var_list)
|
|
||||||
microbatches_losses = tf.reshape(vector_loss,
|
microbatches_losses = tf.reshape(vector_loss,
|
||||||
[self._num_microbatches, -1])
|
[self._num_microbatches, -1])
|
||||||
sample_params = (
|
sample_params = (
|
||||||
|
@ -162,8 +161,7 @@ def make_optimizer_class(cls):
|
||||||
tf.trainable_variables() + tf.get_collection(
|
tf.trainable_variables() + tf.get_collection(
|
||||||
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
||||||
|
|
||||||
sample_state = self._dp_sum_query.initial_sample_state(
|
sample_state = self._dp_sum_query.initial_sample_state(var_list)
|
||||||
self._global_state, var_list)
|
|
||||||
|
|
||||||
if self._unroll_microbatches:
|
if self._unroll_microbatches:
|
||||||
for idx in range(self._num_microbatches):
|
for idx in range(self._num_microbatches):
|
||||||
|
|
Loading…
Reference in a new issue