Extensions to DPQuery and subclasses.

1. Split DPQuery.accumulate_record function into preprocess_record and accumulate_preprocessed_record.
2. Add merge_sample_state function.
3. Add default implementations for some functions in DPQuery, and add base class SumAggregationDPQuery that implements some more. Only get_noised_result is still abstract.
4. Enforce that all states and parameters used as inputs and outputs to DPQuery functions are nested structures of tensors by replacing numbers with constants and Nones with empty tuples.

PiperOrigin-RevId: 247975791
This commit is contained in:
Galen Andrew 2019-05-13 11:28:33 -07:00 committed by A. Unique TensorFlower
parent 82852c0e71
commit 1d1a6e087a
9 changed files with 221 additions and 212 deletions

View file

@ -234,13 +234,22 @@ class QueryWithLedger(dp_query.DPQuery):
"""See base class."""
return self._query.derive_sample_params(global_state)
def initial_sample_state(self, global_state, tensors):
def initial_sample_state(self, global_state, template):
"""See base class."""
return self._query.initial_sample_state(global_state, tensors)
return self._query.initial_sample_state(global_state, template)
def accumulate_record(self, params, sample_state, record):
def preprocess_record(self, params, record):
"""See base class."""
return self._query.accumulate_record(params, sample_state, record)
return self._query.preprocess_record(params, record)
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
"""See base class."""
return self._query.accumulate_preprocessed_record(
sample_state, preprocessed_record)
def merge_sample_states(self, sample_state_1, sample_state_2):
"""See base class."""
return self._query.merge_sample_states(sample_state_1, sample_state_2)
def get_noised_result(self, sample_state, global_state):
"""Ensures sample is recorded to the ledger and returns noised result."""

View file

@ -5,6 +5,10 @@ licenses(["notice"]) # Apache 2.0
py_library(
name = "dp_query",
srcs = ["dp_query.py"],
deps = [
"//third_party/py/distutils",
"//third_party/py/tensorflow",
],
)
py_library(

View file

@ -1,4 +1,4 @@
# Copyright 2018, The TensorFlow Authors.
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -47,6 +47,13 @@ from __future__ import division
from __future__ import print_function
import abc
from distutils.version import LooseVersion
import tensorflow as tf
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
nest = tf.contrib.framework.nest
else:
nest = tf.nest
class DPQuery(object):
@ -54,12 +61,10 @@ class DPQuery(object):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def initial_global_state(self):
"""Returns the initial global state for the DPQuery."""
pass
return ()
@abc.abstractmethod
def derive_sample_params(self, global_state):
"""Given the global state, derives parameters to use for the next sample.
@ -69,25 +74,74 @@ class DPQuery(object):
Returns:
Parameters to use to process records in the next sample.
"""
pass
del global_state # unused.
return ()
@abc.abstractmethod
def initial_sample_state(self, global_state, tensors):
def initial_sample_state(self, global_state, template):
"""Returns an initial state to use for the next sample.
Args:
global_state: The current global state.
tensors: A structure of tensors used as a template to create the initial
sample state.
template: A nested structure of tensors, TensorSpecs, or numpy arrays used
as a template to create the initial sample state. It is assumed that the
leaves of the structure are python scalars or some type that has
properties `shape` and `dtype`.
Returns: An initial sample state.
"""
pass
def preprocess_record(self, params, record):
"""Preprocesses a single record.
This preprocessing is applied to one client's record, e.g. selecting vectors
and clipping them to a fixed L2 norm. This method can be executed in a
separate TF session, or even on a different machine, so it should not depend
on any TF inputs other than those provided as input arguments. In
particular, implementations should avoid accessing any TF tensors or
variables that are stored in self.
Args:
params: The parameters for the sample. In standard DP-SGD training,
the clipping norm for the sample's microbatch gradients (i.e.,
a maximum norm magnitude to which each gradient is clipped)
record: The record to be processed. In standard DP-SGD training,
the gradient computed for the examples in one microbatch, which
may be the gradient for just one example (for size 1 microbatches).
Returns:
A structure of tensors to be aggregated.
"""
del params # unused.
return record
@abc.abstractmethod
def accumulate_preprocessed_record(
self, sample_state, preprocessed_record):
"""Accumulates a single preprocessed record into the sample state.
This method is intended to only do simple aggregation, typically just a sum.
In the future, we might remove this method and replace it with a way to
declaratively specify the type of aggregation required.
Args:
sample_state: The current sample state. In standard DP-SGD training,
the accumulated sum of previous clipped microbatch gradients.
preprocessed_record: The preprocessed record to accumulate.
Returns:
The updated sample state.
"""
pass
def accumulate_record(self, params, sample_state, record):
"""Accumulates a single record into the sample state.
This is a helper method that simply delegates to `preprocess_record` and
`accumulate_preprocessed_record` for the common case when both of those
functions run on a single device.
Args:
params: The parameters for the sample. In standard DP-SGD training,
the clipping norm for the sample's microbatch gradients (i.e.,
@ -102,6 +156,21 @@ class DPQuery(object):
The updated sample state. In standard DP-SGD training, the set of
previous mcrobatch gradients with the addition of the record argument.
"""
preprocessed_record = self.preprocess_record(params, record)
return self.accumulate_preprocessed_record(
sample_state, preprocessed_record)
@abc.abstractmethod
def merge_sample_states(self, sample_state_1, sample_state_2):
"""Merges two sample states into a single state.
Args:
sample_state_1: The first sample state to merge.
sample_state_2: The second sample state to merge.
Returns:
The merged sample state.
"""
pass
@abc.abstractmethod
@ -123,3 +192,26 @@ class DPQuery(object):
averaging performed in a manner that guarantees differential privacy.
"""
pass
def zeros_like(arg):
"""A `zeros_like` function that also works for `tf.TensorSpec`s."""
try:
arg = tf.convert_to_tensor(arg)
except TypeError:
pass
return tf.zeros(arg.shape, arg.dtype)
class SumAggregationDPQuery(DPQuery):
"""Base class for DPQueries that aggregate via sum."""
def initial_sample_state(self, global_state, template):
del global_state # unused.
return nest.map_structure(zeros_like, template)
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
return nest.map_structure(tf.add, sample_state, preprocessed_record)
def merge_sample_states(self, sample_state_1, sample_state_2):
return nest.map_structure(tf.add, sample_state_1, sample_state_2)

View file

@ -31,7 +31,7 @@ else:
nest = tf.nest
class GaussianSumQuery(dp_query.DPQuery):
class GaussianSumQuery(dp_query.SumAggregationDPQuery):
"""Implements DPQuery interface for Gaussian sum queries.
Accumulates clipped vectors, then adds Gaussian noise to the sum.
@ -50,31 +50,10 @@ class GaussianSumQuery(dp_query.DPQuery):
self._stddev = tf.cast(stddev, tf.float32)
self._ledger = ledger
def initial_global_state(self):
"""Returns the initial global state for the GaussianSumQuery."""
return None
def derive_sample_params(self, global_state):
"""Given the global state, derives parameters to use for the next sample.
Args:
global_state: The current global state.
Returns:
Parameters to use to process records in the next sample.
"""
return self._l2_norm_clip
def initial_sample_state(self, global_state, tensors):
"""Returns an initial state to use for the next sample.
Args:
global_state: The current global state.
tensors: A structure of tensors used as a template to create the initial
sample state.
Returns: An initial sample state.
"""
def initial_sample_state(self, global_state, template):
if self._ledger:
dependencies = [
self._ledger.record_sum_query(self._l2_norm_clip, self._stddev)
@ -82,51 +61,32 @@ class GaussianSumQuery(dp_query.DPQuery):
else:
dependencies = []
with tf.control_dependencies(dependencies):
return nest.map_structure(tf.zeros_like, tensors)
return nest.map_structure(
dp_query.zeros_like, template)
def accumulate_record_impl(self, params, sample_state, record):
"""Accumulates a single record into the sample state.
def preprocess_record_impl(self, params, record):
"""Clips the l2 norm, returning the clipped record and the l2 norm.
Args:
params: The parameters for the sample.
sample_state: The current sample state.
record: The record to accumulate.
record: The record to be processed.
Returns:
A tuple containing the updated sample state and the global norm.
A tuple (preprocessed_records, l2_norm) where `preprocessed_records` is
the structure of preprocessed tensors, and l2_norm is the total l2 norm
before clipping.
"""
l2_norm_clip = params
record_as_list = nest.flatten(record)
clipped_as_list, norm = tf.clip_by_global_norm(record_as_list, l2_norm_clip)
clipped = nest.pack_sequence_as(record, clipped_as_list)
return nest.map_structure(tf.add, sample_state, clipped), norm
return nest.pack_sequence_as(record, clipped_as_list), norm
def accumulate_record(self, params, sample_state, record):
"""Accumulates a single record into the sample state.
Args:
params: The parameters for the sample.
sample_state: The current sample state.
record: The record to accumulate.
Returns:
The updated sample state.
"""
new_sample_state, _ = self.accumulate_record_impl(
params, sample_state, record)
return new_sample_state
def preprocess_record(self, params, record):
preprocessed_record, _ = self.preprocess_record_impl(params, record)
return preprocessed_record
def get_noised_result(self, sample_state, global_state):
"""Gets noised sum after all records of sample have been accumulated.
Args:
sample_state: The sample state after all records have been accumulated.
global_state: The global state.
Returns:
A tuple (estimate, new_global_state) where "estimate" is the estimated
sum of the records and "new_global_state" is the updated global state.
"""
"""See base class."""
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
def add_noise(v):
return v + tf.random_normal(tf.shape(v), stddev=self._stddev)

View file

@ -91,6 +91,32 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
result_stddev = np.std(noised_sums)
self.assertNear(result_stddev, stddev, 0.1)
def test_gaussian_sum_merge(self):
records1 = [tf.constant([2.0, 0.0]), tf.constant([-1.0, 1.0])]
records2 = [tf.constant([3.0, 5.0]), tf.constant([-1.0, 4.0])]
def get_sample_state(records):
query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=1.0)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
sample_state = query.initial_sample_state(global_state, records[0])
for record in records:
sample_state = query.accumulate_record(params, sample_state, record)
return sample_state
sample_state_1 = get_sample_state(records1)
sample_state_2 = get_sample_state(records2)
merged = gaussian_query.GaussianSumQuery(10.0, 1.0).merge_sample_states(
sample_state_1,
sample_state_2)
with self.cached_session() as sess:
result = sess.run(merged)
expected = [3.0, 10.0]
self.assertAllClose(result, expected)
def test_gaussian_average_no_noise(self):
with self.cached_session() as sess:
record1 = tf.constant([5.0, 0.0]) # Clipped to [3.0, 0.0].

View file

@ -56,52 +56,39 @@ class NestedQuery(dp_query.DPQuery):
"""
self._queries = queries
def _map_to_queries(self, fn, *inputs):
def _map_to_queries(self, fn, *inputs, **kwargs):
def caller(query, *args):
return getattr(query, fn)(*args)
return getattr(query, fn)(*args, **kwargs)
return nest.map_structure_up_to(
self._queries, caller, self._queries, *inputs)
def initial_global_state(self):
"""Returns the initial global state for the NestedQuery."""
"""See base class."""
return self._map_to_queries('initial_global_state')
def derive_sample_params(self, global_state):
"""Given the global state, derives parameters to use for the next sample.
Args:
global_state: The current global state.
Returns:
Parameters to use to process records in the next sample.
"""
"""See base class."""
return self._map_to_queries('derive_sample_params', global_state)
def initial_sample_state(self, global_state, tensors):
"""Returns an initial state to use for the next sample.
def initial_sample_state(self, global_state, template):
"""See base class."""
return self._map_to_queries('initial_sample_state', global_state, template)
Args:
global_state: The current global state.
tensors: A structure of tensors used as a template to create the initial
sample state.
def preprocess_record(self, params, record):
"""See base class."""
return self._map_to_queries('preprocess_record', params, record)
Returns: An initial sample state.
"""
return self._map_to_queries('initial_sample_state', global_state, tensors)
def accumulate_record(self, params, sample_state, record):
"""Accumulates a single record into the sample state.
Args:
params: The parameters for the sample.
sample_state: The current sample state.
record: The record to accumulate.
Returns:
The updated sample state.
"""
def accumulate_preprocessed_record(
self, sample_state, preprocessed_record):
"""See base class."""
return self._map_to_queries(
'accumulate_record', params, sample_state, record)
'accumulate_preprocessed_record',
sample_state,
preprocessed_record)
def merge_sample_states(self, sample_state_1, sample_state_2):
return self._map_to_queries(
'merge_sample_states', sample_state_1, sample_state_2)
def get_noised_result(self, sample_state, global_state):
"""Gets query result after all records of sample have been accumulated.

View file

@ -28,78 +28,44 @@ else:
nest = tf.nest
class NoPrivacySumQuery(dp_query.DPQuery):
class NoPrivacySumQuery(dp_query.SumAggregationDPQuery):
"""Implements DPQuery interface for a sum query with no privacy.
Accumulates vectors without clipping or adding noise.
"""
def initial_global_state(self):
"""Returns the initial global state for the NoPrivacySumQuery."""
return None
def derive_sample_params(self, global_state):
"""See base class."""
del global_state # unused.
return None
def initial_sample_state(self, global_state, tensors):
"""See base class."""
del global_state # unused.
return nest.map_structure(tf.zeros_like, tensors)
def accumulate_record(self, params, sample_state, record, weight=1):
"""See base class. Optional argument for weighted sum queries."""
del params # unused.
def add_weighted(state_tensor, record_tensor):
return tf.add(state_tensor, weight * record_tensor)
return nest.map_structure(add_weighted, sample_state, record)
def get_noised_result(self, sample_state, global_state):
"""See base class."""
return sample_state, global_state
class NoPrivacyAverageQuery(dp_query.DPQuery):
class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
"""Implements DPQuery interface for an average query with no privacy.
Accumulates vectors and normalizes by the total number of accumulated vectors.
"""
def __init__(self):
"""Initializes the NoPrivacyAverageQuery."""
self._numerator = NoPrivacySumQuery()
def initial_global_state(self):
"""Returns the initial global state for the NoPrivacyAverageQuery."""
return self._numerator.initial_global_state()
def derive_sample_params(self, global_state):
def initial_sample_state(self, global_state, template):
"""See base class."""
del global_state # unused.
return None
return (
super(NoPrivacyAverageQuery, self).initial_sample_state(
global_state, template),
tf.constant(0.0))
def initial_sample_state(self, global_state, tensors):
"""See base class."""
return self._numerator.initial_sample_state(global_state, tensors), 0.0
def preprocess_record(self, params, record, weight=1):
"""Multiplies record by weight."""
weighted_record = nest.map_structure(lambda t: weight * t, record)
return (weighted_record, weight)
def accumulate_record(self, params, sample_state, record, weight=1):
"""See base class. Optional argument for weighted average queries."""
sum_sample_state, denominator = sample_state
return (
self._numerator.accumulate_record(
params, sum_sample_state, record, weight),
tf.add(denominator, weight))
"""Accumulates record, multiplying by weight."""
weighted_record = nest.map_structure(lambda t: weight * t, record)
return self.accumulate_preprocessed_record(
sample_state, (weighted_record, weight))
def get_noised_result(self, sample_state, global_state):
"""See base class."""
sum_sample_state, denominator = sample_state
exact_sum, new_global_state = self._numerator.get_noised_result(
sum_sample_state, global_state)
sum_state, denominator = sample_state
def normalize(v):
return tf.truediv(v, denominator)
return nest.map_structure(normalize, exact_sum), new_global_state
return nest.map_structure(
lambda t: tf.truediv(t, denominator), sum_state), ()

View file

@ -1,4 +1,4 @@
# Copyright 2018, The TensorFlow Authors.
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -27,7 +27,7 @@ from privacy.dp_query import test_utils
class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_no_privacy_sum(self):
def test_sum(self):
with self.cached_session() as sess:
record1 = tf.constant([2.0, 0.0])
record2 = tf.constant([-1.0, 1.0])
@ -38,20 +38,6 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
expected = [1.0, 1.0]
self.assertAllClose(result, expected)
def test_no_privacy_weighted_sum(self):
with self.cached_session() as sess:
record1 = tf.constant([2.0, 0.0])
record2 = tf.constant([-1.0, 1.0])
weights = [1, 2]
query = no_privacy_query.NoPrivacySumQuery()
query_result, _ = test_utils.run_query(
query, [record1, record2], weights=weights)
result = sess.run(query_result)
expected = [0.0, 2.0]
self.assertAllClose(result, expected)
def test_no_privacy_average(self):
with self.cached_session() as sess:
record1 = tf.constant([5.0, 0.0])

View file

@ -38,65 +38,39 @@ class NormalizedQuery(dp_query.DPQuery):
Args:
numerator_query: A DPQuery for the numerator.
denominator: A value for the denominator.
denominator: A value for the denominator. May be None if it will be
supplied via the set_denominator function before get_noised_result is
called.
"""
self._numerator = numerator_query
self._denominator = tf.cast(denominator,
tf.float32) if denominator is not None else None
self._denominator = (
tf.cast(denominator, tf.float32) if denominator is not None else None)
def initial_global_state(self):
"""Returns the initial global state for the NormalizedQuery."""
"""See base class."""
# NormalizedQuery has no global state beyond the numerator state.
return self._numerator.initial_global_state()
def derive_sample_params(self, global_state):
"""Given the global state, derives parameters to use for the next sample.
Args:
global_state: The current global state.
Returns:
Parameters to use to process records in the next sample.
"""
"""See base class."""
return self._numerator.derive_sample_params(global_state)
def initial_sample_state(self, global_state, tensors):
"""Returns an initial state to use for the next sample.
Args:
global_state: The current global state.
tensors: A structure of tensors used as a template to create the initial
sample state.
Returns: An initial sample state.
"""
def initial_sample_state(self, global_state, template):
"""See base class."""
# NormalizedQuery has no sample state beyond the numerator state.
return self._numerator.initial_sample_state(global_state, tensors)
return self._numerator.initial_sample_state(global_state, template)
def accumulate_record(self, params, sample_state, record):
"""Accumulates a single record into the sample state.
def preprocess_record(self, params, record):
return self._numerator.preprocess_record(params, record)
Args:
params: The parameters for the sample.
sample_state: The current sample state.
record: The record to accumulate.
Returns:
The updated sample state.
"""
return self._numerator.accumulate_record(params, sample_state, record)
def accumulate_preprocessed_record(
self, sample_state, preprocessed_record):
"""See base class."""
return self._numerator.accumulate_preprocessed_record(
sample_state, preprocessed_record)
def get_noised_result(self, sample_state, global_state):
"""Gets noised average after all records of sample have been accumulated.
Args:
sample_state: The sample state after all records have been accumulated.
global_state: The global state.
Returns:
A tuple (estimate, new_global_state) where "estimate" is the estimated
average of the records and "new_global_state" is the updated global state.
"""
"""See base class."""
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
sample_state, global_state)
def normalize(v):
@ -104,5 +78,10 @@ class NormalizedQuery(dp_query.DPQuery):
return nest.map_structure(normalize, noised_sum), new_sum_global_state
def merge_sample_states(self, sample_state_1, sample_state_2):
"""See base class."""
return self._numerator.merge_sample_states(sample_state_1, sample_state_2)
def set_denominator(self, denominator):
"""Sets the denominator for the NormalizedQuery."""
self._denominator = tf.cast(denominator, tf.float32)