forked from 626_privacy/tensorflow_privacy
Extensions to DPQuery and subclasses.
1. Split DPQuery.accumulate_record function into preprocess_record and accumulate_preprocessed_record. 2. Add merge_sample_state function. 3. Add default implementations for some functions in DPQuery, and add base class SumAggregationDPQuery that implements some more. Only get_noised_result is still abstract. 4. Enforce that all states and parameters used as inputs and outputs to DPQuery functions are nested structures of tensors by replacing numbers with constants and Nones with empty tuples. PiperOrigin-RevId: 247975791
This commit is contained in:
parent
82852c0e71
commit
1d1a6e087a
9 changed files with 221 additions and 212 deletions
|
@ -234,13 +234,22 @@ class QueryWithLedger(dp_query.DPQuery):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._query.derive_sample_params(global_state)
|
return self._query.derive_sample_params(global_state)
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, tensors):
|
def initial_sample_state(self, global_state, template):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._query.initial_sample_state(global_state, tensors)
|
return self._query.initial_sample_state(global_state, template)
|
||||||
|
|
||||||
def accumulate_record(self, params, sample_state, record):
|
def preprocess_record(self, params, record):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._query.accumulate_record(params, sample_state, record)
|
return self._query.preprocess_record(params, record)
|
||||||
|
|
||||||
|
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
|
||||||
|
"""See base class."""
|
||||||
|
return self._query.accumulate_preprocessed_record(
|
||||||
|
sample_state, preprocessed_record)
|
||||||
|
|
||||||
|
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||||
|
"""See base class."""
|
||||||
|
return self._query.merge_sample_states(sample_state_1, sample_state_2)
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Ensures sample is recorded to the ledger and returns noised result."""
|
"""Ensures sample is recorded to the ledger and returns noised result."""
|
||||||
|
|
|
@ -5,6 +5,10 @@ licenses(["notice"]) # Apache 2.0
|
||||||
py_library(
|
py_library(
|
||||||
name = "dp_query",
|
name = "dp_query",
|
||||||
srcs = ["dp_query.py"],
|
srcs = ["dp_query.py"],
|
||||||
|
deps = [
|
||||||
|
"//third_party/py/distutils",
|
||||||
|
"//third_party/py/tensorflow",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2018, The TensorFlow Authors.
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -47,6 +47,13 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
from distutils.version import LooseVersion
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||||
|
nest = tf.contrib.framework.nest
|
||||||
|
else:
|
||||||
|
nest = tf.nest
|
||||||
|
|
||||||
|
|
||||||
class DPQuery(object):
|
class DPQuery(object):
|
||||||
|
@ -54,12 +61,10 @@ class DPQuery(object):
|
||||||
|
|
||||||
__metaclass__ = abc.ABCMeta
|
__metaclass__ = abc.ABCMeta
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""Returns the initial global state for the DPQuery."""
|
"""Returns the initial global state for the DPQuery."""
|
||||||
pass
|
return ()
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""Given the global state, derives parameters to use for the next sample.
|
"""Given the global state, derives parameters to use for the next sample.
|
||||||
|
|
||||||
|
@ -69,25 +74,74 @@ class DPQuery(object):
|
||||||
Returns:
|
Returns:
|
||||||
Parameters to use to process records in the next sample.
|
Parameters to use to process records in the next sample.
|
||||||
"""
|
"""
|
||||||
pass
|
del global_state # unused.
|
||||||
|
return ()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def initial_sample_state(self, global_state, tensors):
|
def initial_sample_state(self, global_state, template):
|
||||||
"""Returns an initial state to use for the next sample.
|
"""Returns an initial state to use for the next sample.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
global_state: The current global state.
|
global_state: The current global state.
|
||||||
tensors: A structure of tensors used as a template to create the initial
|
template: A nested structure of tensors, TensorSpecs, or numpy arrays used
|
||||||
sample state.
|
as a template to create the initial sample state. It is assumed that the
|
||||||
|
leaves of the structure are python scalars or some type that has
|
||||||
|
properties `shape` and `dtype`.
|
||||||
|
|
||||||
Returns: An initial sample state.
|
Returns: An initial sample state.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def preprocess_record(self, params, record):
|
||||||
|
"""Preprocesses a single record.
|
||||||
|
|
||||||
|
This preprocessing is applied to one client's record, e.g. selecting vectors
|
||||||
|
and clipping them to a fixed L2 norm. This method can be executed in a
|
||||||
|
separate TF session, or even on a different machine, so it should not depend
|
||||||
|
on any TF inputs other than those provided as input arguments. In
|
||||||
|
particular, implementations should avoid accessing any TF tensors or
|
||||||
|
variables that are stored in self.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: The parameters for the sample. In standard DP-SGD training,
|
||||||
|
the clipping norm for the sample's microbatch gradients (i.e.,
|
||||||
|
a maximum norm magnitude to which each gradient is clipped)
|
||||||
|
record: The record to be processed. In standard DP-SGD training,
|
||||||
|
the gradient computed for the examples in one microbatch, which
|
||||||
|
may be the gradient for just one example (for size 1 microbatches).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A structure of tensors to be aggregated.
|
||||||
|
"""
|
||||||
|
del params # unused.
|
||||||
|
return record
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
def accumulate_preprocessed_record(
|
||||||
|
self, sample_state, preprocessed_record):
|
||||||
|
"""Accumulates a single preprocessed record into the sample state.
|
||||||
|
|
||||||
|
This method is intended to only do simple aggregation, typically just a sum.
|
||||||
|
In the future, we might remove this method and replace it with a way to
|
||||||
|
declaratively specify the type of aggregation required.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_state: The current sample state. In standard DP-SGD training,
|
||||||
|
the accumulated sum of previous clipped microbatch gradients.
|
||||||
|
preprocessed_record: The preprocessed record to accumulate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated sample state.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
def accumulate_record(self, params, sample_state, record):
|
def accumulate_record(self, params, sample_state, record):
|
||||||
"""Accumulates a single record into the sample state.
|
"""Accumulates a single record into the sample state.
|
||||||
|
|
||||||
|
This is a helper method that simply delegates to `preprocess_record` and
|
||||||
|
`accumulate_preprocessed_record` for the common case when both of those
|
||||||
|
functions run on a single device.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: The parameters for the sample. In standard DP-SGD training,
|
params: The parameters for the sample. In standard DP-SGD training,
|
||||||
the clipping norm for the sample's microbatch gradients (i.e.,
|
the clipping norm for the sample's microbatch gradients (i.e.,
|
||||||
|
@ -102,6 +156,21 @@ class DPQuery(object):
|
||||||
The updated sample state. In standard DP-SGD training, the set of
|
The updated sample state. In standard DP-SGD training, the set of
|
||||||
previous mcrobatch gradients with the addition of the record argument.
|
previous mcrobatch gradients with the addition of the record argument.
|
||||||
"""
|
"""
|
||||||
|
preprocessed_record = self.preprocess_record(params, record)
|
||||||
|
return self.accumulate_preprocessed_record(
|
||||||
|
sample_state, preprocessed_record)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||||
|
"""Merges two sample states into a single state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_state_1: The first sample state to merge.
|
||||||
|
sample_state_2: The second sample state to merge.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The merged sample state.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
@ -123,3 +192,26 @@ class DPQuery(object):
|
||||||
averaging performed in a manner that guarantees differential privacy.
|
averaging performed in a manner that guarantees differential privacy.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def zeros_like(arg):
|
||||||
|
"""A `zeros_like` function that also works for `tf.TensorSpec`s."""
|
||||||
|
try:
|
||||||
|
arg = tf.convert_to_tensor(arg)
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
return tf.zeros(arg.shape, arg.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class SumAggregationDPQuery(DPQuery):
|
||||||
|
"""Base class for DPQueries that aggregate via sum."""
|
||||||
|
|
||||||
|
def initial_sample_state(self, global_state, template):
|
||||||
|
del global_state # unused.
|
||||||
|
return nest.map_structure(zeros_like, template)
|
||||||
|
|
||||||
|
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
|
||||||
|
return nest.map_structure(tf.add, sample_state, preprocessed_record)
|
||||||
|
|
||||||
|
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||||
|
return nest.map_structure(tf.add, sample_state_1, sample_state_2)
|
||||||
|
|
|
@ -31,7 +31,7 @@ else:
|
||||||
nest = tf.nest
|
nest = tf.nest
|
||||||
|
|
||||||
|
|
||||||
class GaussianSumQuery(dp_query.DPQuery):
|
class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""Implements DPQuery interface for Gaussian sum queries.
|
"""Implements DPQuery interface for Gaussian sum queries.
|
||||||
|
|
||||||
Accumulates clipped vectors, then adds Gaussian noise to the sum.
|
Accumulates clipped vectors, then adds Gaussian noise to the sum.
|
||||||
|
@ -50,31 +50,10 @@ class GaussianSumQuery(dp_query.DPQuery):
|
||||||
self._stddev = tf.cast(stddev, tf.float32)
|
self._stddev = tf.cast(stddev, tf.float32)
|
||||||
self._ledger = ledger
|
self._ledger = ledger
|
||||||
|
|
||||||
def initial_global_state(self):
|
|
||||||
"""Returns the initial global state for the GaussianSumQuery."""
|
|
||||||
return None
|
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""Given the global state, derives parameters to use for the next sample.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
global_state: The current global state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Parameters to use to process records in the next sample.
|
|
||||||
"""
|
|
||||||
return self._l2_norm_clip
|
return self._l2_norm_clip
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, tensors):
|
def initial_sample_state(self, global_state, template):
|
||||||
"""Returns an initial state to use for the next sample.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
global_state: The current global state.
|
|
||||||
tensors: A structure of tensors used as a template to create the initial
|
|
||||||
sample state.
|
|
||||||
|
|
||||||
Returns: An initial sample state.
|
|
||||||
"""
|
|
||||||
if self._ledger:
|
if self._ledger:
|
||||||
dependencies = [
|
dependencies = [
|
||||||
self._ledger.record_sum_query(self._l2_norm_clip, self._stddev)
|
self._ledger.record_sum_query(self._l2_norm_clip, self._stddev)
|
||||||
|
@ -82,51 +61,32 @@ class GaussianSumQuery(dp_query.DPQuery):
|
||||||
else:
|
else:
|
||||||
dependencies = []
|
dependencies = []
|
||||||
with tf.control_dependencies(dependencies):
|
with tf.control_dependencies(dependencies):
|
||||||
return nest.map_structure(tf.zeros_like, tensors)
|
return nest.map_structure(
|
||||||
|
dp_query.zeros_like, template)
|
||||||
|
|
||||||
def accumulate_record_impl(self, params, sample_state, record):
|
def preprocess_record_impl(self, params, record):
|
||||||
"""Accumulates a single record into the sample state.
|
"""Clips the l2 norm, returning the clipped record and the l2 norm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: The parameters for the sample.
|
params: The parameters for the sample.
|
||||||
sample_state: The current sample state.
|
record: The record to be processed.
|
||||||
record: The record to accumulate.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing the updated sample state and the global norm.
|
A tuple (preprocessed_records, l2_norm) where `preprocessed_records` is
|
||||||
|
the structure of preprocessed tensors, and l2_norm is the total l2 norm
|
||||||
|
before clipping.
|
||||||
"""
|
"""
|
||||||
l2_norm_clip = params
|
l2_norm_clip = params
|
||||||
record_as_list = nest.flatten(record)
|
record_as_list = nest.flatten(record)
|
||||||
clipped_as_list, norm = tf.clip_by_global_norm(record_as_list, l2_norm_clip)
|
clipped_as_list, norm = tf.clip_by_global_norm(record_as_list, l2_norm_clip)
|
||||||
clipped = nest.pack_sequence_as(record, clipped_as_list)
|
return nest.pack_sequence_as(record, clipped_as_list), norm
|
||||||
return nest.map_structure(tf.add, sample_state, clipped), norm
|
|
||||||
|
|
||||||
def accumulate_record(self, params, sample_state, record):
|
def preprocess_record(self, params, record):
|
||||||
"""Accumulates a single record into the sample state.
|
preprocessed_record, _ = self.preprocess_record_impl(params, record)
|
||||||
|
return preprocessed_record
|
||||||
Args:
|
|
||||||
params: The parameters for the sample.
|
|
||||||
sample_state: The current sample state.
|
|
||||||
record: The record to accumulate.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The updated sample state.
|
|
||||||
"""
|
|
||||||
new_sample_state, _ = self.accumulate_record_impl(
|
|
||||||
params, sample_state, record)
|
|
||||||
return new_sample_state
|
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Gets noised sum after all records of sample have been accumulated.
|
"""See base class."""
|
||||||
|
|
||||||
Args:
|
|
||||||
sample_state: The sample state after all records have been accumulated.
|
|
||||||
global_state: The global state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple (estimate, new_global_state) where "estimate" is the estimated
|
|
||||||
sum of the records and "new_global_state" is the updated global state.
|
|
||||||
"""
|
|
||||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||||
def add_noise(v):
|
def add_noise(v):
|
||||||
return v + tf.random_normal(tf.shape(v), stddev=self._stddev)
|
return v + tf.random_normal(tf.shape(v), stddev=self._stddev)
|
||||||
|
|
|
@ -91,6 +91,32 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
result_stddev = np.std(noised_sums)
|
result_stddev = np.std(noised_sums)
|
||||||
self.assertNear(result_stddev, stddev, 0.1)
|
self.assertNear(result_stddev, stddev, 0.1)
|
||||||
|
|
||||||
|
def test_gaussian_sum_merge(self):
|
||||||
|
records1 = [tf.constant([2.0, 0.0]), tf.constant([-1.0, 1.0])]
|
||||||
|
records2 = [tf.constant([3.0, 5.0]), tf.constant([-1.0, 4.0])]
|
||||||
|
|
||||||
|
def get_sample_state(records):
|
||||||
|
query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=1.0)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
sample_state = query.initial_sample_state(global_state, records[0])
|
||||||
|
for record in records:
|
||||||
|
sample_state = query.accumulate_record(params, sample_state, record)
|
||||||
|
return sample_state
|
||||||
|
|
||||||
|
sample_state_1 = get_sample_state(records1)
|
||||||
|
sample_state_2 = get_sample_state(records2)
|
||||||
|
|
||||||
|
merged = gaussian_query.GaussianSumQuery(10.0, 1.0).merge_sample_states(
|
||||||
|
sample_state_1,
|
||||||
|
sample_state_2)
|
||||||
|
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
result = sess.run(merged)
|
||||||
|
|
||||||
|
expected = [3.0, 10.0]
|
||||||
|
self.assertAllClose(result, expected)
|
||||||
|
|
||||||
def test_gaussian_average_no_noise(self):
|
def test_gaussian_average_no_noise(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
record1 = tf.constant([5.0, 0.0]) # Clipped to [3.0, 0.0].
|
record1 = tf.constant([5.0, 0.0]) # Clipped to [3.0, 0.0].
|
||||||
|
|
|
@ -56,52 +56,39 @@ class NestedQuery(dp_query.DPQuery):
|
||||||
"""
|
"""
|
||||||
self._queries = queries
|
self._queries = queries
|
||||||
|
|
||||||
def _map_to_queries(self, fn, *inputs):
|
def _map_to_queries(self, fn, *inputs, **kwargs):
|
||||||
def caller(query, *args):
|
def caller(query, *args):
|
||||||
return getattr(query, fn)(*args)
|
return getattr(query, fn)(*args, **kwargs)
|
||||||
return nest.map_structure_up_to(
|
return nest.map_structure_up_to(
|
||||||
self._queries, caller, self._queries, *inputs)
|
self._queries, caller, self._queries, *inputs)
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""Returns the initial global state for the NestedQuery."""
|
"""See base class."""
|
||||||
return self._map_to_queries('initial_global_state')
|
return self._map_to_queries('initial_global_state')
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""Given the global state, derives parameters to use for the next sample.
|
"""See base class."""
|
||||||
|
|
||||||
Args:
|
|
||||||
global_state: The current global state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Parameters to use to process records in the next sample.
|
|
||||||
"""
|
|
||||||
return self._map_to_queries('derive_sample_params', global_state)
|
return self._map_to_queries('derive_sample_params', global_state)
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, tensors):
|
def initial_sample_state(self, global_state, template):
|
||||||
"""Returns an initial state to use for the next sample.
|
"""See base class."""
|
||||||
|
return self._map_to_queries('initial_sample_state', global_state, template)
|
||||||
|
|
||||||
Args:
|
def preprocess_record(self, params, record):
|
||||||
global_state: The current global state.
|
"""See base class."""
|
||||||
tensors: A structure of tensors used as a template to create the initial
|
return self._map_to_queries('preprocess_record', params, record)
|
||||||
sample state.
|
|
||||||
|
|
||||||
Returns: An initial sample state.
|
def accumulate_preprocessed_record(
|
||||||
"""
|
self, sample_state, preprocessed_record):
|
||||||
return self._map_to_queries('initial_sample_state', global_state, tensors)
|
"""See base class."""
|
||||||
|
|
||||||
def accumulate_record(self, params, sample_state, record):
|
|
||||||
"""Accumulates a single record into the sample state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params: The parameters for the sample.
|
|
||||||
sample_state: The current sample state.
|
|
||||||
record: The record to accumulate.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The updated sample state.
|
|
||||||
"""
|
|
||||||
return self._map_to_queries(
|
return self._map_to_queries(
|
||||||
'accumulate_record', params, sample_state, record)
|
'accumulate_preprocessed_record',
|
||||||
|
sample_state,
|
||||||
|
preprocessed_record)
|
||||||
|
|
||||||
|
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||||
|
return self._map_to_queries(
|
||||||
|
'merge_sample_states', sample_state_1, sample_state_2)
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Gets query result after all records of sample have been accumulated.
|
"""Gets query result after all records of sample have been accumulated.
|
||||||
|
|
|
@ -28,78 +28,44 @@ else:
|
||||||
nest = tf.nest
|
nest = tf.nest
|
||||||
|
|
||||||
|
|
||||||
class NoPrivacySumQuery(dp_query.DPQuery):
|
class NoPrivacySumQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""Implements DPQuery interface for a sum query with no privacy.
|
"""Implements DPQuery interface for a sum query with no privacy.
|
||||||
|
|
||||||
Accumulates vectors without clipping or adding noise.
|
Accumulates vectors without clipping or adding noise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def initial_global_state(self):
|
|
||||||
"""Returns the initial global state for the NoPrivacySumQuery."""
|
|
||||||
return None
|
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
|
||||||
"""See base class."""
|
|
||||||
del global_state # unused.
|
|
||||||
return None
|
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, tensors):
|
|
||||||
"""See base class."""
|
|
||||||
del global_state # unused.
|
|
||||||
return nest.map_structure(tf.zeros_like, tensors)
|
|
||||||
|
|
||||||
def accumulate_record(self, params, sample_state, record, weight=1):
|
|
||||||
"""See base class. Optional argument for weighted sum queries."""
|
|
||||||
del params # unused.
|
|
||||||
|
|
||||||
def add_weighted(state_tensor, record_tensor):
|
|
||||||
return tf.add(state_tensor, weight * record_tensor)
|
|
||||||
|
|
||||||
return nest.map_structure(add_weighted, sample_state, record)
|
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return sample_state, global_state
|
return sample_state, global_state
|
||||||
|
|
||||||
|
|
||||||
class NoPrivacyAverageQuery(dp_query.DPQuery):
|
class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""Implements DPQuery interface for an average query with no privacy.
|
"""Implements DPQuery interface for an average query with no privacy.
|
||||||
|
|
||||||
Accumulates vectors and normalizes by the total number of accumulated vectors.
|
Accumulates vectors and normalizes by the total number of accumulated vectors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def initial_sample_state(self, global_state, template):
|
||||||
"""Initializes the NoPrivacyAverageQuery."""
|
|
||||||
self._numerator = NoPrivacySumQuery()
|
|
||||||
|
|
||||||
def initial_global_state(self):
|
|
||||||
"""Returns the initial global state for the NoPrivacyAverageQuery."""
|
|
||||||
return self._numerator.initial_global_state()
|
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
del global_state # unused.
|
return (
|
||||||
return None
|
super(NoPrivacyAverageQuery, self).initial_sample_state(
|
||||||
|
global_state, template),
|
||||||
|
tf.constant(0.0))
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, tensors):
|
def preprocess_record(self, params, record, weight=1):
|
||||||
"""See base class."""
|
"""Multiplies record by weight."""
|
||||||
return self._numerator.initial_sample_state(global_state, tensors), 0.0
|
weighted_record = nest.map_structure(lambda t: weight * t, record)
|
||||||
|
return (weighted_record, weight)
|
||||||
|
|
||||||
def accumulate_record(self, params, sample_state, record, weight=1):
|
def accumulate_record(self, params, sample_state, record, weight=1):
|
||||||
"""See base class. Optional argument for weighted average queries."""
|
"""Accumulates record, multiplying by weight."""
|
||||||
sum_sample_state, denominator = sample_state
|
weighted_record = nest.map_structure(lambda t: weight * t, record)
|
||||||
return (
|
return self.accumulate_preprocessed_record(
|
||||||
self._numerator.accumulate_record(
|
sample_state, (weighted_record, weight))
|
||||||
params, sum_sample_state, record, weight),
|
|
||||||
tf.add(denominator, weight))
|
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
sum_sample_state, denominator = sample_state
|
sum_state, denominator = sample_state
|
||||||
exact_sum, new_global_state = self._numerator.get_noised_result(
|
|
||||||
sum_sample_state, global_state)
|
|
||||||
|
|
||||||
def normalize(v):
|
return nest.map_structure(
|
||||||
return tf.truediv(v, denominator)
|
lambda t: tf.truediv(t, denominator), sum_state), ()
|
||||||
|
|
||||||
return nest.map_structure(normalize, exact_sum), new_global_state
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2018, The TensorFlow Authors.
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -27,7 +27,7 @@ from privacy.dp_query import test_utils
|
||||||
|
|
||||||
class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def test_no_privacy_sum(self):
|
def test_sum(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
record1 = tf.constant([2.0, 0.0])
|
record1 = tf.constant([2.0, 0.0])
|
||||||
record2 = tf.constant([-1.0, 1.0])
|
record2 = tf.constant([-1.0, 1.0])
|
||||||
|
@ -38,20 +38,6 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
expected = [1.0, 1.0]
|
expected = [1.0, 1.0]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(result, expected)
|
||||||
|
|
||||||
def test_no_privacy_weighted_sum(self):
|
|
||||||
with self.cached_session() as sess:
|
|
||||||
record1 = tf.constant([2.0, 0.0])
|
|
||||||
record2 = tf.constant([-1.0, 1.0])
|
|
||||||
|
|
||||||
weights = [1, 2]
|
|
||||||
|
|
||||||
query = no_privacy_query.NoPrivacySumQuery()
|
|
||||||
query_result, _ = test_utils.run_query(
|
|
||||||
query, [record1, record2], weights=weights)
|
|
||||||
result = sess.run(query_result)
|
|
||||||
expected = [0.0, 2.0]
|
|
||||||
self.assertAllClose(result, expected)
|
|
||||||
|
|
||||||
def test_no_privacy_average(self):
|
def test_no_privacy_average(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
record1 = tf.constant([5.0, 0.0])
|
record1 = tf.constant([5.0, 0.0])
|
||||||
|
|
|
@ -38,65 +38,39 @@ class NormalizedQuery(dp_query.DPQuery):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
numerator_query: A DPQuery for the numerator.
|
numerator_query: A DPQuery for the numerator.
|
||||||
denominator: A value for the denominator.
|
denominator: A value for the denominator. May be None if it will be
|
||||||
|
supplied via the set_denominator function before get_noised_result is
|
||||||
|
called.
|
||||||
"""
|
"""
|
||||||
self._numerator = numerator_query
|
self._numerator = numerator_query
|
||||||
self._denominator = tf.cast(denominator,
|
self._denominator = (
|
||||||
tf.float32) if denominator is not None else None
|
tf.cast(denominator, tf.float32) if denominator is not None else None)
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""Returns the initial global state for the NormalizedQuery."""
|
"""See base class."""
|
||||||
# NormalizedQuery has no global state beyond the numerator state.
|
# NormalizedQuery has no global state beyond the numerator state.
|
||||||
return self._numerator.initial_global_state()
|
return self._numerator.initial_global_state()
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""Given the global state, derives parameters to use for the next sample.
|
"""See base class."""
|
||||||
|
|
||||||
Args:
|
|
||||||
global_state: The current global state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Parameters to use to process records in the next sample.
|
|
||||||
"""
|
|
||||||
return self._numerator.derive_sample_params(global_state)
|
return self._numerator.derive_sample_params(global_state)
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, tensors):
|
def initial_sample_state(self, global_state, template):
|
||||||
"""Returns an initial state to use for the next sample.
|
"""See base class."""
|
||||||
|
|
||||||
Args:
|
|
||||||
global_state: The current global state.
|
|
||||||
tensors: A structure of tensors used as a template to create the initial
|
|
||||||
sample state.
|
|
||||||
|
|
||||||
Returns: An initial sample state.
|
|
||||||
"""
|
|
||||||
# NormalizedQuery has no sample state beyond the numerator state.
|
# NormalizedQuery has no sample state beyond the numerator state.
|
||||||
return self._numerator.initial_sample_state(global_state, tensors)
|
return self._numerator.initial_sample_state(global_state, template)
|
||||||
|
|
||||||
def accumulate_record(self, params, sample_state, record):
|
def preprocess_record(self, params, record):
|
||||||
"""Accumulates a single record into the sample state.
|
return self._numerator.preprocess_record(params, record)
|
||||||
|
|
||||||
Args:
|
def accumulate_preprocessed_record(
|
||||||
params: The parameters for the sample.
|
self, sample_state, preprocessed_record):
|
||||||
sample_state: The current sample state.
|
"""See base class."""
|
||||||
record: The record to accumulate.
|
return self._numerator.accumulate_preprocessed_record(
|
||||||
|
sample_state, preprocessed_record)
|
||||||
Returns:
|
|
||||||
The updated sample state.
|
|
||||||
"""
|
|
||||||
return self._numerator.accumulate_record(params, sample_state, record)
|
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Gets noised average after all records of sample have been accumulated.
|
"""See base class."""
|
||||||
|
|
||||||
Args:
|
|
||||||
sample_state: The sample state after all records have been accumulated.
|
|
||||||
global_state: The global state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple (estimate, new_global_state) where "estimate" is the estimated
|
|
||||||
average of the records and "new_global_state" is the updated global state.
|
|
||||||
"""
|
|
||||||
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
|
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
|
||||||
sample_state, global_state)
|
sample_state, global_state)
|
||||||
def normalize(v):
|
def normalize(v):
|
||||||
|
@ -104,5 +78,10 @@ class NormalizedQuery(dp_query.DPQuery):
|
||||||
|
|
||||||
return nest.map_structure(normalize, noised_sum), new_sum_global_state
|
return nest.map_structure(normalize, noised_sum), new_sum_global_state
|
||||||
|
|
||||||
|
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||||
|
"""See base class."""
|
||||||
|
return self._numerator.merge_sample_states(sample_state_1, sample_state_2)
|
||||||
|
|
||||||
def set_denominator(self, denominator):
|
def set_denominator(self, denominator):
|
||||||
|
"""Sets the denominator for the NormalizedQuery."""
|
||||||
self._denominator = tf.cast(denominator, tf.float32)
|
self._denominator = tf.cast(denominator, tf.float32)
|
||||||
|
|
Loading…
Reference in a new issue