Simplify GaussianQuery by removing _GlobalState.

The global state for DP query is intended for aspects of the query that change across samples under the query's own control. It was therefore unnecessary to wrap "l2_norm_clip" and "sum_stddev" in the namedtuple _GlobalState for the basic GaussianQuery classes.

PiperOrigin-RevId: 237528962
This commit is contained in:
Galen Andrew 2019-03-08 15:17:30 -08:00 committed by A. Unique TensorFlower
parent f85c04c072
commit e566967ff6
2 changed files with 39 additions and 27 deletions

View file

@ -19,8 +19,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import tensorflow as tf
from privacy.optimizers import dp_query
@ -34,10 +32,6 @@ class GaussianSumQuery(dp_query.DPQuery):
Accumulates clipped vectors, then adds Gaussian noise to the sum.
"""
# pylint: disable=invalid-name
_GlobalState = collections.namedtuple(
'_GlobalState', ['l2_norm_clip', 'stddev'])
def __init__(self, l2_norm_clip, stddev, ledger=None):
"""Initializes the GaussianSumQuery.
@ -47,13 +41,13 @@ class GaussianSumQuery(dp_query.DPQuery):
stddev: The stddev of the noise added to the sum.
ledger: The privacy ledger to which queries should be recorded.
"""
self._l2_norm_clip = l2_norm_clip
self._stddev = stddev
self._l2_norm_clip = tf.to_float(l2_norm_clip)
self._stddev = tf.to_float(stddev)
self._ledger = ledger
def initial_global_state(self):
"""Returns the initial global state for the GaussianSumQuery."""
return self._GlobalState(float(self._l2_norm_clip), float(self._stddev))
return None
def derive_sample_params(self, global_state):
"""Given the global state, derives parameters to use for the next sample.
@ -64,7 +58,7 @@ class GaussianSumQuery(dp_query.DPQuery):
Returns:
Parameters to use to process records in the next sample.
"""
return global_state.l2_norm_clip
return self._l2_norm_clip
def initial_sample_state(self, global_state, tensors):
"""Returns an initial state to use for the next sample.
@ -77,7 +71,9 @@ class GaussianSumQuery(dp_query.DPQuery):
Returns: An initial sample state.
"""
if self._ledger:
dependencies = [self._ledger.record_sum_query(*global_state)]
dependencies = [
self._ledger.record_sum_query(self._l2_norm_clip, self._stddev)
]
else:
dependencies = []
with tf.control_dependencies(dependencies):
@ -112,7 +108,7 @@ class GaussianSumQuery(dp_query.DPQuery):
sum of the records and "new_global_state" is the updated global state.
"""
def add_noise(v):
return v + tf.random_normal(tf.shape(v), stddev=global_state.stddev)
return v + tf.random_normal(tf.shape(v), stddev=self._stddev)
return nest.map_structure(add_noise, sample_state), global_state
@ -128,10 +124,6 @@ class GaussianAverageQuery(dp_query.DPQuery):
variance estimator.
"""
# pylint: disable=invalid-name
_GlobalState = collections.namedtuple(
'_GlobalState', ['sum_state', 'denominator'])
def __init__(self,
l2_norm_clip,
sum_stddev,
@ -149,12 +141,12 @@ class GaussianAverageQuery(dp_query.DPQuery):
ledger: The privacy ledger to which queries should be recorded.
"""
self._numerator = GaussianSumQuery(l2_norm_clip, sum_stddev, ledger)
self._denominator = denominator
self._denominator = tf.to_float(denominator)
def initial_global_state(self):
"""Returns the initial global state for the GaussianAverageQuery."""
sum_global_state = self._numerator.initial_global_state()
return self._GlobalState(sum_global_state, float(self._denominator))
# GaussianAverageQuery has no global state beyond the numerator state.
return self._numerator.initial_global_state()
def derive_sample_params(self, global_state):
"""Given the global state, derives parameters to use for the next sample.
@ -165,7 +157,7 @@ class GaussianAverageQuery(dp_query.DPQuery):
Returns:
Parameters to use to process records in the next sample.
"""
return self._numerator.derive_sample_params(global_state.sum_state)
return self._numerator.derive_sample_params(global_state)
def initial_sample_state(self, global_state, tensors):
"""Returns an initial state to use for the next sample.
@ -177,8 +169,8 @@ class GaussianAverageQuery(dp_query.DPQuery):
Returns: An initial sample state.
"""
# GaussianAverageQuery has no state beyond the sum state.
return self._numerator.initial_sample_state(global_state.sum_state, tensors)
# GaussianAverageQuery has no sample state beyond the sum state.
return self._numerator.initial_sample_state(global_state, tensors)
def accumulate_record(self, params, sample_state, record):
"""Accumulates a single record into the sample state.
@ -205,10 +197,8 @@ class GaussianAverageQuery(dp_query.DPQuery):
average of the records and "new_global_state" is the updated global state.
"""
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
sample_state, global_state.sum_state)
new_global_state = self._GlobalState(
new_sum_global_state, global_state.denominator)
sample_state, global_state)
def normalize(v):
return tf.truediv(v, global_state.denominator)
return tf.truediv(v, self._denominator)
return nest.map_structure(normalize, noised_sum), new_global_state
return nest.map_structure(normalize, noised_sum), new_sum_global_state

View file

@ -53,6 +53,28 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
expected = [1.0, 1.0]
self.assertAllClose(result, expected)
def test_gaussian_sum_with_changing_clip_no_noise(self):
with self.cached_session() as sess:
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
record2 = tf.constant([4.0, -3.0]) # Not clipped.
l2_norm_clip = tf.Variable(5.0)
l2_norm_clip_placeholder = tf.placeholder(tf.float32)
assign_l2_norm_clip = tf.assign(l2_norm_clip, l2_norm_clip_placeholder)
query = gaussian_query.GaussianSumQuery(
l2_norm_clip=l2_norm_clip, stddev=0.0)
query_result = test_utils.run_query(query, [record1, record2])
self.evaluate(tf.global_variables_initializer())
result = sess.run(query_result)
expected = [1.0, 1.0]
self.assertAllClose(result, expected)
sess.run(assign_l2_norm_clip, {l2_norm_clip_placeholder: 0.0})
result = sess.run(query_result)
expected = [0.0, 0.0]
self.assertAllClose(result, expected)
def test_gaussian_sum_with_noise(self):
with self.cached_session() as sess:
record1, record2 = 2.71828, 3.14159