forked from 626_privacy/tensorflow_privacy
Add DpEvent to return value of get_noised_result. For most DPQueries, the default UnsupportedDpEvent is returned, pending further development.
PiperOrigin-RevId: 394137614
This commit is contained in:
parent
6ac4bc8d01
commit
7e7736ea91
18 changed files with 86 additions and 94 deletions
|
@ -16,6 +16,7 @@
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.analysis import dp_event
|
||||||
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
|
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
|
|
||||||
|
@ -81,4 +82,6 @@ class DiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
# Ensure shape as TF shape inference may fail due to custom noise sampler.
|
# Ensure shape as TF shape inference may fail due to custom noise sampler.
|
||||||
return tf.ensure_shape(noised_v, v.shape)
|
return tf.ensure_shape(noised_v, v.shape)
|
||||||
|
|
||||||
return tf.nest.map_structure(add_noise, sample_state), global_state
|
result = tf.nest.map_structure(add_noise, sample_state)
|
||||||
|
event = dp_event.UnsupportedDpEvent()
|
||||||
|
return result, global_state, event
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.analysis import dp_event
|
||||||
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
|
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
|
|
||||||
|
@ -106,4 +107,5 @@ class DistributedDiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
# Note that by directly returning the aggregate, this assumes that there
|
# Note that by directly returning the aggregate, this assumes that there
|
||||||
# will not be missing local noise shares during execution.
|
# will not be missing local noise shares during execution.
|
||||||
return sample_state, global_state
|
event = dp_event.UnsupportedDpEvent()
|
||||||
|
return sample_state, global_state, event
|
||||||
|
|
|
@ -246,11 +246,14 @@ class DPQuery(object):
|
||||||
global_state: The global state, storing long-term privacy bookkeeping.
|
global_state: The global state, storing long-term privacy bookkeeping.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple (result, new_global_state) where "result" is the result of the
|
A tuple `(result, new_global_state, event)` where:
|
||||||
query and "new_global_state" is the updated global state. In standard
|
* `result` is the result of the query,
|
||||||
DP-SGD training, the result is a gradient update comprising a noised
|
* `new_global_state` is the updated global state, and
|
||||||
average of the clipped gradients in the sample state---with the noise and
|
* `event` is the `DpEvent` that occurred.
|
||||||
averaging performed in a manner that guarantees differential privacy.
|
In standard DP-SGD training, the result is a gradient update comprising a
|
||||||
|
noised average of the clipped gradients in the sample state---with the
|
||||||
|
noise and averaging performed in a manner that guarantees differential
|
||||||
|
privacy.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -297,7 +300,3 @@ class SumAggregationDPQuery(DPQuery):
|
||||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.merge_sample_states`."""
|
"""Implements `tensorflow_privacy.DPQuery.merge_sample_states`."""
|
||||||
return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2)
|
return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2)
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
|
||||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
|
||||||
return sample_state, global_state
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ import distutils
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.analysis import dp_event
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,7 +46,6 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""
|
"""
|
||||||
self._l2_norm_clip = l2_norm_clip
|
self._l2_norm_clip = l2_norm_clip
|
||||||
self._stddev = stddev
|
self._stddev = stddev
|
||||||
self._ledger = None
|
|
||||||
|
|
||||||
def make_global_state(self, l2_norm_clip, stddev):
|
def make_global_state(self, l2_norm_clip, stddev):
|
||||||
"""Creates a global state from the given parameters."""
|
"""Creates a global state from the given parameters."""
|
||||||
|
@ -96,12 +96,8 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
def add_noise(v):
|
def add_noise(v):
|
||||||
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
|
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
|
||||||
|
|
||||||
if self._ledger:
|
result = tf.nest.map_structure(add_noise, sample_state)
|
||||||
dependencies = [
|
noise_multiplier = global_state.stddev / global_state.l2_norm_clip
|
||||||
self._ledger.record_sum_query(global_state.l2_norm_clip,
|
event = dp_event.GaussianDpEvent(noise_multiplier)
|
||||||
global_state.stddev)
|
|
||||||
]
|
return result, global_state, event
|
||||||
else:
|
|
||||||
dependencies = []
|
|
||||||
with tf.control_dependencies(dependencies):
|
|
||||||
return tf.nest.map_structure(add_noise, sample_state), global_state
|
|
||||||
|
|
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.analysis import dp_event
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
import tree
|
import tree
|
||||||
|
|
||||||
|
@ -96,13 +98,15 @@ class NestedQuery(dp_query.DPQuery):
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||||
estimates_and_new_global_states = self._map_to_queries(
|
mapped_query_results = self._map_to_queries('get_noised_result',
|
||||||
'get_noised_result', sample_state, global_state)
|
sample_state, global_state)
|
||||||
|
|
||||||
|
flat_estimates, flat_new_global_states, flat_events = zip(
|
||||||
|
*tree.flatten_up_to(self._queries, mapped_query_results))
|
||||||
|
|
||||||
flat_estimates, flat_new_global_states = zip(
|
|
||||||
*tree.flatten_up_to(self._queries, estimates_and_new_global_states))
|
|
||||||
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
|
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
|
||||||
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
|
tf.nest.pack_sequence_as(self._queries, flat_new_global_states),
|
||||||
|
dp_event.ComposedDpEvent(events=flat_events))
|
||||||
|
|
||||||
def derive_metrics(self, global_state):
|
def derive_metrics(self, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
|
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
|
||||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.analysis import dp_event
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,19 +29,9 @@ class NoPrivacySumQuery(dp_query.SumAggregationDPQuery):
|
||||||
Accumulates vectors without clipping or adding noise.
|
Accumulates vectors without clipping or adding noise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._ledger = None
|
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||||
|
return sample_state, global_state, dp_event.NonPrivateDpEvent()
|
||||||
if self._ledger:
|
|
||||||
dependencies = [self._ledger.record_sum_query(float('inf'), 0.0)]
|
|
||||||
else:
|
|
||||||
dependencies = []
|
|
||||||
|
|
||||||
with tf.control_dependencies(dependencies):
|
|
||||||
return sample_state, global_state
|
|
||||||
|
|
||||||
|
|
||||||
class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
|
class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
@ -56,10 +47,6 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
|
||||||
privatized.
|
privatized.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initializes the NoPrivacyAverageQuery."""
|
|
||||||
self._ledger = None
|
|
||||||
|
|
||||||
def initial_sample_state(self, template):
|
def initial_sample_state(self, template):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
|
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
|
||||||
return (super(NoPrivacyAverageQuery,
|
return (super(NoPrivacyAverageQuery,
|
||||||
|
@ -103,11 +90,5 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||||
sum_state, denominator = sample_state
|
sum_state, denominator = sample_state
|
||||||
|
|
||||||
if self._ledger:
|
result = tf.nest.map_structure(lambda t: t / denominator, sum_state)
|
||||||
dependencies = [self._ledger.record_sum_query(float('inf'), 0.0)]
|
return result, global_state, dp_event.NonPrivateDpEvent()
|
||||||
else:
|
|
||||||
dependencies = []
|
|
||||||
|
|
||||||
with tf.control_dependencies(dependencies):
|
|
||||||
return (tf.nest.map_structure(lambda t: t / denominator,
|
|
||||||
sum_state), global_state)
|
|
||||||
|
|
|
@ -74,14 +74,16 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||||
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
|
noised_sum, new_sum_global_state, event = self._numerator.get_noised_result(
|
||||||
sample_state, global_state.numerator_state)
|
sample_state, global_state.numerator_state)
|
||||||
|
|
||||||
def normalize(v):
|
def normalize(v):
|
||||||
return tf.truediv(v, global_state.denominator)
|
return tf.truediv(v, global_state.denominator)
|
||||||
|
|
||||||
|
# The denominator is constant so the privacy cost comes from the numerator.
|
||||||
return (tf.nest.map_structure(normalize, noised_sum),
|
return (tf.nest.map_structure(normalize, noised_sum),
|
||||||
self._GlobalState(new_sum_global_state, global_state.denominator))
|
self._GlobalState(new_sum_global_state,
|
||||||
|
global_state.denominator), event)
|
||||||
|
|
||||||
def derive_metrics(self, global_state):
|
def derive_metrics(self, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
|
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
|
||||||
|
|
|
@ -21,6 +21,7 @@ import collections
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.analysis import dp_event
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||||
from tensorflow_privacy.privacy.dp_query import quantile_estimator_query
|
from tensorflow_privacy.privacy.dp_query import quantile_estimator_query
|
||||||
|
@ -123,11 +124,11 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||||
noised_vectors, sum_state = self._sum_query.get_noised_result(
|
noised_vectors, sum_state, sum_event = self._sum_query.get_noised_result(
|
||||||
sample_state.sum_state, global_state.sum_state)
|
sample_state.sum_state, global_state.sum_state)
|
||||||
del sum_state # To be set explicitly later when we know the new clip.
|
del sum_state # To be set explicitly later when we know the new clip.
|
||||||
|
|
||||||
new_l2_norm_clip, new_quantile_estimator_state = (
|
new_l2_norm_clip, new_quantile_estimator_state, quantile_event = (
|
||||||
self._quantile_estimator_query.get_noised_result(
|
self._quantile_estimator_query.get_noised_result(
|
||||||
sample_state.quantile_estimator_state,
|
sample_state.quantile_estimator_state,
|
||||||
global_state.quantile_estimator_state))
|
global_state.quantile_estimator_state))
|
||||||
|
@ -141,7 +142,8 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
new_sum_query_state,
|
new_sum_query_state,
|
||||||
new_quantile_estimator_state)
|
new_quantile_estimator_state)
|
||||||
|
|
||||||
return noised_vectors, new_global_state
|
event = dp_event.ComposedDpEvent(events=[sum_event, quantile_event])
|
||||||
|
return noised_vectors, new_global_state, event
|
||||||
|
|
||||||
def derive_metrics(self, global_state):
|
def derive_metrics(self, global_state):
|
||||||
"""Returns the current clipping norm as a metric."""
|
"""Returns the current clipping norm as a metric."""
|
||||||
|
|
|
@ -135,7 +135,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||||
below_estimate_result, new_below_estimate_state = (
|
below_estimate_result, new_below_estimate_state, below_estimate_event = (
|
||||||
self._below_estimate_query.get_noised_result(
|
self._below_estimate_query.get_noised_result(
|
||||||
sample_state, global_state.below_estimate_state))
|
sample_state, global_state.below_estimate_state))
|
||||||
|
|
||||||
|
@ -159,7 +159,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
current_estimate=new_estimate,
|
current_estimate=new_estimate,
|
||||||
below_estimate_state=new_below_estimate_state)
|
below_estimate_state=new_below_estimate_state)
|
||||||
|
|
||||||
return new_estimate, new_global_state
|
return new_estimate, new_global_state, below_estimate_event
|
||||||
|
|
||||||
def derive_metrics(self, global_state):
|
def derive_metrics(self, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
|
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
|
||||||
|
|
|
@ -134,14 +134,14 @@ class RestartQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||||
noised_results, inner_query_state = self._inner_query.get_noised_result(
|
noised_results, inner_state, event = self._inner_query.get_noised_result(
|
||||||
sample_state, global_state.inner_query_state)
|
sample_state, global_state.inner_query_state)
|
||||||
restart_flag, indicator_state = self._restart_indicator.next(
|
restart_flag, indicator_state = self._restart_indicator.next(
|
||||||
global_state.indicator_state)
|
global_state.indicator_state)
|
||||||
if restart_flag:
|
if restart_flag:
|
||||||
inner_query_state = self._inner_query.reset_state(noised_results,
|
inner_state = self._inner_query.reset_state(noised_results, inner_state)
|
||||||
inner_query_state)
|
return (noised_results, self._GlobalState(inner_state,
|
||||||
return noised_results, self._GlobalState(inner_query_state, indicator_state)
|
indicator_state), event)
|
||||||
|
|
||||||
def derive_metrics(self, global_state):
|
def derive_metrics(self, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
|
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
|
||||||
|
|
|
@ -75,7 +75,7 @@ class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
for i in range(total_steps):
|
for i in range(total_steps):
|
||||||
sample_state = query.initial_sample_state(scalar_value)
|
sample_state = query.initial_sample_state(scalar_value)
|
||||||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||||
query_result, global_state = query.get_noised_result(
|
query_result, global_state, _ = query.get_noised_result(
|
||||||
sample_state, global_state)
|
sample_state, global_state)
|
||||||
# Expected value is the combination of cumsum of signal; sum of trees
|
# Expected value is the combination of cumsum of signal; sum of trees
|
||||||
# that have been reset; current tree sum. The tree aggregation value can
|
# that have been reset; current tree sum. The tree aggregation value can
|
||||||
|
@ -110,7 +110,7 @@ class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
for i in range(total_steps):
|
for i in range(total_steps):
|
||||||
sample_state = query.initial_sample_state(scalar_value)
|
sample_state = query.initial_sample_state(scalar_value)
|
||||||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||||
query_result, global_state = query.get_noised_result(
|
query_result, global_state, _ = query.get_noised_result(
|
||||||
sample_state, global_state)
|
sample_state, global_state)
|
||||||
# Expected value is the signal of the current round plus the residual of
|
# Expected value is the signal of the current round plus the residual of
|
||||||
# two continous tree aggregation values. The tree aggregation value can
|
# two continous tree aggregation values. The tree aggregation value can
|
||||||
|
|
|
@ -44,6 +44,7 @@ def run_query(query, records, global_state=None, weights=None):
|
||||||
sample_state = query.accumulate_record(params, sample_state, record)
|
sample_state = query.accumulate_record(params, sample_state, record)
|
||||||
else:
|
else:
|
||||||
for weight, record in zip(weights, records):
|
for weight, record in zip(weights, records):
|
||||||
sample_state = query.accumulate_record(
|
sample_state = query.accumulate_record(params, sample_state, record,
|
||||||
params, sample_state, record, weight)
|
weight)
|
||||||
return query.get_noised_result(sample_state, global_state)
|
result, global_state, _ = query.get_noised_result(sample_state, global_state)
|
||||||
|
return result, global_state
|
||||||
|
|
|
@ -24,12 +24,12 @@ import math
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.analysis import dp_event
|
||||||
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
|
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
||||||
|
|
||||||
|
|
||||||
# TODO(b/193679963): define `RestartQuery` and move `RestartIndicator` to be
|
# TODO(b/193679963): define `RestartQuery` and move `RestartIndicator` to be
|
||||||
# in the same module.
|
# in the same module.
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
for j,sample in enumerate(samples):
|
for j,sample in enumerate(samples):
|
||||||
sample_state = query.accumulate_record(params, sample_state, sample)
|
sample_state = query.accumulate_record(params, sample_state, sample)
|
||||||
# noised_cumsum is privatized estimate of s_i
|
# noised_cumsum is privatized estimate of s_i
|
||||||
noised_cumsum, global_state = query.get_noised_result(
|
noised_cumsum, global_state, event = query.get_noised_result(
|
||||||
sample_state, global_state)
|
sample_state, global_state)
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -176,7 +176,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
global_state,
|
global_state,
|
||||||
samples_cumulative_sum=new_cumulative_sum,
|
samples_cumulative_sum=new_cumulative_sum,
|
||||||
tree_state=new_tree_state)
|
tree_state=new_tree_state)
|
||||||
return noised_cumulative_sum, new_global_state
|
event = dp_event.UnsupportedDpEvent()
|
||||||
|
return noised_cumulative_sum, new_global_state, event
|
||||||
|
|
||||||
def reset_state(self, noised_results, global_state):
|
def reset_state(self, noised_results, global_state):
|
||||||
"""Returns state after resetting the tree.
|
"""Returns state after resetting the tree.
|
||||||
|
@ -281,7 +282,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
sample_state = query.accumulate_record(params, sample_state, sample)
|
sample_state = query.accumulate_record(params, sample_state, sample)
|
||||||
# noised_sum is privatized estimate of x_i by conceptually postprocessing
|
# noised_sum is privatized estimate of x_i by conceptually postprocessing
|
||||||
# noised cumulative sum s_i
|
# noised cumulative sum s_i
|
||||||
noised_sum, global_state = query.get_noised_result(
|
noised_sum, global_state, event = query.get_noised_result(
|
||||||
sample_state, global_state)
|
sample_state, global_state)
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -398,7 +399,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
global_state.previous_tree_noise)
|
global_state.previous_tree_noise)
|
||||||
new_global_state = attr.evolve(
|
new_global_state = attr.evolve(
|
||||||
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
|
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
|
||||||
return noised_sample, new_global_state
|
event = dp_event.UnsupportedDpEvent()
|
||||||
|
return noised_sample, new_global_state, event
|
||||||
|
|
||||||
def reset_state(self, noised_results, global_state):
|
def reset_state(self, noised_results, global_state):
|
||||||
"""Returns state after resetting the tree.
|
"""Returns state after resetting the tree.
|
||||||
|
@ -636,7 +638,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
# This part is not written in tensorflow and will be executed on the server
|
# This part is not written in tensorflow and will be executed on the server
|
||||||
# side instead of the client side if used with
|
# side instead of the client side if used with
|
||||||
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
|
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
|
||||||
sample_state, inner_query_state = self._inner_query.get_noised_result(
|
sample_state, inner_query_state, _ = self._inner_query.get_noised_result(
|
||||||
sample_state, global_state.inner_query_state)
|
sample_state, global_state.inner_query_state)
|
||||||
new_global_state = TreeRangeSumQuery.GlobalState(
|
new_global_state = TreeRangeSumQuery.GlobalState(
|
||||||
arity=global_state.arity, inner_query_state=inner_query_state)
|
arity=global_state.arity, inner_query_state=inner_query_state)
|
||||||
|
@ -647,7 +649,8 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
]
|
]
|
||||||
tree = tf.RaggedTensor.from_row_splits(
|
tree = tf.RaggedTensor.from_row_splits(
|
||||||
values=sample_state, row_splits=row_splits)
|
values=sample_state, row_splits=row_splits)
|
||||||
return tree, new_global_state
|
event = dp_event.UnsupportedDpEvent()
|
||||||
|
return tree, new_global_state, event
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_central_gaussian_query(cls,
|
def build_central_gaussian_query(cls,
|
||||||
|
|
|
@ -258,7 +258,7 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
for scalar, expected_sum in zip(streaming_scalars, partial_sum):
|
for scalar, expected_sum in zip(streaming_scalars, partial_sum):
|
||||||
sample_state = query.initial_sample_state(scalar)
|
sample_state = query.initial_sample_state(scalar)
|
||||||
sample_state = query.accumulate_record(params, sample_state, scalar)
|
sample_state = query.accumulate_record(params, sample_state, scalar)
|
||||||
query_result, global_state = query.get_noised_result(
|
query_result, global_state, _ = query.get_noised_result(
|
||||||
sample_state, global_state)
|
sample_state, global_state)
|
||||||
self.assertEqual(query_result, expected_sum)
|
self.assertEqual(query_result, expected_sum)
|
||||||
|
|
||||||
|
@ -282,7 +282,7 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
for i in range(total_steps):
|
for i in range(total_steps):
|
||||||
sample_state = query.initial_sample_state(scalar_value)
|
sample_state = query.initial_sample_state(scalar_value)
|
||||||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||||
query_result, global_state = query.get_noised_result(
|
query_result, global_state, _ = query.get_noised_result(
|
||||||
sample_state, global_state)
|
sample_state, global_state)
|
||||||
# For each streaming step i , the expected value is roughly
|
# For each streaming step i , the expected value is roughly
|
||||||
# `scalar_value*(i+1) + tree_aggregation(tree_node_value, i)`.
|
# `scalar_value*(i+1) + tree_aggregation(tree_node_value, i)`.
|
||||||
|
@ -314,7 +314,7 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
for i in range(total_steps):
|
for i in range(total_steps):
|
||||||
sample_state = query.initial_sample_state(scalar_value)
|
sample_state = query.initial_sample_state(scalar_value)
|
||||||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||||
query_result, global_state = query.get_noised_result(
|
query_result, global_state, _ = query.get_noised_result(
|
||||||
sample_state, global_state)
|
sample_state, global_state)
|
||||||
if i % frequency == frequency - 1:
|
if i % frequency == frequency - 1:
|
||||||
global_state = query.reset_state(query_result, global_state)
|
global_state = query.reset_state(query_result, global_state)
|
||||||
|
@ -456,7 +456,7 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
for i in range(total_steps):
|
for i in range(total_steps):
|
||||||
sample_state = query.initial_sample_state(scalar_value)
|
sample_state = query.initial_sample_state(scalar_value)
|
||||||
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
sample_state = query.accumulate_record(params, sample_state, scalar_value)
|
||||||
query_result, global_state = query.get_noised_result(
|
query_result, global_state, _ = query.get_noised_result(
|
||||||
sample_state, global_state)
|
sample_state, global_state)
|
||||||
if i % frequency == frequency - 1:
|
if i % frequency == frequency - 1:
|
||||||
global_state = query.reset_state(query_result, global_state)
|
global_state = query.reset_state(query_result, global_state)
|
||||||
|
@ -609,7 +609,7 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
params = query.derive_sample_params(global_state)
|
params = query.derive_sample_params(global_state)
|
||||||
preprocessed_record = query.preprocess_record(params, record)
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
sample_state, global_state = query.get_noised_result(
|
sample_state, global_state, _ = query.get_noised_result(
|
||||||
preprocessed_record, global_state)
|
preprocessed_record, global_state)
|
||||||
|
|
||||||
self.assertAllClose(sample_state, expected_tree)
|
self.assertAllClose(sample_state, expected_tree)
|
||||||
|
@ -621,7 +621,7 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
params = query.derive_sample_params(global_state)
|
params = query.derive_sample_params(global_state)
|
||||||
preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.]))
|
preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.]))
|
||||||
sample_state, global_state = query.get_noised_result(
|
sample_state, global_state, _ = query.get_noised_result(
|
||||||
preprocessed_record, global_state)
|
preprocessed_record, global_state)
|
||||||
|
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
|
|
|
@ -20,6 +20,7 @@ import math
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.analysis import dp_event
|
||||||
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
|
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||||
|
@ -189,7 +190,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
# This part is not written in tensorflow and will be executed on the server
|
# This part is not written in tensorflow and will be executed on the server
|
||||||
# side instead of the client side if used with
|
# side instead of the client side if used with
|
||||||
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
|
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
|
||||||
sample_state, inner_query_state = self._inner_query.get_noised_result(
|
sample_state, inner_query_state, _ = self._inner_query.get_noised_result(
|
||||||
sample_state, global_state.inner_query_state)
|
sample_state, global_state.inner_query_state)
|
||||||
new_global_state = TreeRangeSumQuery.GlobalState(
|
new_global_state = TreeRangeSumQuery.GlobalState(
|
||||||
arity=global_state.arity, inner_query_state=inner_query_state)
|
arity=global_state.arity, inner_query_state=inner_query_state)
|
||||||
|
@ -200,7 +201,8 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
]
|
]
|
||||||
tree = tf.RaggedTensor.from_row_splits(
|
tree = tf.RaggedTensor.from_row_splits(
|
||||||
values=sample_state, row_splits=row_splits)
|
values=sample_state, row_splits=row_splits)
|
||||||
return tree, new_global_state
|
event = dp_event.UnsupportedDpEvent()
|
||||||
|
return tree, new_global_state, event
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_central_gaussian_query(cls,
|
def build_central_gaussian_query(cls,
|
||||||
|
|
|
@ -159,7 +159,7 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
params = query.derive_sample_params(global_state)
|
params = query.derive_sample_params(global_state)
|
||||||
preprocessed_record = query.preprocess_record(params, record)
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
sample_state, global_state = query.get_noised_result(
|
sample_state, global_state, _ = query.get_noised_result(
|
||||||
preprocessed_record, global_state)
|
preprocessed_record, global_state)
|
||||||
|
|
||||||
self.assertAllClose(sample_state, expected_tree)
|
self.assertAllClose(sample_state, expected_tree)
|
||||||
|
@ -171,7 +171,7 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
params = query.derive_sample_params(global_state)
|
params = query.derive_sample_params(global_state)
|
||||||
preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.]))
|
preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.]))
|
||||||
sample_state, global_state = query.get_noised_result(
|
sample_state, global_state, _ = query.get_noised_result(
|
||||||
preprocessed_record, global_state)
|
preprocessed_record, global_state)
|
||||||
|
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
|
|
|
@ -164,7 +164,7 @@ def make_optimizer_class(cls):
|
||||||
for idx in range(self._num_microbatches):
|
for idx in range(self._num_microbatches):
|
||||||
sample_state = process_microbatch(idx, sample_state)
|
sample_state = process_microbatch(idx, sample_state)
|
||||||
|
|
||||||
grad_sums, self._global_state = (
|
grad_sums, self._global_state, _ = (
|
||||||
self._dp_sum_query.get_noised_result(sample_state,
|
self._dp_sum_query.get_noised_result(sample_state,
|
||||||
self._global_state))
|
self._global_state))
|
||||||
|
|
||||||
|
@ -235,7 +235,7 @@ def make_optimizer_class(cls):
|
||||||
_, sample_state = tf.while_loop(
|
_, sample_state = tf.while_loop(
|
||||||
cond=cond_fn, body=body_fn, loop_vars=[idx, sample_state])
|
cond=cond_fn, body=body_fn, loop_vars=[idx, sample_state])
|
||||||
|
|
||||||
grad_sums, self._global_state = (
|
grad_sums, self._global_state, _ = (
|
||||||
self._dp_sum_query.get_noised_result(sample_state,
|
self._dp_sum_query.get_noised_result(sample_state,
|
||||||
self._global_state))
|
self._global_state))
|
||||||
|
|
||||||
|
@ -363,10 +363,6 @@ def make_gaussian_optimizer_class(cls):
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@property
|
|
||||||
def ledger(self):
|
|
||||||
return self._dp_sum_query.ledger
|
|
||||||
|
|
||||||
return DPGaussianOptimizerClass
|
return DPGaussianOptimizerClass
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -81,9 +81,10 @@ def make_keras_optimizer_class(cls):
|
||||||
model.fit(...)
|
model.fit(...)
|
||||||
```
|
```
|
||||||
|
|
||||||
""".format(base_class='tf.keras.optimizers.' + cls.__name__,
|
""".format(
|
||||||
short_base_class=cls.__name__,
|
base_class='tf.keras.optimizers.' + cls.__name__,
|
||||||
dp_keras_class='DPKeras' + cls.__name__)
|
short_base_class=cls.__name__,
|
||||||
|
dp_keras_class='DPKeras' + cls.__name__)
|
||||||
|
|
||||||
# The class tf.keras.optimizers.Optimizer has two methods to compute
|
# The class tf.keras.optimizers.Optimizer has two methods to compute
|
||||||
# gradients, `_compute_gradients` and `get_gradients`. The first works
|
# gradients, `_compute_gradients` and `get_gradients`. The first works
|
||||||
|
@ -106,8 +107,8 @@ def make_keras_optimizer_class(cls):
|
||||||
Args:
|
Args:
|
||||||
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
|
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
|
||||||
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
||||||
num_microbatches: Number of microbatches into which each minibatch
|
num_microbatches: Number of microbatches into which each minibatch is
|
||||||
is split.
|
split.
|
||||||
*args: These will be passed on to the base class `__init__` method.
|
*args: These will be passed on to the base class `__init__` method.
|
||||||
**kwargs: These will be passed on to the base class `__init__` method.
|
**kwargs: These will be passed on to the base class `__init__` method.
|
||||||
"""
|
"""
|
||||||
|
@ -210,7 +211,7 @@ def make_keras_optimizer_class(cls):
|
||||||
sample_state = self._dp_sum_query.initial_sample_state(params)
|
sample_state = self._dp_sum_query.initial_sample_state(params)
|
||||||
for idx in range(self._num_microbatches):
|
for idx in range(self._num_microbatches):
|
||||||
sample_state = process_microbatch(idx, sample_state)
|
sample_state = process_microbatch(idx, sample_state)
|
||||||
grad_sums, self._global_state = (
|
grad_sums, self._global_state, _ = (
|
||||||
self._dp_sum_query.get_noised_result(sample_state,
|
self._dp_sum_query.get_noised_result(sample_state,
|
||||||
self._global_state))
|
self._global_state))
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue