forked from 626_privacy/tensorflow_privacy
Simplify GaussianQuery by removing _GlobalState.
The global state for DP query is intended for aspects of the query that change across samples under the query's own control. It was therefore unnecessary to wrap "l2_norm_clip" and "sum_stddev" in the namedtuple _GlobalState for the basic GaussianQuery classes. PiperOrigin-RevId: 237528962
This commit is contained in:
parent
f85c04c072
commit
e566967ff6
2 changed files with 39 additions and 27 deletions
|
@ -19,8 +19,6 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from privacy.optimizers import dp_query
|
from privacy.optimizers import dp_query
|
||||||
|
@ -34,10 +32,6 @@ class GaussianSumQuery(dp_query.DPQuery):
|
||||||
Accumulates clipped vectors, then adds Gaussian noise to the sum.
|
Accumulates clipped vectors, then adds Gaussian noise to the sum.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
_GlobalState = collections.namedtuple(
|
|
||||||
'_GlobalState', ['l2_norm_clip', 'stddev'])
|
|
||||||
|
|
||||||
def __init__(self, l2_norm_clip, stddev, ledger=None):
|
def __init__(self, l2_norm_clip, stddev, ledger=None):
|
||||||
"""Initializes the GaussianSumQuery.
|
"""Initializes the GaussianSumQuery.
|
||||||
|
|
||||||
|
@ -47,13 +41,13 @@ class GaussianSumQuery(dp_query.DPQuery):
|
||||||
stddev: The stddev of the noise added to the sum.
|
stddev: The stddev of the noise added to the sum.
|
||||||
ledger: The privacy ledger to which queries should be recorded.
|
ledger: The privacy ledger to which queries should be recorded.
|
||||||
"""
|
"""
|
||||||
self._l2_norm_clip = l2_norm_clip
|
self._l2_norm_clip = tf.to_float(l2_norm_clip)
|
||||||
self._stddev = stddev
|
self._stddev = tf.to_float(stddev)
|
||||||
self._ledger = ledger
|
self._ledger = ledger
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""Returns the initial global state for the GaussianSumQuery."""
|
"""Returns the initial global state for the GaussianSumQuery."""
|
||||||
return self._GlobalState(float(self._l2_norm_clip), float(self._stddev))
|
return None
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""Given the global state, derives parameters to use for the next sample.
|
"""Given the global state, derives parameters to use for the next sample.
|
||||||
|
@ -64,7 +58,7 @@ class GaussianSumQuery(dp_query.DPQuery):
|
||||||
Returns:
|
Returns:
|
||||||
Parameters to use to process records in the next sample.
|
Parameters to use to process records in the next sample.
|
||||||
"""
|
"""
|
||||||
return global_state.l2_norm_clip
|
return self._l2_norm_clip
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, tensors):
|
def initial_sample_state(self, global_state, tensors):
|
||||||
"""Returns an initial state to use for the next sample.
|
"""Returns an initial state to use for the next sample.
|
||||||
|
@ -77,7 +71,9 @@ class GaussianSumQuery(dp_query.DPQuery):
|
||||||
Returns: An initial sample state.
|
Returns: An initial sample state.
|
||||||
"""
|
"""
|
||||||
if self._ledger:
|
if self._ledger:
|
||||||
dependencies = [self._ledger.record_sum_query(*global_state)]
|
dependencies = [
|
||||||
|
self._ledger.record_sum_query(self._l2_norm_clip, self._stddev)
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
dependencies = []
|
dependencies = []
|
||||||
with tf.control_dependencies(dependencies):
|
with tf.control_dependencies(dependencies):
|
||||||
|
@ -112,7 +108,7 @@ class GaussianSumQuery(dp_query.DPQuery):
|
||||||
sum of the records and "new_global_state" is the updated global state.
|
sum of the records and "new_global_state" is the updated global state.
|
||||||
"""
|
"""
|
||||||
def add_noise(v):
|
def add_noise(v):
|
||||||
return v + tf.random_normal(tf.shape(v), stddev=global_state.stddev)
|
return v + tf.random_normal(tf.shape(v), stddev=self._stddev)
|
||||||
|
|
||||||
return nest.map_structure(add_noise, sample_state), global_state
|
return nest.map_structure(add_noise, sample_state), global_state
|
||||||
|
|
||||||
|
@ -128,10 +124,6 @@ class GaussianAverageQuery(dp_query.DPQuery):
|
||||||
variance estimator.
|
variance estimator.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
_GlobalState = collections.namedtuple(
|
|
||||||
'_GlobalState', ['sum_state', 'denominator'])
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
l2_norm_clip,
|
l2_norm_clip,
|
||||||
sum_stddev,
|
sum_stddev,
|
||||||
|
@ -149,12 +141,12 @@ class GaussianAverageQuery(dp_query.DPQuery):
|
||||||
ledger: The privacy ledger to which queries should be recorded.
|
ledger: The privacy ledger to which queries should be recorded.
|
||||||
"""
|
"""
|
||||||
self._numerator = GaussianSumQuery(l2_norm_clip, sum_stddev, ledger)
|
self._numerator = GaussianSumQuery(l2_norm_clip, sum_stddev, ledger)
|
||||||
self._denominator = denominator
|
self._denominator = tf.to_float(denominator)
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""Returns the initial global state for the GaussianAverageQuery."""
|
"""Returns the initial global state for the GaussianAverageQuery."""
|
||||||
sum_global_state = self._numerator.initial_global_state()
|
# GaussianAverageQuery has no global state beyond the numerator state.
|
||||||
return self._GlobalState(sum_global_state, float(self._denominator))
|
return self._numerator.initial_global_state()
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""Given the global state, derives parameters to use for the next sample.
|
"""Given the global state, derives parameters to use for the next sample.
|
||||||
|
@ -165,7 +157,7 @@ class GaussianAverageQuery(dp_query.DPQuery):
|
||||||
Returns:
|
Returns:
|
||||||
Parameters to use to process records in the next sample.
|
Parameters to use to process records in the next sample.
|
||||||
"""
|
"""
|
||||||
return self._numerator.derive_sample_params(global_state.sum_state)
|
return self._numerator.derive_sample_params(global_state)
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, tensors):
|
def initial_sample_state(self, global_state, tensors):
|
||||||
"""Returns an initial state to use for the next sample.
|
"""Returns an initial state to use for the next sample.
|
||||||
|
@ -177,8 +169,8 @@ class GaussianAverageQuery(dp_query.DPQuery):
|
||||||
|
|
||||||
Returns: An initial sample state.
|
Returns: An initial sample state.
|
||||||
"""
|
"""
|
||||||
# GaussianAverageQuery has no state beyond the sum state.
|
# GaussianAverageQuery has no sample state beyond the sum state.
|
||||||
return self._numerator.initial_sample_state(global_state.sum_state, tensors)
|
return self._numerator.initial_sample_state(global_state, tensors)
|
||||||
|
|
||||||
def accumulate_record(self, params, sample_state, record):
|
def accumulate_record(self, params, sample_state, record):
|
||||||
"""Accumulates a single record into the sample state.
|
"""Accumulates a single record into the sample state.
|
||||||
|
@ -205,10 +197,8 @@ class GaussianAverageQuery(dp_query.DPQuery):
|
||||||
average of the records and "new_global_state" is the updated global state.
|
average of the records and "new_global_state" is the updated global state.
|
||||||
"""
|
"""
|
||||||
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
|
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
|
||||||
sample_state, global_state.sum_state)
|
sample_state, global_state)
|
||||||
new_global_state = self._GlobalState(
|
|
||||||
new_sum_global_state, global_state.denominator)
|
|
||||||
def normalize(v):
|
def normalize(v):
|
||||||
return tf.truediv(v, global_state.denominator)
|
return tf.truediv(v, self._denominator)
|
||||||
|
|
||||||
return nest.map_structure(normalize, noised_sum), new_global_state
|
return nest.map_structure(normalize, noised_sum), new_sum_global_state
|
||||||
|
|
|
@ -53,6 +53,28 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
expected = [1.0, 1.0]
|
expected = [1.0, 1.0]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(result, expected)
|
||||||
|
|
||||||
|
def test_gaussian_sum_with_changing_clip_no_noise(self):
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
|
||||||
|
record2 = tf.constant([4.0, -3.0]) # Not clipped.
|
||||||
|
|
||||||
|
l2_norm_clip = tf.Variable(5.0)
|
||||||
|
l2_norm_clip_placeholder = tf.placeholder(tf.float32)
|
||||||
|
assign_l2_norm_clip = tf.assign(l2_norm_clip, l2_norm_clip_placeholder)
|
||||||
|
query = gaussian_query.GaussianSumQuery(
|
||||||
|
l2_norm_clip=l2_norm_clip, stddev=0.0)
|
||||||
|
query_result = test_utils.run_query(query, [record1, record2])
|
||||||
|
|
||||||
|
self.evaluate(tf.global_variables_initializer())
|
||||||
|
result = sess.run(query_result)
|
||||||
|
expected = [1.0, 1.0]
|
||||||
|
self.assertAllClose(result, expected)
|
||||||
|
|
||||||
|
sess.run(assign_l2_norm_clip, {l2_norm_clip_placeholder: 0.0})
|
||||||
|
result = sess.run(query_result)
|
||||||
|
expected = [0.0, 0.0]
|
||||||
|
self.assertAllClose(result, expected)
|
||||||
|
|
||||||
def test_gaussian_sum_with_noise(self):
|
def test_gaussian_sum_with_noise(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
record1, record2 = 2.71828, 3.14159
|
record1, record2 = 2.71828, 3.14159
|
||||||
|
|
Loading…
Reference in a new issue