diff --git a/privacy/optimizers/gaussian_query.py b/privacy/optimizers/gaussian_query.py index 3df8cbf..c2458d9 100644 --- a/privacy/optimizers/gaussian_query.py +++ b/privacy/optimizers/gaussian_query.py @@ -19,8 +19,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections - import tensorflow as tf from privacy.optimizers import dp_query @@ -34,10 +32,6 @@ class GaussianSumQuery(dp_query.DPQuery): Accumulates clipped vectors, then adds Gaussian noise to the sum. """ - # pylint: disable=invalid-name - _GlobalState = collections.namedtuple( - '_GlobalState', ['l2_norm_clip', 'stddev']) - def __init__(self, l2_norm_clip, stddev, ledger=None): """Initializes the GaussianSumQuery. @@ -47,13 +41,13 @@ class GaussianSumQuery(dp_query.DPQuery): stddev: The stddev of the noise added to the sum. ledger: The privacy ledger to which queries should be recorded. """ - self._l2_norm_clip = l2_norm_clip - self._stddev = stddev + self._l2_norm_clip = tf.to_float(l2_norm_clip) + self._stddev = tf.to_float(stddev) self._ledger = ledger def initial_global_state(self): """Returns the initial global state for the GaussianSumQuery.""" - return self._GlobalState(float(self._l2_norm_clip), float(self._stddev)) + return None def derive_sample_params(self, global_state): """Given the global state, derives parameters to use for the next sample. @@ -64,7 +58,7 @@ class GaussianSumQuery(dp_query.DPQuery): Returns: Parameters to use to process records in the next sample. """ - return global_state.l2_norm_clip + return self._l2_norm_clip def initial_sample_state(self, global_state, tensors): """Returns an initial state to use for the next sample. @@ -77,7 +71,9 @@ class GaussianSumQuery(dp_query.DPQuery): Returns: An initial sample state. """ if self._ledger: - dependencies = [self._ledger.record_sum_query(*global_state)] + dependencies = [ + self._ledger.record_sum_query(self._l2_norm_clip, self._stddev) + ] else: dependencies = [] with tf.control_dependencies(dependencies): @@ -112,7 +108,7 @@ class GaussianSumQuery(dp_query.DPQuery): sum of the records and "new_global_state" is the updated global state. """ def add_noise(v): - return v + tf.random_normal(tf.shape(v), stddev=global_state.stddev) + return v + tf.random_normal(tf.shape(v), stddev=self._stddev) return nest.map_structure(add_noise, sample_state), global_state @@ -128,10 +124,6 @@ class GaussianAverageQuery(dp_query.DPQuery): variance estimator. """ - # pylint: disable=invalid-name - _GlobalState = collections.namedtuple( - '_GlobalState', ['sum_state', 'denominator']) - def __init__(self, l2_norm_clip, sum_stddev, @@ -149,12 +141,12 @@ class GaussianAverageQuery(dp_query.DPQuery): ledger: The privacy ledger to which queries should be recorded. """ self._numerator = GaussianSumQuery(l2_norm_clip, sum_stddev, ledger) - self._denominator = denominator + self._denominator = tf.to_float(denominator) def initial_global_state(self): """Returns the initial global state for the GaussianAverageQuery.""" - sum_global_state = self._numerator.initial_global_state() - return self._GlobalState(sum_global_state, float(self._denominator)) + # GaussianAverageQuery has no global state beyond the numerator state. + return self._numerator.initial_global_state() def derive_sample_params(self, global_state): """Given the global state, derives parameters to use for the next sample. @@ -165,7 +157,7 @@ class GaussianAverageQuery(dp_query.DPQuery): Returns: Parameters to use to process records in the next sample. """ - return self._numerator.derive_sample_params(global_state.sum_state) + return self._numerator.derive_sample_params(global_state) def initial_sample_state(self, global_state, tensors): """Returns an initial state to use for the next sample. @@ -177,8 +169,8 @@ class GaussianAverageQuery(dp_query.DPQuery): Returns: An initial sample state. """ - # GaussianAverageQuery has no state beyond the sum state. - return self._numerator.initial_sample_state(global_state.sum_state, tensors) + # GaussianAverageQuery has no sample state beyond the sum state. + return self._numerator.initial_sample_state(global_state, tensors) def accumulate_record(self, params, sample_state, record): """Accumulates a single record into the sample state. @@ -205,10 +197,8 @@ class GaussianAverageQuery(dp_query.DPQuery): average of the records and "new_global_state" is the updated global state. """ noised_sum, new_sum_global_state = self._numerator.get_noised_result( - sample_state, global_state.sum_state) - new_global_state = self._GlobalState( - new_sum_global_state, global_state.denominator) + sample_state, global_state) def normalize(v): - return tf.truediv(v, global_state.denominator) + return tf.truediv(v, self._denominator) - return nest.map_structure(normalize, noised_sum), new_global_state + return nest.map_structure(normalize, noised_sum), new_sum_global_state diff --git a/privacy/optimizers/gaussian_query_test.py b/privacy/optimizers/gaussian_query_test.py index b66a442..299372e 100644 --- a/privacy/optimizers/gaussian_query_test.py +++ b/privacy/optimizers/gaussian_query_test.py @@ -53,6 +53,28 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase): expected = [1.0, 1.0] self.assertAllClose(result, expected) + def test_gaussian_sum_with_changing_clip_no_noise(self): + with self.cached_session() as sess: + record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0]. + record2 = tf.constant([4.0, -3.0]) # Not clipped. + + l2_norm_clip = tf.Variable(5.0) + l2_norm_clip_placeholder = tf.placeholder(tf.float32) + assign_l2_norm_clip = tf.assign(l2_norm_clip, l2_norm_clip_placeholder) + query = gaussian_query.GaussianSumQuery( + l2_norm_clip=l2_norm_clip, stddev=0.0) + query_result = test_utils.run_query(query, [record1, record2]) + + self.evaluate(tf.global_variables_initializer()) + result = sess.run(query_result) + expected = [1.0, 1.0] + self.assertAllClose(result, expected) + + sess.run(assign_l2_norm_clip, {l2_norm_clip_placeholder: 0.0}) + result = sess.run(query_result) + expected = [0.0, 0.0] + self.assertAllClose(result, expected) + def test_gaussian_sum_with_noise(self): with self.cached_session() as sess: record1, record2 = 2.71828, 3.14159