Improving docstrings for DPQueries.
PiperOrigin-RevId: 378956777
This commit is contained in:
parent
4b09172c31
commit
5f07198b66
10 changed files with 394 additions and 102 deletions
|
@ -54,7 +54,31 @@ class PrivacyLedger(object):
|
|||
"""Class for keeping a record of private queries.
|
||||
|
||||
The PrivacyLedger keeps a record of all queries executed over a given dataset
|
||||
for the purpose of computing privacy guarantees.
|
||||
for the purpose of computing privacy guarantees. To use it, it must be
|
||||
associated with a `DPQuery` object via a `QueryWithLedger`.
|
||||
|
||||
The current implementation works only with DPQueries that consist of composing
|
||||
Gaussian sum mechanism with Poisson subsampling.
|
||||
|
||||
Example usage:
|
||||
|
||||
```
|
||||
import tensorflow_privacy as tfp
|
||||
|
||||
dp_query = tfp.QueryWithLedger(
|
||||
tensorflow_privacy.GaussianSumQuery(
|
||||
l2_norm_clip=1.0, stddev=1.0),
|
||||
population_size=10000,
|
||||
selection_probability=0.01)
|
||||
|
||||
# Use dp_query here in training loop.
|
||||
|
||||
formatted_ledger = dp_query.ledger.get_formatted_ledger_eager()
|
||||
orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] +
|
||||
list(range(5, 64)) + [128, 256, 512])
|
||||
total_rdp = tfp.compute_rdp_from_ledger(formatted_ledger, orders)
|
||||
epsilon = tfp.get_privacy_spent(orders, total_rdp, target_delta=1e-5)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -106,7 +130,8 @@ class PrivacyLedger(object):
|
|||
noise_stddev: The standard deviation of the noise applied to the sum.
|
||||
|
||||
Returns:
|
||||
An operation recording the sum query to the ledger.
|
||||
An operation recording the sum query to the ledger. This should be called
|
||||
for every Gaussian sum query that is issued on a sample.
|
||||
"""
|
||||
|
||||
def _do_record_query():
|
||||
|
@ -118,7 +143,15 @@ class PrivacyLedger(object):
|
|||
return self._cs.execute(_do_record_query)
|
||||
|
||||
def finalize_sample(self):
|
||||
"""Finalizes sample and records sample ledger entry."""
|
||||
"""Finalizes sample and records sample ledger entry.
|
||||
|
||||
This should be called once per application of the mechanism on a sample,
|
||||
after all sum queries have been recorded.
|
||||
|
||||
Returns:
|
||||
An operation recording the complete mechanism (sampling and sum
|
||||
estimation) to the ledger.
|
||||
"""
|
||||
with tf.control_dependencies([
|
||||
tf.assign(self._sample_var, [
|
||||
self._population_size, self._selection_probability,
|
||||
|
@ -132,6 +165,7 @@ class PrivacyLedger(object):
|
|||
return self._sample_buffer.append(self._sample_var)
|
||||
|
||||
def get_unformatted_ledger(self):
|
||||
"""Returns the raw sample and query values."""
|
||||
return self._sample_buffer.values, self._query_buffer.values
|
||||
|
||||
def get_formatted_ledger(self, sess):
|
||||
|
@ -169,7 +203,10 @@ class QueryWithLedger(dp_query.DPQuery):
|
|||
those contained in the leaves of a nested query) should also contain a
|
||||
reference to the same ledger object.
|
||||
|
||||
For example usage, see `privacy_ledger_test.py`.
|
||||
Only composed Gaussian sum queries with Poisson subsampling are supported.
|
||||
This includes `GaussianSumQuery`, `QuantileEstimatorQuery`, and
|
||||
`QuantileAdaptiveClipSumQuery`, as well as `NestedQuery` or `NormalizedQuery`
|
||||
objects that contain the previous mentioned query types.
|
||||
"""
|
||||
|
||||
def __init__(self, query,
|
||||
|
@ -185,8 +222,8 @@ class QueryWithLedger(dp_query.DPQuery):
|
|||
population, i.e. size of the training data used in each epoch. May be
|
||||
`None` if `ledger` is specified.
|
||||
selection_probability: A floating point value (may be variable) specifying
|
||||
the probability each record is included in a sample. May be `None` if
|
||||
`ledger` is specified.
|
||||
the probability each record is included in a sample under Poisson
|
||||
subsampling. May be `None` if `ledger` is specified.
|
||||
ledger: A `PrivacyLedger` to use. Must be specified if either of
|
||||
`population_size` or `selection_probability` is `None`.
|
||||
"""
|
||||
|
@ -201,46 +238,62 @@ class QueryWithLedger(dp_query.DPQuery):
|
|||
|
||||
@property
|
||||
def ledger(self):
|
||||
"""Gets the ledger that all inner queries record to."""
|
||||
return self._ledger
|
||||
|
||||
def set_ledger(self, ledger):
|
||||
"""Sets a new ledger."""
|
||||
self._ledger = ledger
|
||||
self._query.set_ledger(ledger)
|
||||
|
||||
def initial_global_state(self):
|
||||
"""See base class."""
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
return self._query.initial_global_state()
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""See base class."""
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
return self._query.derive_sample_params(global_state)
|
||||
|
||||
def initial_sample_state(self, template):
|
||||
"""See base class."""
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
|
||||
return self._query.initial_sample_state(template)
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""See base class."""
|
||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
|
||||
return self._query.preprocess_record(params, record)
|
||||
|
||||
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
|
||||
"""See base class."""
|
||||
"""Implements `tensorflow_privacy.DPQuery.accumulate_preprocessed_record`."""
|
||||
return self._query.accumulate_preprocessed_record(
|
||||
sample_state, preprocessed_record)
|
||||
|
||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||
"""See base class."""
|
||||
"""Implements `tensorflow_privacy.DPQuery.merge_sample_states`."""
|
||||
return self._query.merge_sample_states(sample_state_1, sample_state_2)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Ensures sample is recorded to the ledger and returns noised result."""
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`.
|
||||
|
||||
Besides noising and returning the result of the inner query, ensures that
|
||||
the sample is recorded to the ledger.
|
||||
|
||||
Args:
|
||||
sample_state: The sample state after all records have been accumulated.
|
||||
global_state: The global state, storing long-term privacy bookkeeping.
|
||||
|
||||
Returns:
|
||||
A tuple (result, new_global_state) where "result" is the result of the
|
||||
query and "new_global_state" is the updated global state.
|
||||
"""
|
||||
# Ensure sample_state is fully aggregated before calling get_noised_result.
|
||||
with tf.control_dependencies(tf.nest.flatten(sample_state)):
|
||||
result, new_global_state = self._query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
|
||||
# Ensure inner queries have recorded before finalizing.
|
||||
with tf.control_dependencies(tf.nest.flatten(result)):
|
||||
finalize = self._ledger.finalize_sample()
|
||||
|
||||
# Ensure finalizing happens.
|
||||
with tf.control_dependencies([finalize]):
|
||||
return tf.nest.map_structure(tf.identity, result), new_global_state
|
||||
|
|
|
@ -53,13 +53,58 @@ import tensorflow.compat.v1 as tf
|
|||
|
||||
|
||||
class DPQuery(object):
|
||||
"""Interface for differentially private query mechanisms."""
|
||||
"""Interface for differentially private query mechanisms.
|
||||
|
||||
Differential privacy is achieved by processing records to bound sensitivity,
|
||||
accumulating the processed records (usually by summing them) and then
|
||||
adding noise to the aggregated result. The process can be repeated to compose
|
||||
applications of the same mechanism, possibly with different parameters.
|
||||
|
||||
The DPQuery interface specifies a functional approach to this process. A
|
||||
global state maintains state that persists across applications of the
|
||||
mechanism. For each application, the following steps are performed:
|
||||
|
||||
1. Use the global state to derive parameters to use for the next sample of
|
||||
records.
|
||||
2. Initialize a sample state that will accumulate processed records.
|
||||
3. For each record:
|
||||
a. Process the record.
|
||||
b. Accumulate the record into the sample state.
|
||||
4. Get the result of the mechanism, possibly updating the global state to use
|
||||
in the next application.
|
||||
5. Derive metrics from the global state.
|
||||
|
||||
Here is an example using the GaussianSumQuery. Assume there is some function
|
||||
records_for_round(round) that returns an iterable of records to use on some
|
||||
round.
|
||||
|
||||
```
|
||||
dp_query = tensorflow_privacy.GaussianSumQuery(
|
||||
l2_norm_clip=1.0, stddev=1.0)
|
||||
global_state = dp_query.initial_global_state()
|
||||
|
||||
for round in range(num_rounds):
|
||||
sample_params = dp_query.derive_sample_params(global_state)
|
||||
sample_state = dp_query.initial_sample_state()
|
||||
for record in records_for_round(round):
|
||||
sample_state = dp_query.accumulate_record(
|
||||
sample_params, sample_state, record)
|
||||
|
||||
result, global_state = dp_query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
metrics = dp_query.derive_metrics(global_state)
|
||||
|
||||
# Do something with result and metrics...
|
||||
```
|
||||
"""
|
||||
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
def set_ledger(self, ledger):
|
||||
"""Supplies privacy ledger to which the query can record privacy events.
|
||||
|
||||
The ledger should be updated with each call to get_noised_result.
|
||||
|
||||
Args:
|
||||
ledger: A `PrivacyLedger`.
|
||||
"""
|
||||
|
@ -68,12 +113,26 @@ class DPQuery(object):
|
|||
'DPQuery type %s does not support set_ledger.' % type(self).__name__)
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Returns the initial global state for the DPQuery."""
|
||||
"""Returns the initial global state for the DPQuery.
|
||||
|
||||
The global state contains any state information that changes across
|
||||
repeated applications of the mechanism. The default implementation returns
|
||||
just an empty tuple for implementing classes that do not have any persistent
|
||||
state.
|
||||
|
||||
Returns:
|
||||
The global state.
|
||||
"""
|
||||
return ()
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Given the global state, derives parameters to use for the next sample.
|
||||
|
||||
For example, if the mechanism needs to clip records to bound the norm,
|
||||
the clipping norm should be part of the sample params. In a distributed
|
||||
context, this is the part of the state that would be sent to the workers
|
||||
so they can process records.
|
||||
|
||||
Args:
|
||||
global_state: The current global state.
|
||||
|
||||
|
@ -87,6 +146,10 @@ class DPQuery(object):
|
|||
def initial_sample_state(self, template=None):
|
||||
"""Returns an initial state to use for the next sample.
|
||||
|
||||
For typical `DPQuery` classes that are aggregated by summation, this should
|
||||
return a nested structure of zero tensors of the appropriate shapes, to
|
||||
which processed records will be aggregated.
|
||||
|
||||
Args:
|
||||
template: A nested structure of tensors, TensorSpecs, or numpy arrays used
|
||||
as a template to create the initial sample state. It is assumed that the
|
||||
|
@ -145,7 +208,7 @@ class DPQuery(object):
|
|||
|
||||
This is a helper method that simply delegates to `preprocess_record` and
|
||||
`accumulate_preprocessed_record` for the common case when both of those
|
||||
functions run on a single device.
|
||||
functions run on a single device. Typically this will be a simple sum.
|
||||
|
||||
Args:
|
||||
params: The parameters for the sample. In standard DP-SGD training,
|
||||
|
@ -169,6 +232,11 @@ class DPQuery(object):
|
|||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||
"""Merges two sample states into a single state.
|
||||
|
||||
This can be useful if aggregation is performed hierarchically, where
|
||||
multiple sample states are used to accumulate records and then
|
||||
hierarchically merged into the final accumulated state. Typically this will
|
||||
be a simple sum.
|
||||
|
||||
Args:
|
||||
sample_state_1: The first sample state to merge.
|
||||
sample_state_2: The second sample state to merge.
|
||||
|
@ -180,11 +248,14 @@ class DPQuery(object):
|
|||
|
||||
@abc.abstractmethod
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Gets query result after all records of sample have been accumulated.
|
||||
"""Gets the query result after all records of sample have been accumulated.
|
||||
|
||||
The global state can also be updated for use in the next application of the
|
||||
DP mechanism.
|
||||
|
||||
Args:
|
||||
sample_state: The sample state after all records have been accumulated.
|
||||
In standard DP-SGD training, the accumulated sum of clipped microbatch
|
||||
sample_state: The sample state after all records have been accumulated. In
|
||||
standard DP-SGD training, the accumulated sum of clipped microbatch
|
||||
gradients (in the special case of microbatches of size 1, the clipped
|
||||
per-example gradients).
|
||||
global_state: The global state, storing long-term privacy bookkeeping.
|
||||
|
@ -213,7 +284,7 @@ class DPQuery(object):
|
|||
return collections.OrderedDict()
|
||||
|
||||
|
||||
def zeros_like(arg):
|
||||
def _zeros_like(arg):
|
||||
"""A `zeros_like` function that also works for `tf.TensorSpec`s."""
|
||||
try:
|
||||
arg = tf.convert_to_tensor(value=arg)
|
||||
|
@ -222,7 +293,8 @@ def zeros_like(arg):
|
|||
return tf.zeros(arg.shape, arg.dtype)
|
||||
|
||||
|
||||
def safe_add(x, y):
|
||||
def _safe_add(x, y):
|
||||
"""Adds x and y but if y is None, simply returns x."""
|
||||
return x if y is None else tf.add(x, y)
|
||||
|
||||
|
||||
|
@ -230,13 +302,17 @@ class SumAggregationDPQuery(DPQuery):
|
|||
"""Base class for DPQueries that aggregate via sum."""
|
||||
|
||||
def initial_sample_state(self, template=None):
|
||||
return tf.nest.map_structure(zeros_like, template)
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
|
||||
return tf.nest.map_structure(_zeros_like, template)
|
||||
|
||||
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
|
||||
return tf.nest.map_structure(safe_add, sample_state, preprocessed_record)
|
||||
"""Implements `tensorflow_privacy.DPQuery.accumulate_preprocessed_record`."""
|
||||
return tf.nest.map_structure(_safe_add, sample_state, preprocessed_record)
|
||||
|
||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||
"""Implements `tensorflow_privacy.DPQuery.merge_sample_states`."""
|
||||
return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
return sample_state, global_state
|
||||
|
|
|
@ -28,7 +28,7 @@ from tensorflow_privacy.privacy.dp_query import dp_query
|
|||
class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||
"""Implements DPQuery interface for Gaussian sum queries.
|
||||
|
||||
Accumulates clipped vectors, then adds Gaussian noise to the sum.
|
||||
Clips records to bound the L2 norm, then adds Gaussian noise to the sum.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
@ -48,6 +48,7 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
|||
self._ledger = None
|
||||
|
||||
def set_ledger(self, ledger):
|
||||
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
|
||||
self._ledger = ledger
|
||||
|
||||
def make_global_state(self, l2_norm_clip, stddev):
|
||||
|
@ -56,9 +57,11 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
|||
tf.cast(l2_norm_clip, tf.float32), tf.cast(stddev, tf.float32))
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
return self.make_global_state(self._l2_norm_clip, self._stddev)
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
return global_state.l2_norm_clip
|
||||
|
||||
def preprocess_record_impl(self, params, record):
|
||||
|
@ -79,11 +82,12 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
|||
return tf.nest.pack_sequence_as(record, clipped_as_list), norm
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
|
||||
preprocessed_record, _ = self.preprocess_record_impl(params, record)
|
||||
return preprocessed_record
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""See base class."""
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
if distutils.version.LooseVersion(
|
||||
tf.__version__) < distutils.version.LooseVersion('2.0.0'):
|
||||
|
||||
|
|
|
@ -32,12 +32,18 @@ class NestedQuery(dp_query.DPQuery):
|
|||
NestedQuery evaluates arbitrary nested structures of queries. Records must be
|
||||
nested structures of tensors that are compatible (in type and arity) with the
|
||||
query structure, but are allowed to have deeper structure within each leaf of
|
||||
the query structure. For example, the nested query [q1, q2] is compatible with
|
||||
the record [t1, t2] or [t1, (t2, t3)], but not with (t1, t2), [t1] or
|
||||
[t1, t2, t3]. The entire substructure of each record corresponding to a leaf
|
||||
node of the query structure is routed to the corresponding query. If the same
|
||||
tensor should be consumed by multiple sub-queries, it can be replicated in the
|
||||
record, for example [t1, t1].
|
||||
the query structure. The entire substructure of each record corresponding to a
|
||||
leaf node of the query structure is routed to the corresponding query.
|
||||
|
||||
For example, a nested query with structure "[q1, q2]" is compatible with a
|
||||
record of structure "[t1, (t2, t3)]": t1 would be processed by q1, and (t2,
|
||||
t3) would be processed by q2. On the other hand, "[q1, q2]" is not compatible
|
||||
with "(t1, t2)" (type mismatch), "[t1]" (arity-mismatch) or "[t1, t2, t3]"
|
||||
(arity-mismatch).
|
||||
|
||||
It is possible for the same tensor to be consumed by multiple sub-queries, by
|
||||
simply replicating it in the record, for example providing "[t1, t1]" to
|
||||
"[q1, q2]".
|
||||
|
||||
NestedQuery is intended to allow privacy mechanisms for groups as described in
|
||||
[McMahan & Andrew, 2018: "A General Approach to Adding Differential Privacy to
|
||||
|
@ -61,35 +67,43 @@ class NestedQuery(dp_query.DPQuery):
|
|||
*inputs)
|
||||
|
||||
def set_ledger(self, ledger):
|
||||
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
|
||||
self._map_to_queries('set_ledger', ledger=ledger)
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
return self._map_to_queries('initial_global_state')
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
return self._map_to_queries('derive_sample_params', global_state)
|
||||
|
||||
def initial_sample_state(self, template=None):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
|
||||
if template is None:
|
||||
return self._map_to_queries('initial_sample_state')
|
||||
else:
|
||||
return self._map_to_queries('initial_sample_state', template)
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
|
||||
return self._map_to_queries('preprocess_record', params, record)
|
||||
|
||||
def accumulate_preprocessed_record(
|
||||
self, sample_state, preprocessed_record):
|
||||
"""Implements `tensorflow_privacy.DPQuery.accumulate_preprocessed_record`."""
|
||||
return self._map_to_queries(
|
||||
'accumulate_preprocessed_record',
|
||||
sample_state,
|
||||
preprocessed_record)
|
||||
|
||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||
"""Implements `tensorflow_privacy.DPQuery.merge_sample_states`."""
|
||||
return self._map_to_queries(
|
||||
'merge_sample_states', sample_state_1, sample_state_2)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
estimates_and_new_global_states = self._map_to_queries(
|
||||
'get_noised_result', sample_state, global_state)
|
||||
|
||||
|
@ -99,6 +113,7 @@ class NestedQuery(dp_query.DPQuery):
|
|||
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
|
||||
|
||||
def derive_metrics(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
|
||||
metrics = collections.OrderedDict()
|
||||
|
||||
def add_metrics(tuple_path, subquery, subquery_global_state):
|
||||
|
@ -122,6 +137,8 @@ class NestedSumQuery(NestedQuery, dp_query.SumAggregationDPQuery):
|
|||
Args:
|
||||
queries: A nested structure of queries that must all be
|
||||
SumAggregationDPQueries.
|
||||
|
||||
Raises: TypeError if any of the subqueries are not SumAggregationDPQueries.
|
||||
"""
|
||||
def check(query):
|
||||
if not isinstance(query, dp_query.SumAggregationDPQuery):
|
||||
|
|
|
@ -34,13 +34,14 @@ class NoPrivacySumQuery(dp_query.SumAggregationDPQuery):
|
|||
self._ledger = None
|
||||
|
||||
def set_ledger(self, ledger):
|
||||
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
|
||||
warnings.warn(
|
||||
'Attempt to use NoPrivacySumQuery with privacy ledger. Privacy '
|
||||
'guarantees will be vacuous.')
|
||||
self._ledger = ledger
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""See base class."""
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
|
||||
if self._ledger:
|
||||
dependencies = [
|
||||
|
@ -57,35 +58,67 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
|
|||
"""Implements DPQuery interface for an average query with no privacy.
|
||||
|
||||
Accumulates vectors and normalizes by the total number of accumulated vectors.
|
||||
Under some sampling schemes, such as Poisson subsampling, the number of
|
||||
records in a sample is a private quantity, so we lose all privacy guarantees
|
||||
by using the number of records directly to normalize.
|
||||
|
||||
Also allows weighted accumulation, unlike the base class DPQuery. In a private
|
||||
implementation of weighted average, the weight would have to be itself
|
||||
privatized.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the NoPrivacyAverageQuery."""
|
||||
self._ledger = None
|
||||
|
||||
def set_ledger(self, ledger):
|
||||
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
|
||||
warnings.warn(
|
||||
'Attempt to use NoPrivacyAverageQuery with privacy ledger. Privacy '
|
||||
'guarantees will be vacuous.')
|
||||
self._ledger = ledger
|
||||
|
||||
def initial_sample_state(self, template):
|
||||
"""See base class."""
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
|
||||
return (super(NoPrivacyAverageQuery, self).initial_sample_state(template),
|
||||
tf.constant(0.0))
|
||||
|
||||
def preprocess_record(self, params, record, weight=1):
|
||||
"""Multiplies record by weight."""
|
||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
|
||||
|
||||
Optional `weight` argument allows weighted accumulation.
|
||||
|
||||
Args:
|
||||
params: The parameters for the sample.
|
||||
record: The record to accumulate.
|
||||
weight: Optional weight for the record.
|
||||
|
||||
Returns:
|
||||
The preprocessed record.
|
||||
"""
|
||||
weighted_record = tf.nest.map_structure(lambda t: weight * t, record)
|
||||
return (weighted_record, tf.cast(weight, tf.float32))
|
||||
|
||||
def accumulate_record(self, params, sample_state, record, weight=1):
|
||||
"""Accumulates record, multiplying by weight."""
|
||||
"""Implements `tensorflow_privacy.DPQuery.accumulate_record`.
|
||||
|
||||
Optional `weight` argument allows weighted accumulation.
|
||||
|
||||
Args:
|
||||
params: The parameters for the sample.
|
||||
sample_state: The current sample state.
|
||||
record: The record to accumulate.
|
||||
weight: Optional weight for the record.
|
||||
|
||||
Returns:
|
||||
The updated sample state.
|
||||
"""
|
||||
weighted_record = tf.nest.map_structure(lambda t: weight * t, record)
|
||||
return self.accumulate_preprocessed_record(
|
||||
sample_state, (weighted_record, tf.cast(weight, tf.float32)))
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""See base class."""
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
sum_state, denominator = sample_state
|
||||
|
||||
if self._ledger:
|
||||
|
|
|
@ -27,14 +27,22 @@ from tensorflow_privacy.privacy.dp_query import dp_query
|
|||
|
||||
|
||||
class NormalizedQuery(dp_query.SumAggregationDPQuery):
|
||||
"""DPQuery for queries with a DPQuery numerator and fixed denominator."""
|
||||
"""`DPQuery` for queries with a `DPQuery` numerator and fixed denominator.
|
||||
|
||||
If the number of records per round is a public constant R, `NormalizedQuery`
|
||||
could be used with a sum query as the numerator and R as the denominator to
|
||||
implement an average. Under some sampling schemes, such as Poisson
|
||||
subsampling, the actual number of records in a sample is a private quantity,
|
||||
so we cannot use it directly. Using this class with the expected number of
|
||||
records as the denominator gives an unbiased estimate of the average.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_GlobalState = collections.namedtuple(
|
||||
'_GlobalState', ['numerator_state', 'denominator'])
|
||||
|
||||
def __init__(self, numerator_query, denominator):
|
||||
"""Initializer for NormalizedQuery.
|
||||
"""Initializes the NormalizedQuery.
|
||||
|
||||
Args:
|
||||
numerator_query: A SumAggregationDPQuery for the numerator.
|
||||
|
@ -48,27 +56,30 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
|
|||
assert isinstance(self._numerator, dp_query.SumAggregationDPQuery)
|
||||
|
||||
def set_ledger(self, ledger):
|
||||
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
|
||||
self._numerator.set_ledger(ledger)
|
||||
|
||||
def initial_global_state(self):
|
||||
if self._denominator is not None:
|
||||
denominator = tf.cast(self._denominator, tf.float32)
|
||||
else:
|
||||
denominator = None
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
denominator = tf.cast(self._denominator, tf.float32)
|
||||
return self._GlobalState(
|
||||
self._numerator.initial_global_state(), denominator)
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
return self._numerator.derive_sample_params(global_state.numerator_state)
|
||||
|
||||
def initial_sample_state(self, template):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
|
||||
# NormalizedQuery has no sample state beyond the numerator state.
|
||||
return self._numerator.initial_sample_state(template)
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
|
||||
return self._numerator.preprocess_record(params, record)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
|
||||
sample_state, global_state.numerator_state)
|
||||
def normalize(v):
|
||||
|
@ -78,4 +89,5 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
|
|||
self._GlobalState(new_sum_global_state, global_state.denominator))
|
||||
|
||||
def derive_metrics(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
|
||||
return self._numerator.derive_metrics(global_state.numerator_state)
|
||||
|
|
|
@ -11,14 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Implements DPQuery interface for adaptive clip queries.
|
||||
|
||||
Instead of a fixed clipping norm specified in advance, the clipping norm is
|
||||
dynamically adjusted to match a target fraction of clipped updates per sample,
|
||||
where the actual fraction of clipped updates is itself estimated in a
|
||||
differentially private manner. For details see Thakkar et al., "Differentially
|
||||
Private Learning with Adaptive Clipping" [http://arxiv.org/abs/1905.03871].
|
||||
"""
|
||||
"""`DPQuery` for Gaussian sum queries with adaptive clipping."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -34,10 +27,12 @@ from tensorflow_privacy.privacy.dp_query import quantile_estimator_query
|
|||
|
||||
|
||||
class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
||||
"""DPQuery for sum queries with adaptive clipping.
|
||||
"""`DPQuery` for Gaussian sum queries with adaptive clipping.
|
||||
|
||||
Clipping norm is tuned adaptively to converge to a value such that a specified
|
||||
quantile of updates are clipped.
|
||||
quantile of updates are clipped, using the algorithm of Andrew et al. (
|
||||
https://arxiv.org/abs/1905.03871). See the paper for details and suggested
|
||||
hyperparameter settings.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
@ -65,20 +60,23 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
|||
|
||||
Args:
|
||||
initial_l2_norm_clip: The initial value of clipping norm.
|
||||
noise_multiplier: The multiplier of the l2_norm_clip to make the stddev of
|
||||
the noise added to the output of the sum query.
|
||||
noise_multiplier: The stddev of the noise added to the output will be this
|
||||
times the current value of the clipping norm.
|
||||
target_unclipped_quantile: The desired quantile of updates which should be
|
||||
unclipped. I.e., a value of 0.8 means a value of l2_norm_clip should be
|
||||
found for which approximately 20% of updates are clipped each round.
|
||||
learning_rate: The learning rate for the clipping norm adaptation. A rate
|
||||
of r means that the clipping norm will change by a maximum of r at each
|
||||
step. This maximum is attained when |clip - target| is 1.0.
|
||||
Andrew et al. recommends that this be set to 0.5 to clip to the median.
|
||||
learning_rate: The learning rate for the clipping norm adaptation. With
|
||||
geometric updating, a rate of r means that the clipping norm will change
|
||||
by a maximum factor of exp(r) at each round. This maximum is attained
|
||||
when |actual_unclipped_fraction - target_unclipped_quantile| is 1.0.
|
||||
Andrew et al. recommends that this be set to 0.2 for geometric updating.
|
||||
clipped_count_stddev: The stddev of the noise added to the clipped_count.
|
||||
Since the sensitivity of the clipped count is 0.5, as a rule of thumb it
|
||||
should be about 0.5 for reasonable privacy.
|
||||
Andrew et al. recommends that this be set to `expected_num_records / 20`
|
||||
for reasonably fast adaptation and high privacy.
|
||||
expected_num_records: The expected number of records per round, used to
|
||||
estimate the clipped count quantile.
|
||||
geometric_update: If True, use geometric updating of clip.
|
||||
geometric_update: If `True`, use geometric updating of clip (recommended).
|
||||
"""
|
||||
self._noise_multiplier = noise_multiplier
|
||||
|
||||
|
@ -94,27 +92,32 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
|||
dp_query.SumAggregationDPQuery)
|
||||
|
||||
def set_ledger(self, ledger):
|
||||
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
|
||||
self._sum_query.set_ledger(ledger)
|
||||
self._quantile_estimator_query.set_ledger(ledger)
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
return self._GlobalState(
|
||||
tf.cast(self._noise_multiplier, tf.float32),
|
||||
self._sum_query.initial_global_state(),
|
||||
self._quantile_estimator_query.initial_global_state())
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
return self._SampleParams(
|
||||
self._sum_query.derive_sample_params(global_state.sum_state),
|
||||
self._quantile_estimator_query.derive_sample_params(
|
||||
global_state.quantile_estimator_state))
|
||||
|
||||
def initial_sample_state(self, template):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
|
||||
return self._SampleState(
|
||||
self._sum_query.initial_sample_state(template),
|
||||
self._quantile_estimator_query.initial_sample_state())
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
|
||||
clipped_record, global_norm = (
|
||||
self._sum_query.preprocess_record_impl(params.sum_params, record))
|
||||
|
||||
|
@ -124,6 +127,7 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
|||
return self._SampleState(clipped_record, was_unclipped)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
noised_vectors, sum_state = self._sum_query.get_noised_result(
|
||||
sample_state.sum_state, global_state.sum_state)
|
||||
del sum_state # To be set explicitly later when we know the new clip.
|
||||
|
@ -145,4 +149,5 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
|||
return noised_vectors, new_global_state
|
||||
|
||||
def derive_metrics(self, global_state):
|
||||
"""Returns the current clipping norm as a metric."""
|
||||
return collections.OrderedDict(clip=global_state.sum_state.l2_norm_clip)
|
||||
|
|
|
@ -11,13 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Implements DPQuery interface for quantile estimator.
|
||||
|
||||
From a starting estimate of the target quantile, the estimate is updated
|
||||
dynamically where the fraction of below_estimate updates is estimated in a
|
||||
differentially private manner. For details see Thakkar et al., "Differentially
|
||||
Private Learning with Adaptive Clipping" [http://arxiv.org/abs/1905.03871].
|
||||
"""
|
||||
"""Implements DPQuery interface for quantile estimator."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -33,7 +27,11 @@ from tensorflow_privacy.privacy.dp_query import normalized_query
|
|||
|
||||
|
||||
class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||
"""Iterative process to estimate target quantile of a univariate distribution."""
|
||||
"""DPQuery to estimate target quantile of a univariate distribution.
|
||||
|
||||
Uses the algorithm of Andrew et al. (https://arxiv.org/abs/1905.03871). See
|
||||
the paper for details and suggested hyperparameter settings.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_GlobalState = collections.namedtuple('_GlobalState', [
|
||||
|
@ -55,7 +53,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
|||
below_estimate_stddev,
|
||||
expected_num_records,
|
||||
geometric_update=False):
|
||||
"""Initializes the QuantileAdaptiveClipSumQuery.
|
||||
"""Initializes the QuantileEstimatorQuery.
|
||||
|
||||
Args:
|
||||
initial_estimate: The initial estimate of the quantile.
|
||||
|
@ -64,11 +62,12 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
|||
estimate each round.
|
||||
learning_rate: The learning rate. A rate of r means that the estimate will
|
||||
change by a maximum of r at each step (for arithmetic updating) or by a
|
||||
maximum factor of exp(r) (for geometric updating).
|
||||
maximum factor of exp(r) (for geometric updating). Andrew et al.
|
||||
recommends that this be set to 0.2 for geometric updating.
|
||||
below_estimate_stddev: The stddev of the noise added to the count of
|
||||
records currently below the estimate. Since the sensitivity of the count
|
||||
query is 0.5, as a rule of thumb it should be about 0.5 for reasonable
|
||||
privacy.
|
||||
records currently below the estimate. Andrew et al. recommends that this
|
||||
be set to `expected_num_records / 20` for reasonably fast adaptation and
|
||||
high privacy.
|
||||
expected_num_records: The expected number of records per round.
|
||||
geometric_update: If True, use geometric updating of estimate. Geometric
|
||||
updating is preferred for non-negative records like vector norms that
|
||||
|
@ -102,9 +101,11 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
|||
denominator=expected_num_records)
|
||||
|
||||
def set_ledger(self, ledger):
|
||||
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
|
||||
self._below_estimate_query.set_ledger(ledger)
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
return self._GlobalState(
|
||||
tf.cast(self._initial_estimate, tf.float32),
|
||||
tf.cast(self._target_quantile, tf.float32),
|
||||
|
@ -112,39 +113,42 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
|||
self._below_estimate_query.initial_global_state())
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
below_estimate_params = self._below_estimate_query.derive_sample_params(
|
||||
global_state.below_estimate_state)
|
||||
return self._SampleParams(global_state.current_estimate,
|
||||
below_estimate_params)
|
||||
|
||||
def initial_sample_state(self, template=None):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
|
||||
# Template is ignored because records are required to be scalars.
|
||||
del template
|
||||
|
||||
return self._below_estimate_query.initial_sample_state(0.0)
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
|
||||
tf.debugging.assert_scalar(record)
|
||||
|
||||
# We accumulate counts shifted by 0.5 so they are centered at zero.
|
||||
# This makes the sensitivity of the count query 0.5 instead of 1.0.
|
||||
# Shift counts by 0.5 so they are centered at zero. (See comment in
|
||||
# `_construct_below_estimate_query`.)
|
||||
below = tf.cast(record <= params.current_estimate, tf.float32) - 0.5
|
||||
return self._below_estimate_query.preprocess_record(
|
||||
params.below_estimate_params, below)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
below_estimate_result, new_below_estimate_state = (
|
||||
self._below_estimate_query.get_noised_result(
|
||||
sample_state, global_state.below_estimate_state))
|
||||
|
||||
# Unshift below_estimate percentile by 0.5. (See comment in initializer.)
|
||||
# Unshift below_estimate percentile by 0.5. (See comment in
|
||||
# `_construct_below_estimate_query`.)
|
||||
below_estimate = below_estimate_result + 0.5
|
||||
|
||||
# Protect against out-of-range estimates.
|
||||
below_estimate = tf.minimum(1.0, tf.maximum(0.0, below_estimate))
|
||||
|
||||
# Loss function is convex, with derivative in [-1, 1], and minimized when
|
||||
# the true quantile matches the target.
|
||||
loss_grad = below_estimate - global_state.target_quantile
|
||||
|
||||
update = global_state.learning_rate * loss_grad
|
||||
|
@ -161,6 +165,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
|||
return new_estimate, new_global_state
|
||||
|
||||
def derive_metrics(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
|
||||
return collections.OrderedDict(estimate=global_state.current_estimate)
|
||||
|
||||
|
||||
|
@ -168,7 +173,7 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
|
|||
"""Iterative process to estimate target quantile of a univariate distribution.
|
||||
|
||||
Unlike the base class, this uses a NoPrivacyQuery to estimate the fraction
|
||||
below estimate with an exact denominator.
|
||||
below estimate with an exact denominator, so there are no privacy guarantees.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -185,7 +190,8 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
|
|||
estimate each round.
|
||||
learning_rate: The learning rate. A rate of r means that the estimate will
|
||||
change by a maximum of r at each step (for arithmetic updating) or by a
|
||||
maximum factor of exp(r) (for geometric updating).
|
||||
maximum factor of exp(r) (for geometric updating). Andrew et al.
|
||||
recommends that this be set to 0.2 for geometric updating.
|
||||
geometric_update: If True, use geometric updating of estimate. Geometric
|
||||
updating is preferred for non-negative records like vector norms that
|
||||
could potentially be very large or very close to zero.
|
||||
|
|
|
@ -27,29 +27,63 @@ import tensorflow as tf
|
|||
|
||||
|
||||
class ValueGenerator(metaclass=abc.ABCMeta):
|
||||
"""Base class establishing interface for stateful value generation."""
|
||||
"""Base class establishing interface for stateful value generation.
|
||||
|
||||
A `ValueGenerator` maintains a state, and each time `next` is called, a new
|
||||
value is generated and the state is advanced.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def initialize(self):
|
||||
"""Returns initialized state."""
|
||||
"""Makes an initialized state for the ValueGenerator.
|
||||
|
||||
Returns:
|
||||
An initial state.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def next(self, state):
|
||||
"""Returns tree node value and updated state."""
|
||||
"""Gets next value and advances the ValueGenerator.
|
||||
|
||||
Args:
|
||||
state: The current state.
|
||||
|
||||
Returns:
|
||||
A pair (value, new_state) where value is the next value and new_state
|
||||
is the advanced state.
|
||||
"""
|
||||
|
||||
|
||||
class GaussianNoiseGenerator(ValueGenerator):
|
||||
"""Gaussian noise generator with counter as pseudo state."""
|
||||
"""Gaussian noise generator with counter as pseudo state.
|
||||
|
||||
Produces i.i.d. spherical Gaussian noise at each step shaped according to a
|
||||
nested structure of `tf.TensorSpec`s.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
noise_std: float,
|
||||
specs: Collection[tf.TensorSpec],
|
||||
seed: Optional[int] = None):
|
||||
"""Initializes the GaussianNoiseGenerator.
|
||||
|
||||
Args:
|
||||
noise_std: The standard deviation of the noise.
|
||||
specs: A nested structure of `tf.TensorSpec`s specifying the shape of the
|
||||
noise to generate.
|
||||
seed: An optional integer seed. If None, generator is seeded from the
|
||||
clock.
|
||||
"""
|
||||
self.noise_std = noise_std
|
||||
self.specs = specs
|
||||
self.seed = seed
|
||||
|
||||
def initialize(self):
|
||||
"""Makes an initial state for the GaussianNoiseGenerator.
|
||||
|
||||
Returns:
|
||||
An initial state.
|
||||
"""
|
||||
if self.seed is None:
|
||||
return tf.cast(
|
||||
tf.stack([
|
||||
|
@ -61,6 +95,15 @@ class GaussianNoiseGenerator(ValueGenerator):
|
|||
return tf.constant(self.seed, dtype=tf.int64, shape=(2,))
|
||||
|
||||
def next(self, state):
|
||||
"""Gets next value and advances the GaussianNoiseGenerator.
|
||||
|
||||
Args:
|
||||
state: The current state.
|
||||
|
||||
Returns:
|
||||
A pair (sample, new_state) where sample is a new sample and new_state
|
||||
is the advanced state.
|
||||
"""
|
||||
flat_structure = tf.nest.flatten(self.specs)
|
||||
flat_seeds = [state + i for i in range(len(flat_structure))]
|
||||
nest_seeds = tf.nest.pack_sequence_as(self.specs, flat_seeds)
|
||||
|
@ -74,15 +117,34 @@ class GaussianNoiseGenerator(ValueGenerator):
|
|||
|
||||
|
||||
class StatelessValueGenerator(ValueGenerator):
|
||||
"""A wrapper for stateless value generator initialized by a no-arg function."""
|
||||
"""A wrapper for stateless value generator that calls a no-arg function."""
|
||||
|
||||
def __init__(self, value_fn):
|
||||
"""Initializes the StatelessValueGenerator.
|
||||
|
||||
Args:
|
||||
value_fn: The function to call to generate values.
|
||||
"""
|
||||
self.value_fn = value_fn
|
||||
|
||||
def initialize(self):
|
||||
"""Makes an initialized state for the StatelessValueGenerator.
|
||||
|
||||
Returns:
|
||||
An initial state (empty, because stateless).
|
||||
"""
|
||||
return ()
|
||||
|
||||
def next(self, state):
|
||||
"""Gets next value.
|
||||
|
||||
Args:
|
||||
state: The current state (simply passed through).
|
||||
|
||||
Returns:
|
||||
A pair (value, new_state) where value is the next value and new_state
|
||||
is the advanced state.
|
||||
"""
|
||||
return self.value_fn(), state
|
||||
|
||||
|
||||
|
@ -127,7 +189,12 @@ class TreeAggregator():
|
|||
"""
|
||||
|
||||
def __init__(self, value_generator: Union[ValueGenerator, Callable[[], Any]]):
|
||||
"""Initialize the aggregator with a noise generator."""
|
||||
"""Initialize the aggregator with a noise generator.
|
||||
|
||||
Args:
|
||||
value_generator: A `ValueGenerator` or a no-arg function to generate a
|
||||
noise value for each tree node.
|
||||
"""
|
||||
if isinstance(value_generator, ValueGenerator):
|
||||
self.value_generator = value_generator
|
||||
else:
|
||||
|
@ -235,7 +302,7 @@ class EfficientTreeAggregator():
|
|||
|
||||
This class implements the efficient tree aggregation algorithm based on
|
||||
Honaker 2015 "Efficient Use of Differentially Private Binary Trees".
|
||||
The noise standard deviation for the note at depth d is roughly
|
||||
The noise standard deviation for a node at depth d is roughly
|
||||
`sigma * sqrt(2^{d-1}/(2^d-1))`. which becomes `sigma / sqrt(2)` when
|
||||
the tree is very tall.
|
||||
|
||||
|
@ -245,7 +312,12 @@ class EfficientTreeAggregator():
|
|||
"""
|
||||
|
||||
def __init__(self, value_generator: Union[ValueGenerator, Callable[[], Any]]):
|
||||
"""Initialize the aggregator with a noise generator."""
|
||||
"""Initialize the aggregator with a noise generator.
|
||||
|
||||
Args:
|
||||
value_generator: A `ValueGenerator` or a no-arg function to generate a
|
||||
noise value for each tree node.
|
||||
"""
|
||||
if isinstance(value_generator, ValueGenerator):
|
||||
self.value_generator = value_generator
|
||||
else:
|
||||
|
@ -257,6 +329,9 @@ class EfficientTreeAggregator():
|
|||
Initializes `TreeState` for a tree of a single leaf node: the respective
|
||||
initial node value in `TreeState.level_buffer` is generated by the value
|
||||
generator function, and the node index is 0.
|
||||
|
||||
Returns:
|
||||
An initialized `TreeState`.
|
||||
"""
|
||||
value_generator_state = self.value_generator.initialize()
|
||||
level_buffer_idx = tf.TensorArray(dtype=tf.int32, size=1, dynamic_size=True)
|
||||
|
|
|
@ -68,7 +68,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
`TreeCumulativeSumQuery` with L2 norm clipping and Gaussian noise.
|
||||
|
||||
Args:
|
||||
record_specs: `Collection[tf.TensorSpec]` specifying shapes of records.
|
||||
record_specs: A nested structure of `tf.TensorSpec`s specifying structure
|
||||
and shapes of records.
|
||||
noise_generator: `tree_aggregation.ValueGenerator` to generate the noise
|
||||
value for a tree node. Should be coupled with clipping norm to guarantee
|
||||
privacy.
|
||||
|
@ -89,7 +90,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Returns initial global state."""
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
initial_tree_state = self._tree_aggregator.init_state()
|
||||
initial_samples_cumulative_sum = tf.nest.map_structure(
|
||||
lambda spec: tf.zeros(spec.shape), self._record_specs)
|
||||
|
@ -100,10 +101,11 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
return initial_state
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
return global_state.clip_value
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""Returns the clipped record using `clip_fn` and params.
|
||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
|
||||
|
||||
Args:
|
||||
params: `clip_value` for the record.
|
||||
|
@ -118,14 +120,16 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
return tf.nest.pack_sequence_as(record, clipped_as_list)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Updates tree, state, and returns noised cumulative sum and updated state.
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
|
||||
|
||||
Computes new cumulative sum, and returns its noised value. Grows tree_state
|
||||
Updates tree state, and returns noised cumulative sum and updated state.
|
||||
|
||||
Computes new cumulative sum, and returns its noised value. Grows tree state
|
||||
by one new leaf, and returns the new state.
|
||||
|
||||
Args:
|
||||
sample_state: Sum of clipped records for this round.
|
||||
global_state: Global state with current samples cumulative sum and tree
|
||||
global_state: Global state with current sample's cumulative sum and tree
|
||||
state.
|
||||
|
||||
Returns:
|
||||
|
@ -157,7 +161,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
`clip_norm`.
|
||||
noise_multiplier: The effective noise multiplier for the sum of records.
|
||||
Noise standard deviation is `clip_norm*noise_multiplier`.
|
||||
record_specs: `Collection[tf.TensorSpec]` specifying shapes of records.
|
||||
record_specs: A nested structure of `tf.TensorSpec`s specifying structure
|
||||
and shapes of records.
|
||||
noise_seed: Integer seed for the Gaussian noise generator. If `None`, a
|
||||
nondeterministic seed based on system time will be generated.
|
||||
use_efficient: Boolean indicating the usage of the efficient tree
|
||||
|
@ -204,9 +209,10 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
Attributes:
|
||||
clip_fn: Callable that specifies clipping function. `clip_fn` receives two
|
||||
arguments: a flat list of vars in a record and a `clip_value` to clip the
|
||||
corresponding record, e.g. clip_fn(flat_record, clip_value).
|
||||
corresponding record, e.g. clip_fn(flat_record, clip_value).
|
||||
clip_value: float indicating the value at which to clip the record.
|
||||
record_specs: `Collection[tf.TensorSpec]` specifying shapes of records.
|
||||
record_specs: A nested structure of `tf.TensorSpec`s specifying structure
|
||||
and shapes of records.
|
||||
tree_aggregator: `tree_aggregation.TreeAggregator` initialized with user
|
||||
defined `noise_generator`. `noise_generator` is a
|
||||
`tree_aggregation.ValueGenerator` to generate the noise value for a tree
|
||||
|
@ -242,7 +248,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
`TreeResidualSumQuery` with L2 norm clipping and Gaussian noise.
|
||||
|
||||
Args:
|
||||
record_specs: `Collection[tf.TensorSpec]` specifying shapes of records.
|
||||
record_specs: A nested structure of `tf.TensorSpec`s specifying structure
|
||||
and shapes of records.
|
||||
noise_generator: `tree_aggregation.ValueGenerator` to generate the noise
|
||||
value for a tree node. Should be coupled with clipping norm to guarantee
|
||||
privacy.
|
||||
|
@ -263,7 +270,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Returns initial global state."""
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
initial_tree_state = self._tree_aggregator.init_state()
|
||||
initial_noise = tf.nest.map_structure(lambda spec: tf.zeros(spec.shape),
|
||||
self._record_specs)
|
||||
|
@ -273,10 +280,11 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
previous_tree_noise=initial_noise)
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
return global_state.clip_value
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""Returns the clipped record using `clip_fn` and params.
|
||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
|
||||
|
||||
Args:
|
||||
params: `clip_value` for the record.
|
||||
|
@ -291,7 +299,9 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
return tf.nest.pack_sequence_as(record, clipped_as_list)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Updates tree state, and returns residual of noised cumulative sum.
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
|
||||
|
||||
Updates tree state, and returns residual of noised cumulative sum.
|
||||
|
||||
Args:
|
||||
sample_state: Sum of clipped records for this round.
|
||||
|
@ -324,7 +334,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
`clip_norm`.
|
||||
noise_multiplier: The effective noise multiplier for the sum of records.
|
||||
Noise standard deviation is `clip_norm*noise_multiplier`.
|
||||
record_specs: `Collection[tf.TensorSpec]` specifying shapes of records.
|
||||
record_specs: A nested structure of `tf.TensorSpec`s specifying structure
|
||||
and shapes of records.
|
||||
noise_seed: Integer seed for the Gaussian noise generator. If `None`, a
|
||||
nondeterministic seed based on system time will be generated.
|
||||
use_efficient: Boolean indicating the usage of the efficient tree
|
||||
|
|
Loading…
Reference in a new issue