Add DpEvent to return value of get_noised_result. For most DPQueries, the default UnsupportedDpEvent is returned, pending further development.
PiperOrigin-RevId: 394137614
This commit is contained in:
parent
6ac4bc8d01
commit
7e7736ea91
18 changed files with 86 additions and 94 deletions
|
@ -16,6 +16,7 @@
|
|||
import collections
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.analysis import dp_event
|
||||
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
|
||||
|
@ -81,4 +82,6 @@ class DiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
|
|||
# Ensure shape as TF shape inference may fail due to custom noise sampler.
|
||||
return tf.ensure_shape(noised_v, v.shape)
|
||||
|
||||
return tf.nest.map_structure(add_noise, sample_state), global_state
|
||||
result = tf.nest.map_structure(add_noise, sample_state)
|
||||
event = dp_event.UnsupportedDpEvent()
|
||||
return result, global_state, event
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import collections
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.analysis import dp_event
|
||||
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
|
||||
|
@ -106,4 +107,5 @@ class DistributedDiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
|
|||
def get_noised_result(self, sample_state, global_state):
|
||||
# Note that by directly returning the aggregate, this assumes that there
|
||||
# will not be missing local noise shares during execution.
|
||||
return sample_state, global_state
|
||||
event = dp_event.UnsupportedDpEvent()
|
||||
return sample_state, global_state, event
|
||||
|
|
|
@ -246,11 +246,14 @@ class DPQuery(object):
|
|||
global_state: The global state, storing long-term privacy bookkeeping.
|
||||
|
||||
Returns:
|
||||
A tuple (result, new_global_state) where "result" is the result of the
|
||||
query and "new_global_state" is the updated global state. In standard
|
||||
DP-SGD training, the result is a gradient update comprising a noised
|
||||
average of the clipped gradients in the sample state---with the noise and
|
||||
averaging performed in a manner that guarantees differential privacy.
|
||||
A tuple `(result, new_global_state, event)` where:
|
||||
* `result` is the result of the query,
|
||||
* `new_global_state` is the updated global state, and
|
||||
* `event` is the `DpEvent` that occurred.
|
||||
In standard DP-SGD training, the result is a gradient update comprising a
|
||||
noised average of the clipped gradients in the sample state---with the
|
||||
noise and averaging performed in a manner that guarantees differential
|
||||
privacy.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -297,7 +300,3 @@ class SumAggregationDPQuery(DPQuery):
|
|||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||
"""Implements `tensorflow_privacy.DPQuery.merge_sample_states`."""
|
||||
return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
return sample_state, global_state
|
||||
|
|
|
@ -22,6 +22,7 @@ import distutils
|
|||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from tensorflow_privacy.privacy.analysis import dp_event
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
|
||||
|
||||
|
@ -45,7 +46,6 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
|||
"""
|
||||
self._l2_norm_clip = l2_norm_clip
|
||||
self._stddev = stddev
|
||||
self._ledger = None
|
||||
|
||||
def make_global_state(self, l2_norm_clip, stddev):
|
||||
"""Creates a global state from the given parameters."""
|
||||
|
@ -96,12 +96,8 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
|||
def add_noise(v):
|
||||
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
|
||||
|
||||
if self._ledger:
|
||||
dependencies = [
|
||||
self._ledger.record_sum_query(global_state.l2_norm_clip,
|
||||
global_state.stddev)
|
||||
]
|
||||
else:
|
||||
dependencies = []
|
||||
with tf.control_dependencies(dependencies):
|
||||
return tf.nest.map_structure(add_noise, sample_state), global_state
|
||||
result = tf.nest.map_structure(add_noise, sample_state)
|
||||
noise_multiplier = global_state.stddev / global_state.l2_norm_clip
|
||||
event = dp_event.GaussianDpEvent(noise_multiplier)
|
||||
|
||||
return result, global_state, event
|
||||
|
|
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
|||
import collections
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from tensorflow_privacy.privacy.analysis import dp_event
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
import tree
|
||||
|
||||
|
@ -96,13 +98,15 @@ class NestedQuery(dp_query.DPQuery):
|
|||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
estimates_and_new_global_states = self._map_to_queries(
|
||||
'get_noised_result', sample_state, global_state)
|
||||
mapped_query_results = self._map_to_queries('get_noised_result',
|
||||
sample_state, global_state)
|
||||
|
||||
flat_estimates, flat_new_global_states, flat_events = zip(
|
||||
*tree.flatten_up_to(self._queries, mapped_query_results))
|
||||
|
||||
flat_estimates, flat_new_global_states = zip(
|
||||
*tree.flatten_up_to(self._queries, estimates_and_new_global_states))
|
||||
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
|
||||
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
|
||||
tf.nest.pack_sequence_as(self._queries, flat_new_global_states),
|
||||
dp_event.ComposedDpEvent(events=flat_events))
|
||||
|
||||
def derive_metrics(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
|
||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
|||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from tensorflow_privacy.privacy.analysis import dp_event
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
|
||||
|
||||
|
@ -28,19 +29,9 @@ class NoPrivacySumQuery(dp_query.SumAggregationDPQuery):
|
|||
Accumulates vectors without clipping or adding noise.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._ledger = None
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
|
||||
if self._ledger:
|
||||
dependencies = [self._ledger.record_sum_query(float('inf'), 0.0)]
|
||||
else:
|
||||
dependencies = []
|
||||
|
||||
with tf.control_dependencies(dependencies):
|
||||
return sample_state, global_state
|
||||
return sample_state, global_state, dp_event.NonPrivateDpEvent()
|
||||
|
||||
|
||||
class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
|
||||
|
@ -56,10 +47,6 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
|
|||
privatized.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the NoPrivacyAverageQuery."""
|
||||
self._ledger = None
|
||||
|
||||
def initial_sample_state(self, template):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
|
||||
return (super(NoPrivacyAverageQuery,
|
||||
|
@ -103,11 +90,5 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
|
|||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
sum_state, denominator = sample_state
|
||||
|
||||
if self._ledger:
|
||||
dependencies = [self._ledger.record_sum_query(float('inf'), 0.0)]
|
||||
else:
|
||||
dependencies = []
|
||||
|
||||
with tf.control_dependencies(dependencies):
|
||||
return (tf.nest.map_structure(lambda t: t / denominator,
|
||||
sum_state), global_state)
|
||||
result = tf.nest.map_structure(lambda t: t / denominator, sum_state)
|
||||
return result, global_state, dp_event.NonPrivateDpEvent()
|
||||
|
|
|
@ -74,14 +74,16 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
|
|||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
|
||||
noised_sum, new_sum_global_state, event = self._numerator.get_noised_result(
|
||||
sample_state, global_state.numerator_state)
|
||||
|
||||
def normalize(v):
|
||||
return tf.truediv(v, global_state.denominator)
|
||||
|
||||
# The denominator is constant so the privacy cost comes from the numerator.
|
||||
return (tf.nest.map_structure(normalize, noised_sum),
|
||||
self._GlobalState(new_sum_global_state, global_state.denominator))
|
||||
self._GlobalState(new_sum_global_state,
|
||||
global_state.denominator), event)
|
||||
|
||||
def derive_metrics(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
|
||||
|
|
|
@ -21,6 +21,7 @@ import collections
|
|||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from tensorflow_privacy.privacy.analysis import dp_event
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||
from tensorflow_privacy.privacy.dp_query import quantile_estimator_query
|
||||
|
@ -123,11 +124,11 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
|||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
noised_vectors, sum_state = self._sum_query.get_noised_result(
|
||||
noised_vectors, sum_state, sum_event = self._sum_query.get_noised_result(
|
||||
sample_state.sum_state, global_state.sum_state)
|
||||
del sum_state # To be set explicitly later when we know the new clip.
|
||||
|
||||
new_l2_norm_clip, new_quantile_estimator_state = (
|
||||
new_l2_norm_clip, new_quantile_estimator_state, quantile_event = (
|
||||
self._quantile_estimator_query.get_noised_result(
|
||||
sample_state.quantile_estimator_state,
|
||||
global_state.quantile_estimator_state))
|
||||
|
@ -141,7 +142,8 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
|||
new_sum_query_state,
|
||||
new_quantile_estimator_state)
|
||||
|
||||
return noised_vectors, new_global_state
|
||||
event = dp_event.ComposedDpEvent(events=[sum_event, quantile_event])
|
||||
return noised_vectors, new_global_state, event
|
||||
|
||||
def derive_metrics(self, global_state):
|
||||
"""Returns the current clipping norm as a metric."""
|
||||
|
|
|
@ -135,7 +135,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
|||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
below_estimate_result, new_below_estimate_state = (
|
||||
below_estimate_result, new_below_estimate_state, below_estimate_event = (
|
||||
self._below_estimate_query.get_noised_result(
|
||||
sample_state, global_state.below_estimate_state))
|
||||
|
||||
|
@ -159,7 +159,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
|||
current_estimate=new_estimate,
|
||||
below_estimate_state=new_below_estimate_state)
|
||||
|
||||
return new_estimate, new_global_state
|
||||
return new_estimate, new_global_state, below_estimate_event
|
||||
|
||||
def derive_metrics(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
|
||||
|
|
|
@ -134,14 +134,14 @@ class RestartQuery(dp_query.SumAggregationDPQuery):
|
|||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||
noised_results, inner_query_state = self._inner_query.get_noised_result(
|
||||
noised_results, inner_state, event = self._inner_query.get_noised_result(
|
||||
sample_state, global_state.inner_query_state)
|
||||
restart_flag, indicator_state = self._restart_indicator.next(
|
||||
global_state.indicator_state)
|
||||
if restart_flag:
|
||||
inner_query_state = self._inner_query.reset_state(noised_results,
|
||||
inner_query_state)
|
||||
return noised_results, self._GlobalState(inner_query_state, indicator_state)
|
||||
inner_state = self._inner_query.reset_state(noised_results, inner_state)
|
||||
return (noised_results, self._GlobalState(inner_state,
|
||||
indicator_state), event)
|
||||
|
||||
def derive_metrics(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
|
||||
|
|
|
@ -75,7 +75,7 @@ class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
for i in range(total_steps):
|
||||
sample_state = query.initial_sample_state(scalar_value)
|
||||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||
query_result, global_state = query.get_noised_result(
|
||||
query_result, global_state, _ = query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
# Expected value is the combination of cumsum of signal; sum of trees
|
||||
# that have been reset; current tree sum. The tree aggregation value can
|
||||
|
@ -110,7 +110,7 @@ class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
for i in range(total_steps):
|
||||
sample_state = query.initial_sample_state(scalar_value)
|
||||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||
query_result, global_state = query.get_noised_result(
|
||||
query_result, global_state, _ = query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
# Expected value is the signal of the current round plus the residual of
|
||||
# two continous tree aggregation values. The tree aggregation value can
|
||||
|
|
|
@ -44,6 +44,7 @@ def run_query(query, records, global_state=None, weights=None):
|
|||
sample_state = query.accumulate_record(params, sample_state, record)
|
||||
else:
|
||||
for weight, record in zip(weights, records):
|
||||
sample_state = query.accumulate_record(
|
||||
params, sample_state, record, weight)
|
||||
return query.get_noised_result(sample_state, global_state)
|
||||
sample_state = query.accumulate_record(params, sample_state, record,
|
||||
weight)
|
||||
result, global_state, _ = query.get_noised_result(sample_state, global_state)
|
||||
return result, global_state
|
||||
|
|
|
@ -24,12 +24,12 @@ import math
|
|||
|
||||
import attr
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.analysis import dp_event
|
||||
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
||||
|
||||
|
||||
# TODO(b/193679963): define `RestartQuery` and move `RestartIndicator` to be
|
||||
# in the same module.
|
||||
|
||||
|
@ -57,7 +57,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
for j,sample in enumerate(samples):
|
||||
sample_state = query.accumulate_record(params, sample_state, sample)
|
||||
# noised_cumsum is privatized estimate of s_i
|
||||
noised_cumsum, global_state = query.get_noised_result(
|
||||
noised_cumsum, global_state, event = query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
|
||||
Attributes:
|
||||
|
@ -176,7 +176,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
global_state,
|
||||
samples_cumulative_sum=new_cumulative_sum,
|
||||
tree_state=new_tree_state)
|
||||
return noised_cumulative_sum, new_global_state
|
||||
event = dp_event.UnsupportedDpEvent()
|
||||
return noised_cumulative_sum, new_global_state, event
|
||||
|
||||
def reset_state(self, noised_results, global_state):
|
||||
"""Returns state after resetting the tree.
|
||||
|
@ -281,7 +282,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
sample_state = query.accumulate_record(params, sample_state, sample)
|
||||
# noised_sum is privatized estimate of x_i by conceptually postprocessing
|
||||
# noised cumulative sum s_i
|
||||
noised_sum, global_state = query.get_noised_result(
|
||||
noised_sum, global_state, event = query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
|
||||
Attributes:
|
||||
|
@ -398,7 +399,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
global_state.previous_tree_noise)
|
||||
new_global_state = attr.evolve(
|
||||
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
|
||||
return noised_sample, new_global_state
|
||||
event = dp_event.UnsupportedDpEvent()
|
||||
return noised_sample, new_global_state, event
|
||||
|
||||
def reset_state(self, noised_results, global_state):
|
||||
"""Returns state after resetting the tree.
|
||||
|
@ -636,7 +638,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
# This part is not written in tensorflow and will be executed on the server
|
||||
# side instead of the client side if used with
|
||||
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
|
||||
sample_state, inner_query_state = self._inner_query.get_noised_result(
|
||||
sample_state, inner_query_state, _ = self._inner_query.get_noised_result(
|
||||
sample_state, global_state.inner_query_state)
|
||||
new_global_state = TreeRangeSumQuery.GlobalState(
|
||||
arity=global_state.arity, inner_query_state=inner_query_state)
|
||||
|
@ -647,7 +649,8 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
]
|
||||
tree = tf.RaggedTensor.from_row_splits(
|
||||
values=sample_state, row_splits=row_splits)
|
||||
return tree, new_global_state
|
||||
event = dp_event.UnsupportedDpEvent()
|
||||
return tree, new_global_state, event
|
||||
|
||||
@classmethod
|
||||
def build_central_gaussian_query(cls,
|
||||
|
|
|
@ -258,7 +258,7 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
for scalar, expected_sum in zip(streaming_scalars, partial_sum):
|
||||
sample_state = query.initial_sample_state(scalar)
|
||||
sample_state = query.accumulate_record(params, sample_state, scalar)
|
||||
query_result, global_state = query.get_noised_result(
|
||||
query_result, global_state, _ = query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
self.assertEqual(query_result, expected_sum)
|
||||
|
||||
|
@ -282,7 +282,7 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
for i in range(total_steps):
|
||||
sample_state = query.initial_sample_state(scalar_value)
|
||||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||
query_result, global_state = query.get_noised_result(
|
||||
query_result, global_state, _ = query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
# For each streaming step i , the expected value is roughly
|
||||
# `scalar_value*(i+1) + tree_aggregation(tree_node_value, i)`.
|
||||
|
@ -314,7 +314,7 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
for i in range(total_steps):
|
||||
sample_state = query.initial_sample_state(scalar_value)
|
||||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||
query_result, global_state = query.get_noised_result(
|
||||
query_result, global_state, _ = query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
if i % frequency == frequency - 1:
|
||||
global_state = query.reset_state(query_result, global_state)
|
||||
|
@ -456,7 +456,7 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
for i in range(total_steps):
|
||||
sample_state = query.initial_sample_state(scalar_value)
|
||||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||
query_result, global_state = query.get_noised_result(
|
||||
query_result, global_state, _ = query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
if i % frequency == frequency - 1:
|
||||
global_state = query.reset_state(query_result, global_state)
|
||||
|
@ -609,7 +609,7 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
sample_state, global_state = query.get_noised_result(
|
||||
sample_state, global_state, _ = query.get_noised_result(
|
||||
preprocessed_record, global_state)
|
||||
|
||||
self.assertAllClose(sample_state, expected_tree)
|
||||
|
@ -621,7 +621,7 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.]))
|
||||
sample_state, global_state = query.get_noised_result(
|
||||
sample_state, global_state, _ = query.get_noised_result(
|
||||
preprocessed_record, global_state)
|
||||
|
||||
self.assertAllClose(
|
||||
|
|
|
@ -20,6 +20,7 @@ import math
|
|||
|
||||
import attr
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.analysis import dp_event
|
||||
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||
|
@ -189,7 +190,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
# This part is not written in tensorflow and will be executed on the server
|
||||
# side instead of the client side if used with
|
||||
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
|
||||
sample_state, inner_query_state = self._inner_query.get_noised_result(
|
||||
sample_state, inner_query_state, _ = self._inner_query.get_noised_result(
|
||||
sample_state, global_state.inner_query_state)
|
||||
new_global_state = TreeRangeSumQuery.GlobalState(
|
||||
arity=global_state.arity, inner_query_state=inner_query_state)
|
||||
|
@ -200,7 +201,8 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
]
|
||||
tree = tf.RaggedTensor.from_row_splits(
|
||||
values=sample_state, row_splits=row_splits)
|
||||
return tree, new_global_state
|
||||
event = dp_event.UnsupportedDpEvent()
|
||||
return tree, new_global_state, event
|
||||
|
||||
@classmethod
|
||||
def build_central_gaussian_query(cls,
|
||||
|
|
|
@ -159,7 +159,7 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
sample_state, global_state = query.get_noised_result(
|
||||
sample_state, global_state, _ = query.get_noised_result(
|
||||
preprocessed_record, global_state)
|
||||
|
||||
self.assertAllClose(sample_state, expected_tree)
|
||||
|
@ -171,7 +171,7 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.]))
|
||||
sample_state, global_state = query.get_noised_result(
|
||||
sample_state, global_state, _ = query.get_noised_result(
|
||||
preprocessed_record, global_state)
|
||||
|
||||
self.assertAllClose(
|
||||
|
|
|
@ -164,7 +164,7 @@ def make_optimizer_class(cls):
|
|||
for idx in range(self._num_microbatches):
|
||||
sample_state = process_microbatch(idx, sample_state)
|
||||
|
||||
grad_sums, self._global_state = (
|
||||
grad_sums, self._global_state, _ = (
|
||||
self._dp_sum_query.get_noised_result(sample_state,
|
||||
self._global_state))
|
||||
|
||||
|
@ -235,7 +235,7 @@ def make_optimizer_class(cls):
|
|||
_, sample_state = tf.while_loop(
|
||||
cond=cond_fn, body=body_fn, loop_vars=[idx, sample_state])
|
||||
|
||||
grad_sums, self._global_state = (
|
||||
grad_sums, self._global_state, _ = (
|
||||
self._dp_sum_query.get_noised_result(sample_state,
|
||||
self._global_state))
|
||||
|
||||
|
@ -363,10 +363,6 @@ def make_gaussian_optimizer_class(cls):
|
|||
|
||||
return config
|
||||
|
||||
@property
|
||||
def ledger(self):
|
||||
return self._dp_sum_query.ledger
|
||||
|
||||
return DPGaussianOptimizerClass
|
||||
|
||||
|
||||
|
|
|
@ -81,9 +81,10 @@ def make_keras_optimizer_class(cls):
|
|||
model.fit(...)
|
||||
```
|
||||
|
||||
""".format(base_class='tf.keras.optimizers.' + cls.__name__,
|
||||
short_base_class=cls.__name__,
|
||||
dp_keras_class='DPKeras' + cls.__name__)
|
||||
""".format(
|
||||
base_class='tf.keras.optimizers.' + cls.__name__,
|
||||
short_base_class=cls.__name__,
|
||||
dp_keras_class='DPKeras' + cls.__name__)
|
||||
|
||||
# The class tf.keras.optimizers.Optimizer has two methods to compute
|
||||
# gradients, `_compute_gradients` and `get_gradients`. The first works
|
||||
|
@ -106,8 +107,8 @@ def make_keras_optimizer_class(cls):
|
|||
Args:
|
||||
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
|
||||
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
||||
num_microbatches: Number of microbatches into which each minibatch
|
||||
is split.
|
||||
num_microbatches: Number of microbatches into which each minibatch is
|
||||
split.
|
||||
*args: These will be passed on to the base class `__init__` method.
|
||||
**kwargs: These will be passed on to the base class `__init__` method.
|
||||
"""
|
||||
|
@ -210,7 +211,7 @@ def make_keras_optimizer_class(cls):
|
|||
sample_state = self._dp_sum_query.initial_sample_state(params)
|
||||
for idx in range(self._num_microbatches):
|
||||
sample_state = process_microbatch(idx, sample_state)
|
||||
grad_sums, self._global_state = (
|
||||
grad_sums, self._global_state, _ = (
|
||||
self._dp_sum_query.get_noised_result(sample_state,
|
||||
self._global_state))
|
||||
|
||||
|
|
Loading…
Reference in a new issue