Add DpEvent to return value of get_noised_result. For most DPQueries, the default UnsupportedDpEvent is returned, pending further development.

PiperOrigin-RevId: 394137614
This commit is contained in:
Galen Andrew 2021-08-31 19:26:51 -07:00 committed by A. Unique TensorFlower
parent 6ac4bc8d01
commit 7e7736ea91
18 changed files with 86 additions and 94 deletions

View file

@ -16,6 +16,7 @@
import collections import collections
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
@ -81,4 +82,6 @@ class DiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
# Ensure shape as TF shape inference may fail due to custom noise sampler. # Ensure shape as TF shape inference may fail due to custom noise sampler.
return tf.ensure_shape(noised_v, v.shape) return tf.ensure_shape(noised_v, v.shape)
return tf.nest.map_structure(add_noise, sample_state), global_state result = tf.nest.map_structure(add_noise, sample_state)
event = dp_event.UnsupportedDpEvent()
return result, global_state, event

View file

@ -16,6 +16,7 @@
import collections import collections
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
@ -106,4 +107,5 @@ class DistributedDiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state): def get_noised_result(self, sample_state, global_state):
# Note that by directly returning the aggregate, this assumes that there # Note that by directly returning the aggregate, this assumes that there
# will not be missing local noise shares during execution. # will not be missing local noise shares during execution.
return sample_state, global_state event = dp_event.UnsupportedDpEvent()
return sample_state, global_state, event

View file

@ -246,11 +246,14 @@ class DPQuery(object):
global_state: The global state, storing long-term privacy bookkeeping. global_state: The global state, storing long-term privacy bookkeeping.
Returns: Returns:
A tuple (result, new_global_state) where "result" is the result of the A tuple `(result, new_global_state, event)` where:
query and "new_global_state" is the updated global state. In standard * `result` is the result of the query,
DP-SGD training, the result is a gradient update comprising a noised * `new_global_state` is the updated global state, and
average of the clipped gradients in the sample state---with the noise and * `event` is the `DpEvent` that occurred.
averaging performed in a manner that guarantees differential privacy. In standard DP-SGD training, the result is a gradient update comprising a
noised average of the clipped gradients in the sample state---with the
noise and averaging performed in a manner that guarantees differential
privacy.
""" """
pass pass
@ -297,7 +300,3 @@ class SumAggregationDPQuery(DPQuery):
def merge_sample_states(self, sample_state_1, sample_state_2): def merge_sample_states(self, sample_state_1, sample_state_2):
"""Implements `tensorflow_privacy.DPQuery.merge_sample_states`.""" """Implements `tensorflow_privacy.DPQuery.merge_sample_states`."""
return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2) return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2)
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
return sample_state, global_state

View file

@ -22,6 +22,7 @@ import distutils
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
@ -45,7 +46,6 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
""" """
self._l2_norm_clip = l2_norm_clip self._l2_norm_clip = l2_norm_clip
self._stddev = stddev self._stddev = stddev
self._ledger = None
def make_global_state(self, l2_norm_clip, stddev): def make_global_state(self, l2_norm_clip, stddev):
"""Creates a global state from the given parameters.""" """Creates a global state from the given parameters."""
@ -96,12 +96,8 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
def add_noise(v): def add_noise(v):
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype) return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
if self._ledger: result = tf.nest.map_structure(add_noise, sample_state)
dependencies = [ noise_multiplier = global_state.stddev / global_state.l2_norm_clip
self._ledger.record_sum_query(global_state.l2_norm_clip, event = dp_event.GaussianDpEvent(noise_multiplier)
global_state.stddev)
] return result, global_state, event
else:
dependencies = []
with tf.control_dependencies(dependencies):
return tf.nest.map_structure(add_noise, sample_state), global_state

View file

@ -20,6 +20,8 @@ from __future__ import print_function
import collections import collections
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
import tree import tree
@ -96,13 +98,15 @@ class NestedQuery(dp_query.DPQuery):
def get_noised_result(self, sample_state, global_state): def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" """Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
estimates_and_new_global_states = self._map_to_queries( mapped_query_results = self._map_to_queries('get_noised_result',
'get_noised_result', sample_state, global_state) sample_state, global_state)
flat_estimates, flat_new_global_states, flat_events = zip(
*tree.flatten_up_to(self._queries, mapped_query_results))
flat_estimates, flat_new_global_states = zip(
*tree.flatten_up_to(self._queries, estimates_and_new_global_states))
return (tf.nest.pack_sequence_as(self._queries, flat_estimates), return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
tf.nest.pack_sequence_as(self._queries, flat_new_global_states)) tf.nest.pack_sequence_as(self._queries, flat_new_global_states),
dp_event.ComposedDpEvent(events=flat_events))
def derive_metrics(self, global_state): def derive_metrics(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`.""" """Implements `tensorflow_privacy.DPQuery.derive_metrics`."""

View file

@ -19,6 +19,7 @@ from __future__ import print_function
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
@ -28,19 +29,9 @@ class NoPrivacySumQuery(dp_query.SumAggregationDPQuery):
Accumulates vectors without clipping or adding noise. Accumulates vectors without clipping or adding noise.
""" """
def __init__(self):
self._ledger = None
def get_noised_result(self, sample_state, global_state): def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" """Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
return sample_state, global_state, dp_event.NonPrivateDpEvent()
if self._ledger:
dependencies = [self._ledger.record_sum_query(float('inf'), 0.0)]
else:
dependencies = []
with tf.control_dependencies(dependencies):
return sample_state, global_state
class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery): class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
@ -56,10 +47,6 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
privatized. privatized.
""" """
def __init__(self):
"""Initializes the NoPrivacyAverageQuery."""
self._ledger = None
def initial_sample_state(self, template): def initial_sample_state(self, template):
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`.""" """Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
return (super(NoPrivacyAverageQuery, return (super(NoPrivacyAverageQuery,
@ -103,11 +90,5 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" """Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
sum_state, denominator = sample_state sum_state, denominator = sample_state
if self._ledger: result = tf.nest.map_structure(lambda t: t / denominator, sum_state)
dependencies = [self._ledger.record_sum_query(float('inf'), 0.0)] return result, global_state, dp_event.NonPrivateDpEvent()
else:
dependencies = []
with tf.control_dependencies(dependencies):
return (tf.nest.map_structure(lambda t: t / denominator,
sum_state), global_state)

View file

@ -74,14 +74,16 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state): def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" """Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
noised_sum, new_sum_global_state = self._numerator.get_noised_result( noised_sum, new_sum_global_state, event = self._numerator.get_noised_result(
sample_state, global_state.numerator_state) sample_state, global_state.numerator_state)
def normalize(v): def normalize(v):
return tf.truediv(v, global_state.denominator) return tf.truediv(v, global_state.denominator)
# The denominator is constant so the privacy cost comes from the numerator.
return (tf.nest.map_structure(normalize, noised_sum), return (tf.nest.map_structure(normalize, noised_sum),
self._GlobalState(new_sum_global_state, global_state.denominator)) self._GlobalState(new_sum_global_state,
global_state.denominator), event)
def derive_metrics(self, global_state): def derive_metrics(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`.""" """Implements `tensorflow_privacy.DPQuery.derive_metrics`."""

View file

@ -21,6 +21,7 @@ import collections
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import quantile_estimator_query from tensorflow_privacy.privacy.dp_query import quantile_estimator_query
@ -123,11 +124,11 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state): def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" """Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
noised_vectors, sum_state = self._sum_query.get_noised_result( noised_vectors, sum_state, sum_event = self._sum_query.get_noised_result(
sample_state.sum_state, global_state.sum_state) sample_state.sum_state, global_state.sum_state)
del sum_state # To be set explicitly later when we know the new clip. del sum_state # To be set explicitly later when we know the new clip.
new_l2_norm_clip, new_quantile_estimator_state = ( new_l2_norm_clip, new_quantile_estimator_state, quantile_event = (
self._quantile_estimator_query.get_noised_result( self._quantile_estimator_query.get_noised_result(
sample_state.quantile_estimator_state, sample_state.quantile_estimator_state,
global_state.quantile_estimator_state)) global_state.quantile_estimator_state))
@ -141,7 +142,8 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
new_sum_query_state, new_sum_query_state,
new_quantile_estimator_state) new_quantile_estimator_state)
return noised_vectors, new_global_state event = dp_event.ComposedDpEvent(events=[sum_event, quantile_event])
return noised_vectors, new_global_state, event
def derive_metrics(self, global_state): def derive_metrics(self, global_state):
"""Returns the current clipping norm as a metric.""" """Returns the current clipping norm as a metric."""

View file

@ -135,7 +135,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state): def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" """Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
below_estimate_result, new_below_estimate_state = ( below_estimate_result, new_below_estimate_state, below_estimate_event = (
self._below_estimate_query.get_noised_result( self._below_estimate_query.get_noised_result(
sample_state, global_state.below_estimate_state)) sample_state, global_state.below_estimate_state))
@ -159,7 +159,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
current_estimate=new_estimate, current_estimate=new_estimate,
below_estimate_state=new_below_estimate_state) below_estimate_state=new_below_estimate_state)
return new_estimate, new_global_state return new_estimate, new_global_state, below_estimate_event
def derive_metrics(self, global_state): def derive_metrics(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`.""" """Implements `tensorflow_privacy.DPQuery.derive_metrics`."""

View file

@ -134,14 +134,14 @@ class RestartQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state): def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" """Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
noised_results, inner_query_state = self._inner_query.get_noised_result( noised_results, inner_state, event = self._inner_query.get_noised_result(
sample_state, global_state.inner_query_state) sample_state, global_state.inner_query_state)
restart_flag, indicator_state = self._restart_indicator.next( restart_flag, indicator_state = self._restart_indicator.next(
global_state.indicator_state) global_state.indicator_state)
if restart_flag: if restart_flag:
inner_query_state = self._inner_query.reset_state(noised_results, inner_state = self._inner_query.reset_state(noised_results, inner_state)
inner_query_state) return (noised_results, self._GlobalState(inner_state,
return noised_results, self._GlobalState(inner_query_state, indicator_state) indicator_state), event)
def derive_metrics(self, global_state): def derive_metrics(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`.""" """Implements `tensorflow_privacy.DPQuery.derive_metrics`."""

View file

@ -75,7 +75,7 @@ class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
for i in range(total_steps): for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value) sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value) sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result( query_result, global_state, _ = query.get_noised_result(
sample_state, global_state) sample_state, global_state)
# Expected value is the combination of cumsum of signal; sum of trees # Expected value is the combination of cumsum of signal; sum of trees
# that have been reset; current tree sum. The tree aggregation value can # that have been reset; current tree sum. The tree aggregation value can
@ -110,7 +110,7 @@ class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
for i in range(total_steps): for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value) sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value) sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result( query_result, global_state, _ = query.get_noised_result(
sample_state, global_state) sample_state, global_state)
# Expected value is the signal of the current round plus the residual of # Expected value is the signal of the current round plus the residual of
# two continous tree aggregation values. The tree aggregation value can # two continous tree aggregation values. The tree aggregation value can

View file

@ -44,6 +44,7 @@ def run_query(query, records, global_state=None, weights=None):
sample_state = query.accumulate_record(params, sample_state, record) sample_state = query.accumulate_record(params, sample_state, record)
else: else:
for weight, record in zip(weights, records): for weight, record in zip(weights, records):
sample_state = query.accumulate_record( sample_state = query.accumulate_record(params, sample_state, record,
params, sample_state, record, weight) weight)
return query.get_noised_result(sample_state, global_state) result, global_state, _ = query.get_noised_result(sample_state, global_state)
return result, global_state

View file

@ -24,12 +24,12 @@ import math
import attr import attr
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation from tensorflow_privacy.privacy.dp_query import tree_aggregation
# TODO(b/193679963): define `RestartQuery` and move `RestartIndicator` to be # TODO(b/193679963): define `RestartQuery` and move `RestartIndicator` to be
# in the same module. # in the same module.
@ -57,7 +57,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
for j,sample in enumerate(samples): for j,sample in enumerate(samples):
sample_state = query.accumulate_record(params, sample_state, sample) sample_state = query.accumulate_record(params, sample_state, sample)
# noised_cumsum is privatized estimate of s_i # noised_cumsum is privatized estimate of s_i
noised_cumsum, global_state = query.get_noised_result( noised_cumsum, global_state, event = query.get_noised_result(
sample_state, global_state) sample_state, global_state)
Attributes: Attributes:
@ -176,7 +176,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
global_state, global_state,
samples_cumulative_sum=new_cumulative_sum, samples_cumulative_sum=new_cumulative_sum,
tree_state=new_tree_state) tree_state=new_tree_state)
return noised_cumulative_sum, new_global_state event = dp_event.UnsupportedDpEvent()
return noised_cumulative_sum, new_global_state, event
def reset_state(self, noised_results, global_state): def reset_state(self, noised_results, global_state):
"""Returns state after resetting the tree. """Returns state after resetting the tree.
@ -281,7 +282,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
sample_state = query.accumulate_record(params, sample_state, sample) sample_state = query.accumulate_record(params, sample_state, sample)
# noised_sum is privatized estimate of x_i by conceptually postprocessing # noised_sum is privatized estimate of x_i by conceptually postprocessing
# noised cumulative sum s_i # noised cumulative sum s_i
noised_sum, global_state = query.get_noised_result( noised_sum, global_state, event = query.get_noised_result(
sample_state, global_state) sample_state, global_state)
Attributes: Attributes:
@ -398,7 +399,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
global_state.previous_tree_noise) global_state.previous_tree_noise)
new_global_state = attr.evolve( new_global_state = attr.evolve(
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state) global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
return noised_sample, new_global_state event = dp_event.UnsupportedDpEvent()
return noised_sample, new_global_state, event
def reset_state(self, noised_results, global_state): def reset_state(self, noised_results, global_state):
"""Returns state after resetting the tree. """Returns state after resetting the tree.
@ -636,7 +638,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
# This part is not written in tensorflow and will be executed on the server # This part is not written in tensorflow and will be executed on the server
# side instead of the client side if used with # side instead of the client side if used with
# tff.aggregators.DifferentiallyPrivateFactory for federated learning. # tff.aggregators.DifferentiallyPrivateFactory for federated learning.
sample_state, inner_query_state = self._inner_query.get_noised_result( sample_state, inner_query_state, _ = self._inner_query.get_noised_result(
sample_state, global_state.inner_query_state) sample_state, global_state.inner_query_state)
new_global_state = TreeRangeSumQuery.GlobalState( new_global_state = TreeRangeSumQuery.GlobalState(
arity=global_state.arity, inner_query_state=inner_query_state) arity=global_state.arity, inner_query_state=inner_query_state)
@ -647,7 +649,8 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
] ]
tree = tf.RaggedTensor.from_row_splits( tree = tf.RaggedTensor.from_row_splits(
values=sample_state, row_splits=row_splits) values=sample_state, row_splits=row_splits)
return tree, new_global_state event = dp_event.UnsupportedDpEvent()
return tree, new_global_state, event
@classmethod @classmethod
def build_central_gaussian_query(cls, def build_central_gaussian_query(cls,

View file

@ -258,7 +258,7 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
for scalar, expected_sum in zip(streaming_scalars, partial_sum): for scalar, expected_sum in zip(streaming_scalars, partial_sum):
sample_state = query.initial_sample_state(scalar) sample_state = query.initial_sample_state(scalar)
sample_state = query.accumulate_record(params, sample_state, scalar) sample_state = query.accumulate_record(params, sample_state, scalar)
query_result, global_state = query.get_noised_result( query_result, global_state, _ = query.get_noised_result(
sample_state, global_state) sample_state, global_state)
self.assertEqual(query_result, expected_sum) self.assertEqual(query_result, expected_sum)
@ -282,7 +282,7 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
for i in range(total_steps): for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value) sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value) sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result( query_result, global_state, _ = query.get_noised_result(
sample_state, global_state) sample_state, global_state)
# For each streaming step i , the expected value is roughly # For each streaming step i , the expected value is roughly
# `scalar_value*(i+1) + tree_aggregation(tree_node_value, i)`. # `scalar_value*(i+1) + tree_aggregation(tree_node_value, i)`.
@ -314,7 +314,7 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
for i in range(total_steps): for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value) sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value) sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result( query_result, global_state, _ = query.get_noised_result(
sample_state, global_state) sample_state, global_state)
if i % frequency == frequency - 1: if i % frequency == frequency - 1:
global_state = query.reset_state(query_result, global_state) global_state = query.reset_state(query_result, global_state)
@ -456,7 +456,7 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
for i in range(total_steps): for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value) sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value) sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result( query_result, global_state, _ = query.get_noised_result(
sample_state, global_state) sample_state, global_state)
if i % frequency == frequency - 1: if i % frequency == frequency - 1:
global_state = query.reset_state(query_result, global_state) global_state = query.reset_state(query_result, global_state)
@ -609,7 +609,7 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
global_state = query.initial_global_state() global_state = query.initial_global_state()
params = query.derive_sample_params(global_state) params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record) preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state = query.get_noised_result( sample_state, global_state, _ = query.get_noised_result(
preprocessed_record, global_state) preprocessed_record, global_state)
self.assertAllClose(sample_state, expected_tree) self.assertAllClose(sample_state, expected_tree)
@ -621,7 +621,7 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
global_state = query.initial_global_state() global_state = query.initial_global_state()
params = query.derive_sample_params(global_state) params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.])) preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.]))
sample_state, global_state = query.get_noised_result( sample_state, global_state, _ = query.get_noised_result(
preprocessed_record, global_state) preprocessed_record, global_state)
self.assertAllClose( self.assertAllClose(

View file

@ -20,6 +20,7 @@ import math
import attr import attr
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query from tensorflow_privacy.privacy.dp_query import gaussian_query
@ -189,7 +190,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
# This part is not written in tensorflow and will be executed on the server # This part is not written in tensorflow and will be executed on the server
# side instead of the client side if used with # side instead of the client side if used with
# tff.aggregators.DifferentiallyPrivateFactory for federated learning. # tff.aggregators.DifferentiallyPrivateFactory for federated learning.
sample_state, inner_query_state = self._inner_query.get_noised_result( sample_state, inner_query_state, _ = self._inner_query.get_noised_result(
sample_state, global_state.inner_query_state) sample_state, global_state.inner_query_state)
new_global_state = TreeRangeSumQuery.GlobalState( new_global_state = TreeRangeSumQuery.GlobalState(
arity=global_state.arity, inner_query_state=inner_query_state) arity=global_state.arity, inner_query_state=inner_query_state)
@ -200,7 +201,8 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
] ]
tree = tf.RaggedTensor.from_row_splits( tree = tf.RaggedTensor.from_row_splits(
values=sample_state, row_splits=row_splits) values=sample_state, row_splits=row_splits)
return tree, new_global_state event = dp_event.UnsupportedDpEvent()
return tree, new_global_state, event
@classmethod @classmethod
def build_central_gaussian_query(cls, def build_central_gaussian_query(cls,

View file

@ -159,7 +159,7 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
global_state = query.initial_global_state() global_state = query.initial_global_state()
params = query.derive_sample_params(global_state) params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record) preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state = query.get_noised_result( sample_state, global_state, _ = query.get_noised_result(
preprocessed_record, global_state) preprocessed_record, global_state)
self.assertAllClose(sample_state, expected_tree) self.assertAllClose(sample_state, expected_tree)
@ -171,7 +171,7 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
global_state = query.initial_global_state() global_state = query.initial_global_state()
params = query.derive_sample_params(global_state) params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.])) preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.]))
sample_state, global_state = query.get_noised_result( sample_state, global_state, _ = query.get_noised_result(
preprocessed_record, global_state) preprocessed_record, global_state)
self.assertAllClose( self.assertAllClose(

View file

@ -164,7 +164,7 @@ def make_optimizer_class(cls):
for idx in range(self._num_microbatches): for idx in range(self._num_microbatches):
sample_state = process_microbatch(idx, sample_state) sample_state = process_microbatch(idx, sample_state)
grad_sums, self._global_state = ( grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state, self._dp_sum_query.get_noised_result(sample_state,
self._global_state)) self._global_state))
@ -235,7 +235,7 @@ def make_optimizer_class(cls):
_, sample_state = tf.while_loop( _, sample_state = tf.while_loop(
cond=cond_fn, body=body_fn, loop_vars=[idx, sample_state]) cond=cond_fn, body=body_fn, loop_vars=[idx, sample_state])
grad_sums, self._global_state = ( grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state, self._dp_sum_query.get_noised_result(sample_state,
self._global_state)) self._global_state))
@ -363,10 +363,6 @@ def make_gaussian_optimizer_class(cls):
return config return config
@property
def ledger(self):
return self._dp_sum_query.ledger
return DPGaussianOptimizerClass return DPGaussianOptimizerClass

View file

@ -81,9 +81,10 @@ def make_keras_optimizer_class(cls):
model.fit(...) model.fit(...)
``` ```
""".format(base_class='tf.keras.optimizers.' + cls.__name__, """.format(
short_base_class=cls.__name__, base_class='tf.keras.optimizers.' + cls.__name__,
dp_keras_class='DPKeras' + cls.__name__) short_base_class=cls.__name__,
dp_keras_class='DPKeras' + cls.__name__)
# The class tf.keras.optimizers.Optimizer has two methods to compute # The class tf.keras.optimizers.Optimizer has two methods to compute
# gradients, `_compute_gradients` and `get_gradients`. The first works # gradients, `_compute_gradients` and `get_gradients`. The first works
@ -106,8 +107,8 @@ def make_keras_optimizer_class(cls):
Args: Args:
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients). l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
noise_multiplier: Ratio of the standard deviation to the clipping norm. noise_multiplier: Ratio of the standard deviation to the clipping norm.
num_microbatches: Number of microbatches into which each minibatch num_microbatches: Number of microbatches into which each minibatch is
is split. split.
*args: These will be passed on to the base class `__init__` method. *args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method. **kwargs: These will be passed on to the base class `__init__` method.
""" """
@ -210,7 +211,7 @@ def make_keras_optimizer_class(cls):
sample_state = self._dp_sum_query.initial_sample_state(params) sample_state = self._dp_sum_query.initial_sample_state(params)
for idx in range(self._num_microbatches): for idx in range(self._num_microbatches):
sample_state = process_microbatch(idx, sample_state) sample_state = process_microbatch(idx, sample_state)
grad_sums, self._global_state = ( grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state, self._dp_sum_query.get_noised_result(sample_state,
self._global_state)) self._global_state))