forked from 626_privacy/tensorflow_privacy
Simplify GaussianQuery by removing _GlobalState.
The global state for DP query is intended for aspects of the query that change across samples under the query's own control. It was therefore unnecessary to wrap "l2_norm_clip" and "sum_stddev" in the namedtuple _GlobalState for the basic GaussianQuery classes. PiperOrigin-RevId: 237528962
This commit is contained in:
parent
f85c04c072
commit
e566967ff6
2 changed files with 39 additions and 27 deletions
|
@ -19,8 +19,6 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from privacy.optimizers import dp_query
|
||||
|
@ -34,10 +32,6 @@ class GaussianSumQuery(dp_query.DPQuery):
|
|||
Accumulates clipped vectors, then adds Gaussian noise to the sum.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_GlobalState = collections.namedtuple(
|
||||
'_GlobalState', ['l2_norm_clip', 'stddev'])
|
||||
|
||||
def __init__(self, l2_norm_clip, stddev, ledger=None):
|
||||
"""Initializes the GaussianSumQuery.
|
||||
|
||||
|
@ -47,13 +41,13 @@ class GaussianSumQuery(dp_query.DPQuery):
|
|||
stddev: The stddev of the noise added to the sum.
|
||||
ledger: The privacy ledger to which queries should be recorded.
|
||||
"""
|
||||
self._l2_norm_clip = l2_norm_clip
|
||||
self._stddev = stddev
|
||||
self._l2_norm_clip = tf.to_float(l2_norm_clip)
|
||||
self._stddev = tf.to_float(stddev)
|
||||
self._ledger = ledger
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Returns the initial global state for the GaussianSumQuery."""
|
||||
return self._GlobalState(float(self._l2_norm_clip), float(self._stddev))
|
||||
return None
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Given the global state, derives parameters to use for the next sample.
|
||||
|
@ -64,7 +58,7 @@ class GaussianSumQuery(dp_query.DPQuery):
|
|||
Returns:
|
||||
Parameters to use to process records in the next sample.
|
||||
"""
|
||||
return global_state.l2_norm_clip
|
||||
return self._l2_norm_clip
|
||||
|
||||
def initial_sample_state(self, global_state, tensors):
|
||||
"""Returns an initial state to use for the next sample.
|
||||
|
@ -77,7 +71,9 @@ class GaussianSumQuery(dp_query.DPQuery):
|
|||
Returns: An initial sample state.
|
||||
"""
|
||||
if self._ledger:
|
||||
dependencies = [self._ledger.record_sum_query(*global_state)]
|
||||
dependencies = [
|
||||
self._ledger.record_sum_query(self._l2_norm_clip, self._stddev)
|
||||
]
|
||||
else:
|
||||
dependencies = []
|
||||
with tf.control_dependencies(dependencies):
|
||||
|
@ -112,7 +108,7 @@ class GaussianSumQuery(dp_query.DPQuery):
|
|||
sum of the records and "new_global_state" is the updated global state.
|
||||
"""
|
||||
def add_noise(v):
|
||||
return v + tf.random_normal(tf.shape(v), stddev=global_state.stddev)
|
||||
return v + tf.random_normal(tf.shape(v), stddev=self._stddev)
|
||||
|
||||
return nest.map_structure(add_noise, sample_state), global_state
|
||||
|
||||
|
@ -128,10 +124,6 @@ class GaussianAverageQuery(dp_query.DPQuery):
|
|||
variance estimator.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_GlobalState = collections.namedtuple(
|
||||
'_GlobalState', ['sum_state', 'denominator'])
|
||||
|
||||
def __init__(self,
|
||||
l2_norm_clip,
|
||||
sum_stddev,
|
||||
|
@ -149,12 +141,12 @@ class GaussianAverageQuery(dp_query.DPQuery):
|
|||
ledger: The privacy ledger to which queries should be recorded.
|
||||
"""
|
||||
self._numerator = GaussianSumQuery(l2_norm_clip, sum_stddev, ledger)
|
||||
self._denominator = denominator
|
||||
self._denominator = tf.to_float(denominator)
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Returns the initial global state for the GaussianAverageQuery."""
|
||||
sum_global_state = self._numerator.initial_global_state()
|
||||
return self._GlobalState(sum_global_state, float(self._denominator))
|
||||
# GaussianAverageQuery has no global state beyond the numerator state.
|
||||
return self._numerator.initial_global_state()
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Given the global state, derives parameters to use for the next sample.
|
||||
|
@ -165,7 +157,7 @@ class GaussianAverageQuery(dp_query.DPQuery):
|
|||
Returns:
|
||||
Parameters to use to process records in the next sample.
|
||||
"""
|
||||
return self._numerator.derive_sample_params(global_state.sum_state)
|
||||
return self._numerator.derive_sample_params(global_state)
|
||||
|
||||
def initial_sample_state(self, global_state, tensors):
|
||||
"""Returns an initial state to use for the next sample.
|
||||
|
@ -177,8 +169,8 @@ class GaussianAverageQuery(dp_query.DPQuery):
|
|||
|
||||
Returns: An initial sample state.
|
||||
"""
|
||||
# GaussianAverageQuery has no state beyond the sum state.
|
||||
return self._numerator.initial_sample_state(global_state.sum_state, tensors)
|
||||
# GaussianAverageQuery has no sample state beyond the sum state.
|
||||
return self._numerator.initial_sample_state(global_state, tensors)
|
||||
|
||||
def accumulate_record(self, params, sample_state, record):
|
||||
"""Accumulates a single record into the sample state.
|
||||
|
@ -205,10 +197,8 @@ class GaussianAverageQuery(dp_query.DPQuery):
|
|||
average of the records and "new_global_state" is the updated global state.
|
||||
"""
|
||||
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
|
||||
sample_state, global_state.sum_state)
|
||||
new_global_state = self._GlobalState(
|
||||
new_sum_global_state, global_state.denominator)
|
||||
sample_state, global_state)
|
||||
def normalize(v):
|
||||
return tf.truediv(v, global_state.denominator)
|
||||
return tf.truediv(v, self._denominator)
|
||||
|
||||
return nest.map_structure(normalize, noised_sum), new_global_state
|
||||
return nest.map_structure(normalize, noised_sum), new_sum_global_state
|
||||
|
|
|
@ -53,6 +53,28 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
expected = [1.0, 1.0]
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
def test_gaussian_sum_with_changing_clip_no_noise(self):
|
||||
with self.cached_session() as sess:
|
||||
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
|
||||
record2 = tf.constant([4.0, -3.0]) # Not clipped.
|
||||
|
||||
l2_norm_clip = tf.Variable(5.0)
|
||||
l2_norm_clip_placeholder = tf.placeholder(tf.float32)
|
||||
assign_l2_norm_clip = tf.assign(l2_norm_clip, l2_norm_clip_placeholder)
|
||||
query = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=l2_norm_clip, stddev=0.0)
|
||||
query_result = test_utils.run_query(query, [record1, record2])
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
result = sess.run(query_result)
|
||||
expected = [1.0, 1.0]
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
sess.run(assign_l2_norm_clip, {l2_norm_clip_placeholder: 0.0})
|
||||
result = sess.run(query_result)
|
||||
expected = [0.0, 0.0]
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
def test_gaussian_sum_with_noise(self):
|
||||
with self.cached_session() as sess:
|
||||
record1, record2 = 2.71828, 3.14159
|
||||
|
|
Loading…
Reference in a new issue