Add DpEvent to return value of get_noised_result. For most DPQueries, the default UnsupportedDpEvent is returned, pending further development.

PiperOrigin-RevId: 394137614
This commit is contained in:
Galen Andrew 2021-08-31 19:26:51 -07:00 committed by A. Unique TensorFlower
parent 6ac4bc8d01
commit 7e7736ea91
18 changed files with 86 additions and 94 deletions

View file

@ -16,6 +16,7 @@
import collections
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
from tensorflow_privacy.privacy.dp_query import dp_query
@ -81,4 +82,6 @@ class DiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
# Ensure shape as TF shape inference may fail due to custom noise sampler.
return tf.ensure_shape(noised_v, v.shape)
return tf.nest.map_structure(add_noise, sample_state), global_state
result = tf.nest.map_structure(add_noise, sample_state)
event = dp_event.UnsupportedDpEvent()
return result, global_state, event

View file

@ -16,6 +16,7 @@
import collections
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
from tensorflow_privacy.privacy.dp_query import dp_query
@ -106,4 +107,5 @@ class DistributedDiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state):
# Note that by directly returning the aggregate, this assumes that there
# will not be missing local noise shares during execution.
return sample_state, global_state
event = dp_event.UnsupportedDpEvent()
return sample_state, global_state, event

View file

@ -246,11 +246,14 @@ class DPQuery(object):
global_state: The global state, storing long-term privacy bookkeeping.
Returns:
A tuple (result, new_global_state) where "result" is the result of the
query and "new_global_state" is the updated global state. In standard
DP-SGD training, the result is a gradient update comprising a noised
average of the clipped gradients in the sample state---with the noise and
averaging performed in a manner that guarantees differential privacy.
A tuple `(result, new_global_state, event)` where:
* `result` is the result of the query,
* `new_global_state` is the updated global state, and
* `event` is the `DpEvent` that occurred.
In standard DP-SGD training, the result is a gradient update comprising a
noised average of the clipped gradients in the sample state---with the
noise and averaging performed in a manner that guarantees differential
privacy.
"""
pass
@ -297,7 +300,3 @@ class SumAggregationDPQuery(DPQuery):
def merge_sample_states(self, sample_state_1, sample_state_2):
"""Implements `tensorflow_privacy.DPQuery.merge_sample_states`."""
return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2)
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
return sample_state, global_state

View file

@ -22,6 +22,7 @@ import distutils
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import dp_query
@ -45,7 +46,6 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
"""
self._l2_norm_clip = l2_norm_clip
self._stddev = stddev
self._ledger = None
def make_global_state(self, l2_norm_clip, stddev):
"""Creates a global state from the given parameters."""
@ -96,12 +96,8 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
def add_noise(v):
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
if self._ledger:
dependencies = [
self._ledger.record_sum_query(global_state.l2_norm_clip,
global_state.stddev)
]
else:
dependencies = []
with tf.control_dependencies(dependencies):
return tf.nest.map_structure(add_noise, sample_state), global_state
result = tf.nest.map_structure(add_noise, sample_state)
noise_multiplier = global_state.stddev / global_state.l2_norm_clip
event = dp_event.GaussianDpEvent(noise_multiplier)
return result, global_state, event

View file

@ -20,6 +20,8 @@ from __future__ import print_function
import collections
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import dp_query
import tree
@ -96,13 +98,15 @@ class NestedQuery(dp_query.DPQuery):
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
estimates_and_new_global_states = self._map_to_queries(
'get_noised_result', sample_state, global_state)
mapped_query_results = self._map_to_queries('get_noised_result',
sample_state, global_state)
flat_estimates, flat_new_global_states, flat_events = zip(
*tree.flatten_up_to(self._queries, mapped_query_results))
flat_estimates, flat_new_global_states = zip(
*tree.flatten_up_to(self._queries, estimates_and_new_global_states))
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
tf.nest.pack_sequence_as(self._queries, flat_new_global_states),
dp_event.ComposedDpEvent(events=flat_events))
def derive_metrics(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""

View file

@ -19,6 +19,7 @@ from __future__ import print_function
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import dp_query
@ -28,19 +29,9 @@ class NoPrivacySumQuery(dp_query.SumAggregationDPQuery):
Accumulates vectors without clipping or adding noise.
"""
def __init__(self):
self._ledger = None
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
if self._ledger:
dependencies = [self._ledger.record_sum_query(float('inf'), 0.0)]
else:
dependencies = []
with tf.control_dependencies(dependencies):
return sample_state, global_state
return sample_state, global_state, dp_event.NonPrivateDpEvent()
class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
@ -56,10 +47,6 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
privatized.
"""
def __init__(self):
"""Initializes the NoPrivacyAverageQuery."""
self._ledger = None
def initial_sample_state(self, template):
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
return (super(NoPrivacyAverageQuery,
@ -103,11 +90,5 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
sum_state, denominator = sample_state
if self._ledger:
dependencies = [self._ledger.record_sum_query(float('inf'), 0.0)]
else:
dependencies = []
with tf.control_dependencies(dependencies):
return (tf.nest.map_structure(lambda t: t / denominator,
sum_state), global_state)
result = tf.nest.map_structure(lambda t: t / denominator, sum_state)
return result, global_state, dp_event.NonPrivateDpEvent()

View file

@ -74,14 +74,16 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
noised_sum, new_sum_global_state, event = self._numerator.get_noised_result(
sample_state, global_state.numerator_state)
def normalize(v):
return tf.truediv(v, global_state.denominator)
# The denominator is constant so the privacy cost comes from the numerator.
return (tf.nest.map_structure(normalize, noised_sum),
self._GlobalState(new_sum_global_state, global_state.denominator))
self._GlobalState(new_sum_global_state,
global_state.denominator), event)
def derive_metrics(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""

View file

@ -21,6 +21,7 @@ import collections
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import quantile_estimator_query
@ -123,11 +124,11 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
noised_vectors, sum_state = self._sum_query.get_noised_result(
noised_vectors, sum_state, sum_event = self._sum_query.get_noised_result(
sample_state.sum_state, global_state.sum_state)
del sum_state # To be set explicitly later when we know the new clip.
new_l2_norm_clip, new_quantile_estimator_state = (
new_l2_norm_clip, new_quantile_estimator_state, quantile_event = (
self._quantile_estimator_query.get_noised_result(
sample_state.quantile_estimator_state,
global_state.quantile_estimator_state))
@ -141,7 +142,8 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
new_sum_query_state,
new_quantile_estimator_state)
return noised_vectors, new_global_state
event = dp_event.ComposedDpEvent(events=[sum_event, quantile_event])
return noised_vectors, new_global_state, event
def derive_metrics(self, global_state):
"""Returns the current clipping norm as a metric."""

View file

@ -135,7 +135,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
below_estimate_result, new_below_estimate_state = (
below_estimate_result, new_below_estimate_state, below_estimate_event = (
self._below_estimate_query.get_noised_result(
sample_state, global_state.below_estimate_state))
@ -159,7 +159,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
current_estimate=new_estimate,
below_estimate_state=new_below_estimate_state)
return new_estimate, new_global_state
return new_estimate, new_global_state, below_estimate_event
def derive_metrics(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""

View file

@ -134,14 +134,14 @@ class RestartQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
noised_results, inner_query_state = self._inner_query.get_noised_result(
noised_results, inner_state, event = self._inner_query.get_noised_result(
sample_state, global_state.inner_query_state)
restart_flag, indicator_state = self._restart_indicator.next(
global_state.indicator_state)
if restart_flag:
inner_query_state = self._inner_query.reset_state(noised_results,
inner_query_state)
return noised_results, self._GlobalState(inner_query_state, indicator_state)
inner_state = self._inner_query.reset_state(noised_results, inner_state)
return (noised_results, self._GlobalState(inner_state,
indicator_state), event)
def derive_metrics(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""

View file

@ -75,7 +75,7 @@ class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result(
query_result, global_state, _ = query.get_noised_result(
sample_state, global_state)
# Expected value is the combination of cumsum of signal; sum of trees
# that have been reset; current tree sum. The tree aggregation value can
@ -110,7 +110,7 @@ class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result(
query_result, global_state, _ = query.get_noised_result(
sample_state, global_state)
# Expected value is the signal of the current round plus the residual of
# two continous tree aggregation values. The tree aggregation value can

View file

@ -44,6 +44,7 @@ def run_query(query, records, global_state=None, weights=None):
sample_state = query.accumulate_record(params, sample_state, record)
else:
for weight, record in zip(weights, records):
sample_state = query.accumulate_record(
params, sample_state, record, weight)
return query.get_noised_result(sample_state, global_state)
sample_state = query.accumulate_record(params, sample_state, record,
weight)
result, global_state, _ = query.get_noised_result(sample_state, global_state)
return result, global_state

View file

@ -24,12 +24,12 @@ import math
import attr
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation
# TODO(b/193679963): define `RestartQuery` and move `RestartIndicator` to be
# in the same module.
@ -57,7 +57,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
for j,sample in enumerate(samples):
sample_state = query.accumulate_record(params, sample_state, sample)
# noised_cumsum is privatized estimate of s_i
noised_cumsum, global_state = query.get_noised_result(
noised_cumsum, global_state, event = query.get_noised_result(
sample_state, global_state)
Attributes:
@ -176,7 +176,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
global_state,
samples_cumulative_sum=new_cumulative_sum,
tree_state=new_tree_state)
return noised_cumulative_sum, new_global_state
event = dp_event.UnsupportedDpEvent()
return noised_cumulative_sum, new_global_state, event
def reset_state(self, noised_results, global_state):
"""Returns state after resetting the tree.
@ -281,7 +282,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
sample_state = query.accumulate_record(params, sample_state, sample)
# noised_sum is privatized estimate of x_i by conceptually postprocessing
# noised cumulative sum s_i
noised_sum, global_state = query.get_noised_result(
noised_sum, global_state, event = query.get_noised_result(
sample_state, global_state)
Attributes:
@ -398,7 +399,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
global_state.previous_tree_noise)
new_global_state = attr.evolve(
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
return noised_sample, new_global_state
event = dp_event.UnsupportedDpEvent()
return noised_sample, new_global_state, event
def reset_state(self, noised_results, global_state):
"""Returns state after resetting the tree.
@ -636,7 +638,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
# This part is not written in tensorflow and will be executed on the server
# side instead of the client side if used with
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
sample_state, inner_query_state = self._inner_query.get_noised_result(
sample_state, inner_query_state, _ = self._inner_query.get_noised_result(
sample_state, global_state.inner_query_state)
new_global_state = TreeRangeSumQuery.GlobalState(
arity=global_state.arity, inner_query_state=inner_query_state)
@ -647,7 +649,8 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
]
tree = tf.RaggedTensor.from_row_splits(
values=sample_state, row_splits=row_splits)
return tree, new_global_state
event = dp_event.UnsupportedDpEvent()
return tree, new_global_state, event
@classmethod
def build_central_gaussian_query(cls,

View file

@ -258,7 +258,7 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
for scalar, expected_sum in zip(streaming_scalars, partial_sum):
sample_state = query.initial_sample_state(scalar)
sample_state = query.accumulate_record(params, sample_state, scalar)
query_result, global_state = query.get_noised_result(
query_result, global_state, _ = query.get_noised_result(
sample_state, global_state)
self.assertEqual(query_result, expected_sum)
@ -282,7 +282,7 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result(
query_result, global_state, _ = query.get_noised_result(
sample_state, global_state)
# For each streaming step i , the expected value is roughly
# `scalar_value*(i+1) + tree_aggregation(tree_node_value, i)`.
@ -314,7 +314,7 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result(
query_result, global_state, _ = query.get_noised_result(
sample_state, global_state)
if i % frequency == frequency - 1:
global_state = query.reset_state(query_result, global_state)
@ -456,7 +456,7 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result(
query_result, global_state, _ = query.get_noised_result(
sample_state, global_state)
if i % frequency == frequency - 1:
global_state = query.reset_state(query_result, global_state)
@ -609,7 +609,7 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state = query.get_noised_result(
sample_state, global_state, _ = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(sample_state, expected_tree)
@ -621,7 +621,7 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.]))
sample_state, global_state = query.get_noised_result(
sample_state, global_state, _ = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(

View file

@ -20,6 +20,7 @@ import math
import attr
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query
@ -189,7 +190,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
# This part is not written in tensorflow and will be executed on the server
# side instead of the client side if used with
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
sample_state, inner_query_state = self._inner_query.get_noised_result(
sample_state, inner_query_state, _ = self._inner_query.get_noised_result(
sample_state, global_state.inner_query_state)
new_global_state = TreeRangeSumQuery.GlobalState(
arity=global_state.arity, inner_query_state=inner_query_state)
@ -200,7 +201,8 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
]
tree = tf.RaggedTensor.from_row_splits(
values=sample_state, row_splits=row_splits)
return tree, new_global_state
event = dp_event.UnsupportedDpEvent()
return tree, new_global_state, event
@classmethod
def build_central_gaussian_query(cls,

View file

@ -159,7 +159,7 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state = query.get_noised_result(
sample_state, global_state, _ = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(sample_state, expected_tree)
@ -171,7 +171,7 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.]))
sample_state, global_state = query.get_noised_result(
sample_state, global_state, _ = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(

View file

@ -164,7 +164,7 @@ def make_optimizer_class(cls):
for idx in range(self._num_microbatches):
sample_state = process_microbatch(idx, sample_state)
grad_sums, self._global_state = (
grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state,
self._global_state))
@ -235,7 +235,7 @@ def make_optimizer_class(cls):
_, sample_state = tf.while_loop(
cond=cond_fn, body=body_fn, loop_vars=[idx, sample_state])
grad_sums, self._global_state = (
grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state,
self._global_state))
@ -363,10 +363,6 @@ def make_gaussian_optimizer_class(cls):
return config
@property
def ledger(self):
return self._dp_sum_query.ledger
return DPGaussianOptimizerClass

View file

@ -81,9 +81,10 @@ def make_keras_optimizer_class(cls):
model.fit(...)
```
""".format(base_class='tf.keras.optimizers.' + cls.__name__,
short_base_class=cls.__name__,
dp_keras_class='DPKeras' + cls.__name__)
""".format(
base_class='tf.keras.optimizers.' + cls.__name__,
short_base_class=cls.__name__,
dp_keras_class='DPKeras' + cls.__name__)
# The class tf.keras.optimizers.Optimizer has two methods to compute
# gradients, `_compute_gradients` and `get_gradients`. The first works
@ -106,8 +107,8 @@ def make_keras_optimizer_class(cls):
Args:
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
num_microbatches: Number of microbatches into which each minibatch
is split.
num_microbatches: Number of microbatches into which each minibatch is
split.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
"""
@ -210,7 +211,7 @@ def make_keras_optimizer_class(cls):
sample_state = self._dp_sum_query.initial_sample_state(params)
for idx in range(self._num_microbatches):
sample_state = process_microbatch(idx, sample_state)
grad_sums, self._global_state = (
grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state,
self._global_state))