forked from 626_privacy/tensorflow_privacy
Extensions to DPQuery and subclasses.
1. Split DPQuery.accumulate_record function into preprocess_record and accumulate_preprocessed_record. 2. Add merge_sample_state function. 3. Add default implementations for some functions in DPQuery, and add base class SumAggregationDPQuery that implements some more. Only get_noised_result is still abstract. 4. Enforce that all states and parameters used as inputs and outputs to DPQuery functions are nested structures of tensors by replacing numbers with constants and Nones with empty tuples. PiperOrigin-RevId: 247975791
This commit is contained in:
parent
82852c0e71
commit
1d1a6e087a
9 changed files with 221 additions and 212 deletions
|
@ -234,13 +234,22 @@ class QueryWithLedger(dp_query.DPQuery):
|
|||
"""See base class."""
|
||||
return self._query.derive_sample_params(global_state)
|
||||
|
||||
def initial_sample_state(self, global_state, tensors):
|
||||
def initial_sample_state(self, global_state, template):
|
||||
"""See base class."""
|
||||
return self._query.initial_sample_state(global_state, tensors)
|
||||
return self._query.initial_sample_state(global_state, template)
|
||||
|
||||
def accumulate_record(self, params, sample_state, record):
|
||||
def preprocess_record(self, params, record):
|
||||
"""See base class."""
|
||||
return self._query.accumulate_record(params, sample_state, record)
|
||||
return self._query.preprocess_record(params, record)
|
||||
|
||||
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
|
||||
"""See base class."""
|
||||
return self._query.accumulate_preprocessed_record(
|
||||
sample_state, preprocessed_record)
|
||||
|
||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||
"""See base class."""
|
||||
return self._query.merge_sample_states(sample_state_1, sample_state_2)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Ensures sample is recorded to the ledger and returns noised result."""
|
||||
|
|
|
@ -5,6 +5,10 @@ licenses(["notice"]) # Apache 2.0
|
|||
py_library(
|
||||
name = "dp_query",
|
||||
srcs = ["dp_query.py"],
|
||||
deps = [
|
||||
"//third_party/py/distutils",
|
||||
"//third_party/py/tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2018, The TensorFlow Authors.
|
||||
# Copyright 2019, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -47,6 +47,13 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import tensorflow as tf
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
nest = tf.contrib.framework.nest
|
||||
else:
|
||||
nest = tf.nest
|
||||
|
||||
|
||||
class DPQuery(object):
|
||||
|
@ -54,12 +61,10 @@ class DPQuery(object):
|
|||
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
@abc.abstractmethod
|
||||
def initial_global_state(self):
|
||||
"""Returns the initial global state for the DPQuery."""
|
||||
pass
|
||||
return ()
|
||||
|
||||
@abc.abstractmethod
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Given the global state, derives parameters to use for the next sample.
|
||||
|
||||
|
@ -69,25 +74,74 @@ class DPQuery(object):
|
|||
Returns:
|
||||
Parameters to use to process records in the next sample.
|
||||
"""
|
||||
pass
|
||||
del global_state # unused.
|
||||
return ()
|
||||
|
||||
@abc.abstractmethod
|
||||
def initial_sample_state(self, global_state, tensors):
|
||||
def initial_sample_state(self, global_state, template):
|
||||
"""Returns an initial state to use for the next sample.
|
||||
|
||||
Args:
|
||||
global_state: The current global state.
|
||||
tensors: A structure of tensors used as a template to create the initial
|
||||
sample state.
|
||||
template: A nested structure of tensors, TensorSpecs, or numpy arrays used
|
||||
as a template to create the initial sample state. It is assumed that the
|
||||
leaves of the structure are python scalars or some type that has
|
||||
properties `shape` and `dtype`.
|
||||
|
||||
Returns: An initial sample state.
|
||||
"""
|
||||
pass
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""Preprocesses a single record.
|
||||
|
||||
This preprocessing is applied to one client's record, e.g. selecting vectors
|
||||
and clipping them to a fixed L2 norm. This method can be executed in a
|
||||
separate TF session, or even on a different machine, so it should not depend
|
||||
on any TF inputs other than those provided as input arguments. In
|
||||
particular, implementations should avoid accessing any TF tensors or
|
||||
variables that are stored in self.
|
||||
|
||||
Args:
|
||||
params: The parameters for the sample. In standard DP-SGD training,
|
||||
the clipping norm for the sample's microbatch gradients (i.e.,
|
||||
a maximum norm magnitude to which each gradient is clipped)
|
||||
record: The record to be processed. In standard DP-SGD training,
|
||||
the gradient computed for the examples in one microbatch, which
|
||||
may be the gradient for just one example (for size 1 microbatches).
|
||||
|
||||
Returns:
|
||||
A structure of tensors to be aggregated.
|
||||
"""
|
||||
del params # unused.
|
||||
return record
|
||||
|
||||
@abc.abstractmethod
|
||||
def accumulate_preprocessed_record(
|
||||
self, sample_state, preprocessed_record):
|
||||
"""Accumulates a single preprocessed record into the sample state.
|
||||
|
||||
This method is intended to only do simple aggregation, typically just a sum.
|
||||
In the future, we might remove this method and replace it with a way to
|
||||
declaratively specify the type of aggregation required.
|
||||
|
||||
Args:
|
||||
sample_state: The current sample state. In standard DP-SGD training,
|
||||
the accumulated sum of previous clipped microbatch gradients.
|
||||
preprocessed_record: The preprocessed record to accumulate.
|
||||
|
||||
Returns:
|
||||
The updated sample state.
|
||||
"""
|
||||
pass
|
||||
|
||||
def accumulate_record(self, params, sample_state, record):
|
||||
"""Accumulates a single record into the sample state.
|
||||
|
||||
This is a helper method that simply delegates to `preprocess_record` and
|
||||
`accumulate_preprocessed_record` for the common case when both of those
|
||||
functions run on a single device.
|
||||
|
||||
Args:
|
||||
params: The parameters for the sample. In standard DP-SGD training,
|
||||
the clipping norm for the sample's microbatch gradients (i.e.,
|
||||
|
@ -102,6 +156,21 @@ class DPQuery(object):
|
|||
The updated sample state. In standard DP-SGD training, the set of
|
||||
previous mcrobatch gradients with the addition of the record argument.
|
||||
"""
|
||||
preprocessed_record = self.preprocess_record(params, record)
|
||||
return self.accumulate_preprocessed_record(
|
||||
sample_state, preprocessed_record)
|
||||
|
||||
@abc.abstractmethod
|
||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||
"""Merges two sample states into a single state.
|
||||
|
||||
Args:
|
||||
sample_state_1: The first sample state to merge.
|
||||
sample_state_2: The second sample state to merge.
|
||||
|
||||
Returns:
|
||||
The merged sample state.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
|
@ -123,3 +192,26 @@ class DPQuery(object):
|
|||
averaging performed in a manner that guarantees differential privacy.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def zeros_like(arg):
|
||||
"""A `zeros_like` function that also works for `tf.TensorSpec`s."""
|
||||
try:
|
||||
arg = tf.convert_to_tensor(arg)
|
||||
except TypeError:
|
||||
pass
|
||||
return tf.zeros(arg.shape, arg.dtype)
|
||||
|
||||
|
||||
class SumAggregationDPQuery(DPQuery):
|
||||
"""Base class for DPQueries that aggregate via sum."""
|
||||
|
||||
def initial_sample_state(self, global_state, template):
|
||||
del global_state # unused.
|
||||
return nest.map_structure(zeros_like, template)
|
||||
|
||||
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
|
||||
return nest.map_structure(tf.add, sample_state, preprocessed_record)
|
||||
|
||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||
return nest.map_structure(tf.add, sample_state_1, sample_state_2)
|
||||
|
|
|
@ -31,7 +31,7 @@ else:
|
|||
nest = tf.nest
|
||||
|
||||
|
||||
class GaussianSumQuery(dp_query.DPQuery):
|
||||
class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||
"""Implements DPQuery interface for Gaussian sum queries.
|
||||
|
||||
Accumulates clipped vectors, then adds Gaussian noise to the sum.
|
||||
|
@ -50,31 +50,10 @@ class GaussianSumQuery(dp_query.DPQuery):
|
|||
self._stddev = tf.cast(stddev, tf.float32)
|
||||
self._ledger = ledger
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Returns the initial global state for the GaussianSumQuery."""
|
||||
return None
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Given the global state, derives parameters to use for the next sample.
|
||||
|
||||
Args:
|
||||
global_state: The current global state.
|
||||
|
||||
Returns:
|
||||
Parameters to use to process records in the next sample.
|
||||
"""
|
||||
return self._l2_norm_clip
|
||||
|
||||
def initial_sample_state(self, global_state, tensors):
|
||||
"""Returns an initial state to use for the next sample.
|
||||
|
||||
Args:
|
||||
global_state: The current global state.
|
||||
tensors: A structure of tensors used as a template to create the initial
|
||||
sample state.
|
||||
|
||||
Returns: An initial sample state.
|
||||
"""
|
||||
def initial_sample_state(self, global_state, template):
|
||||
if self._ledger:
|
||||
dependencies = [
|
||||
self._ledger.record_sum_query(self._l2_norm_clip, self._stddev)
|
||||
|
@ -82,51 +61,32 @@ class GaussianSumQuery(dp_query.DPQuery):
|
|||
else:
|
||||
dependencies = []
|
||||
with tf.control_dependencies(dependencies):
|
||||
return nest.map_structure(tf.zeros_like, tensors)
|
||||
return nest.map_structure(
|
||||
dp_query.zeros_like, template)
|
||||
|
||||
def accumulate_record_impl(self, params, sample_state, record):
|
||||
"""Accumulates a single record into the sample state.
|
||||
def preprocess_record_impl(self, params, record):
|
||||
"""Clips the l2 norm, returning the clipped record and the l2 norm.
|
||||
|
||||
Args:
|
||||
params: The parameters for the sample.
|
||||
sample_state: The current sample state.
|
||||
record: The record to accumulate.
|
||||
record: The record to be processed.
|
||||
|
||||
Returns:
|
||||
A tuple containing the updated sample state and the global norm.
|
||||
A tuple (preprocessed_records, l2_norm) where `preprocessed_records` is
|
||||
the structure of preprocessed tensors, and l2_norm is the total l2 norm
|
||||
before clipping.
|
||||
"""
|
||||
l2_norm_clip = params
|
||||
record_as_list = nest.flatten(record)
|
||||
clipped_as_list, norm = tf.clip_by_global_norm(record_as_list, l2_norm_clip)
|
||||
clipped = nest.pack_sequence_as(record, clipped_as_list)
|
||||
return nest.map_structure(tf.add, sample_state, clipped), norm
|
||||
return nest.pack_sequence_as(record, clipped_as_list), norm
|
||||
|
||||
def accumulate_record(self, params, sample_state, record):
|
||||
"""Accumulates a single record into the sample state.
|
||||
|
||||
Args:
|
||||
params: The parameters for the sample.
|
||||
sample_state: The current sample state.
|
||||
record: The record to accumulate.
|
||||
|
||||
Returns:
|
||||
The updated sample state.
|
||||
"""
|
||||
new_sample_state, _ = self.accumulate_record_impl(
|
||||
params, sample_state, record)
|
||||
return new_sample_state
|
||||
def preprocess_record(self, params, record):
|
||||
preprocessed_record, _ = self.preprocess_record_impl(params, record)
|
||||
return preprocessed_record
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Gets noised sum after all records of sample have been accumulated.
|
||||
|
||||
Args:
|
||||
sample_state: The sample state after all records have been accumulated.
|
||||
global_state: The global state.
|
||||
|
||||
Returns:
|
||||
A tuple (estimate, new_global_state) where "estimate" is the estimated
|
||||
sum of the records and "new_global_state" is the updated global state.
|
||||
"""
|
||||
"""See base class."""
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
def add_noise(v):
|
||||
return v + tf.random_normal(tf.shape(v), stddev=self._stddev)
|
||||
|
|
|
@ -91,6 +91,32 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
result_stddev = np.std(noised_sums)
|
||||
self.assertNear(result_stddev, stddev, 0.1)
|
||||
|
||||
def test_gaussian_sum_merge(self):
|
||||
records1 = [tf.constant([2.0, 0.0]), tf.constant([-1.0, 1.0])]
|
||||
records2 = [tf.constant([3.0, 5.0]), tf.constant([-1.0, 4.0])]
|
||||
|
||||
def get_sample_state(records):
|
||||
query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=1.0)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
sample_state = query.initial_sample_state(global_state, records[0])
|
||||
for record in records:
|
||||
sample_state = query.accumulate_record(params, sample_state, record)
|
||||
return sample_state
|
||||
|
||||
sample_state_1 = get_sample_state(records1)
|
||||
sample_state_2 = get_sample_state(records2)
|
||||
|
||||
merged = gaussian_query.GaussianSumQuery(10.0, 1.0).merge_sample_states(
|
||||
sample_state_1,
|
||||
sample_state_2)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
result = sess.run(merged)
|
||||
|
||||
expected = [3.0, 10.0]
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
def test_gaussian_average_no_noise(self):
|
||||
with self.cached_session() as sess:
|
||||
record1 = tf.constant([5.0, 0.0]) # Clipped to [3.0, 0.0].
|
||||
|
|
|
@ -56,52 +56,39 @@ class NestedQuery(dp_query.DPQuery):
|
|||
"""
|
||||
self._queries = queries
|
||||
|
||||
def _map_to_queries(self, fn, *inputs):
|
||||
def _map_to_queries(self, fn, *inputs, **kwargs):
|
||||
def caller(query, *args):
|
||||
return getattr(query, fn)(*args)
|
||||
return getattr(query, fn)(*args, **kwargs)
|
||||
return nest.map_structure_up_to(
|
||||
self._queries, caller, self._queries, *inputs)
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Returns the initial global state for the NestedQuery."""
|
||||
"""See base class."""
|
||||
return self._map_to_queries('initial_global_state')
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Given the global state, derives parameters to use for the next sample.
|
||||
|
||||
Args:
|
||||
global_state: The current global state.
|
||||
|
||||
Returns:
|
||||
Parameters to use to process records in the next sample.
|
||||
"""
|
||||
"""See base class."""
|
||||
return self._map_to_queries('derive_sample_params', global_state)
|
||||
|
||||
def initial_sample_state(self, global_state, tensors):
|
||||
"""Returns an initial state to use for the next sample.
|
||||
def initial_sample_state(self, global_state, template):
|
||||
"""See base class."""
|
||||
return self._map_to_queries('initial_sample_state', global_state, template)
|
||||
|
||||
Args:
|
||||
global_state: The current global state.
|
||||
tensors: A structure of tensors used as a template to create the initial
|
||||
sample state.
|
||||
def preprocess_record(self, params, record):
|
||||
"""See base class."""
|
||||
return self._map_to_queries('preprocess_record', params, record)
|
||||
|
||||
Returns: An initial sample state.
|
||||
"""
|
||||
return self._map_to_queries('initial_sample_state', global_state, tensors)
|
||||
|
||||
def accumulate_record(self, params, sample_state, record):
|
||||
"""Accumulates a single record into the sample state.
|
||||
|
||||
Args:
|
||||
params: The parameters for the sample.
|
||||
sample_state: The current sample state.
|
||||
record: The record to accumulate.
|
||||
|
||||
Returns:
|
||||
The updated sample state.
|
||||
"""
|
||||
def accumulate_preprocessed_record(
|
||||
self, sample_state, preprocessed_record):
|
||||
"""See base class."""
|
||||
return self._map_to_queries(
|
||||
'accumulate_record', params, sample_state, record)
|
||||
'accumulate_preprocessed_record',
|
||||
sample_state,
|
||||
preprocessed_record)
|
||||
|
||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||
return self._map_to_queries(
|
||||
'merge_sample_states', sample_state_1, sample_state_2)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Gets query result after all records of sample have been accumulated.
|
||||
|
|
|
@ -28,78 +28,44 @@ else:
|
|||
nest = tf.nest
|
||||
|
||||
|
||||
class NoPrivacySumQuery(dp_query.DPQuery):
|
||||
class NoPrivacySumQuery(dp_query.SumAggregationDPQuery):
|
||||
"""Implements DPQuery interface for a sum query with no privacy.
|
||||
|
||||
Accumulates vectors without clipping or adding noise.
|
||||
"""
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Returns the initial global state for the NoPrivacySumQuery."""
|
||||
return None
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""See base class."""
|
||||
del global_state # unused.
|
||||
return None
|
||||
|
||||
def initial_sample_state(self, global_state, tensors):
|
||||
"""See base class."""
|
||||
del global_state # unused.
|
||||
return nest.map_structure(tf.zeros_like, tensors)
|
||||
|
||||
def accumulate_record(self, params, sample_state, record, weight=1):
|
||||
"""See base class. Optional argument for weighted sum queries."""
|
||||
del params # unused.
|
||||
|
||||
def add_weighted(state_tensor, record_tensor):
|
||||
return tf.add(state_tensor, weight * record_tensor)
|
||||
|
||||
return nest.map_structure(add_weighted, sample_state, record)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""See base class."""
|
||||
return sample_state, global_state
|
||||
|
||||
|
||||
class NoPrivacyAverageQuery(dp_query.DPQuery):
|
||||
class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
|
||||
"""Implements DPQuery interface for an average query with no privacy.
|
||||
|
||||
Accumulates vectors and normalizes by the total number of accumulated vectors.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the NoPrivacyAverageQuery."""
|
||||
self._numerator = NoPrivacySumQuery()
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Returns the initial global state for the NoPrivacyAverageQuery."""
|
||||
return self._numerator.initial_global_state()
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
def initial_sample_state(self, global_state, template):
|
||||
"""See base class."""
|
||||
del global_state # unused.
|
||||
return None
|
||||
return (
|
||||
super(NoPrivacyAverageQuery, self).initial_sample_state(
|
||||
global_state, template),
|
||||
tf.constant(0.0))
|
||||
|
||||
def initial_sample_state(self, global_state, tensors):
|
||||
"""See base class."""
|
||||
return self._numerator.initial_sample_state(global_state, tensors), 0.0
|
||||
def preprocess_record(self, params, record, weight=1):
|
||||
"""Multiplies record by weight."""
|
||||
weighted_record = nest.map_structure(lambda t: weight * t, record)
|
||||
return (weighted_record, weight)
|
||||
|
||||
def accumulate_record(self, params, sample_state, record, weight=1):
|
||||
"""See base class. Optional argument for weighted average queries."""
|
||||
sum_sample_state, denominator = sample_state
|
||||
return (
|
||||
self._numerator.accumulate_record(
|
||||
params, sum_sample_state, record, weight),
|
||||
tf.add(denominator, weight))
|
||||
"""Accumulates record, multiplying by weight."""
|
||||
weighted_record = nest.map_structure(lambda t: weight * t, record)
|
||||
return self.accumulate_preprocessed_record(
|
||||
sample_state, (weighted_record, weight))
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""See base class."""
|
||||
sum_sample_state, denominator = sample_state
|
||||
exact_sum, new_global_state = self._numerator.get_noised_result(
|
||||
sum_sample_state, global_state)
|
||||
sum_state, denominator = sample_state
|
||||
|
||||
def normalize(v):
|
||||
return tf.truediv(v, denominator)
|
||||
|
||||
return nest.map_structure(normalize, exact_sum), new_global_state
|
||||
return nest.map_structure(
|
||||
lambda t: tf.truediv(t, denominator), sum_state), ()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2018, The TensorFlow Authors.
|
||||
# Copyright 2019, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -27,7 +27,7 @@ from privacy.dp_query import test_utils
|
|||
|
||||
class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_no_privacy_sum(self):
|
||||
def test_sum(self):
|
||||
with self.cached_session() as sess:
|
||||
record1 = tf.constant([2.0, 0.0])
|
||||
record2 = tf.constant([-1.0, 1.0])
|
||||
|
@ -38,20 +38,6 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
expected = [1.0, 1.0]
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
def test_no_privacy_weighted_sum(self):
|
||||
with self.cached_session() as sess:
|
||||
record1 = tf.constant([2.0, 0.0])
|
||||
record2 = tf.constant([-1.0, 1.0])
|
||||
|
||||
weights = [1, 2]
|
||||
|
||||
query = no_privacy_query.NoPrivacySumQuery()
|
||||
query_result, _ = test_utils.run_query(
|
||||
query, [record1, record2], weights=weights)
|
||||
result = sess.run(query_result)
|
||||
expected = [0.0, 2.0]
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
def test_no_privacy_average(self):
|
||||
with self.cached_session() as sess:
|
||||
record1 = tf.constant([5.0, 0.0])
|
||||
|
|
|
@ -38,65 +38,39 @@ class NormalizedQuery(dp_query.DPQuery):
|
|||
|
||||
Args:
|
||||
numerator_query: A DPQuery for the numerator.
|
||||
denominator: A value for the denominator.
|
||||
denominator: A value for the denominator. May be None if it will be
|
||||
supplied via the set_denominator function before get_noised_result is
|
||||
called.
|
||||
"""
|
||||
self._numerator = numerator_query
|
||||
self._denominator = tf.cast(denominator,
|
||||
tf.float32) if denominator is not None else None
|
||||
self._denominator = (
|
||||
tf.cast(denominator, tf.float32) if denominator is not None else None)
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Returns the initial global state for the NormalizedQuery."""
|
||||
"""See base class."""
|
||||
# NormalizedQuery has no global state beyond the numerator state.
|
||||
return self._numerator.initial_global_state()
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Given the global state, derives parameters to use for the next sample.
|
||||
|
||||
Args:
|
||||
global_state: The current global state.
|
||||
|
||||
Returns:
|
||||
Parameters to use to process records in the next sample.
|
||||
"""
|
||||
"""See base class."""
|
||||
return self._numerator.derive_sample_params(global_state)
|
||||
|
||||
def initial_sample_state(self, global_state, tensors):
|
||||
"""Returns an initial state to use for the next sample.
|
||||
|
||||
Args:
|
||||
global_state: The current global state.
|
||||
tensors: A structure of tensors used as a template to create the initial
|
||||
sample state.
|
||||
|
||||
Returns: An initial sample state.
|
||||
"""
|
||||
def initial_sample_state(self, global_state, template):
|
||||
"""See base class."""
|
||||
# NormalizedQuery has no sample state beyond the numerator state.
|
||||
return self._numerator.initial_sample_state(global_state, tensors)
|
||||
return self._numerator.initial_sample_state(global_state, template)
|
||||
|
||||
def accumulate_record(self, params, sample_state, record):
|
||||
"""Accumulates a single record into the sample state.
|
||||
def preprocess_record(self, params, record):
|
||||
return self._numerator.preprocess_record(params, record)
|
||||
|
||||
Args:
|
||||
params: The parameters for the sample.
|
||||
sample_state: The current sample state.
|
||||
record: The record to accumulate.
|
||||
|
||||
Returns:
|
||||
The updated sample state.
|
||||
"""
|
||||
return self._numerator.accumulate_record(params, sample_state, record)
|
||||
def accumulate_preprocessed_record(
|
||||
self, sample_state, preprocessed_record):
|
||||
"""See base class."""
|
||||
return self._numerator.accumulate_preprocessed_record(
|
||||
sample_state, preprocessed_record)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Gets noised average after all records of sample have been accumulated.
|
||||
|
||||
Args:
|
||||
sample_state: The sample state after all records have been accumulated.
|
||||
global_state: The global state.
|
||||
|
||||
Returns:
|
||||
A tuple (estimate, new_global_state) where "estimate" is the estimated
|
||||
average of the records and "new_global_state" is the updated global state.
|
||||
"""
|
||||
"""See base class."""
|
||||
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
|
||||
sample_state, global_state)
|
||||
def normalize(v):
|
||||
|
@ -104,5 +78,10 @@ class NormalizedQuery(dp_query.DPQuery):
|
|||
|
||||
return nest.map_structure(normalize, noised_sum), new_sum_global_state
|
||||
|
||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||
"""See base class."""
|
||||
return self._numerator.merge_sample_states(sample_state_1, sample_state_2)
|
||||
|
||||
def set_denominator(self, denominator):
|
||||
"""Sets the denominator for the NormalizedQuery."""
|
||||
self._denominator = tf.cast(denominator, tf.float32)
|
||||
|
|
Loading…
Reference in a new issue