From 4d0ab48c3563c9254fc93ef651d8e86b507e1256 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 8 Feb 2019 11:21:20 -0800 Subject: [PATCH] Add privacy ledger. The privacy ledger keeps a record of all sampling and query events for analysis post hoc by the privacy accountant. PiperOrigin-RevId: 233094012 --- privacy/analysis/privacy_ledger.py | 199 ++++++++++++++++++++++++ privacy/analysis/privacy_ledger_test.py | 134 ++++++++++++++++ privacy/analysis/tensor_buffer.py | 91 +++++++++++ privacy/analysis/tensor_buffer_test.py | 73 +++++++++ privacy/optimizers/dp_optimizer.py | 14 ++ privacy/optimizers/dp_optimizer_test.py | 28 +++- privacy/optimizers/gaussian_query.py | 21 ++- tutorials/mnist_dpsgd_tutorial.py | 3 +- 8 files changed, 553 insertions(+), 10 deletions(-) create mode 100644 privacy/analysis/privacy_ledger.py create mode 100644 privacy/analysis/privacy_ledger_test.py create mode 100644 privacy/analysis/tensor_buffer.py create mode 100644 privacy/analysis/tensor_buffer_test.py diff --git a/privacy/analysis/privacy_ledger.py b/privacy/analysis/privacy_ledger.py new file mode 100644 index 0000000..1062e08 --- /dev/null +++ b/privacy/analysis/privacy_ledger.py @@ -0,0 +1,199 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PrivacyLedger class for keeping a record of private queries. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import tensorflow as tf + +from tensorflow_privacy.privacy.analysis import tensor_buffer +from tensorflow_privacy.privacy.optimizers import dp_query + +nest = tf.contrib.framework.nest + +SampleEntry = collections.namedtuple( # pylint: disable=invalid-name + 'SampleEntry', ['population_size', 'selection_probability', 'queries']) + +GaussianSumQueryEntry = collections.namedtuple( # pylint: disable=invalid-name + 'GaussianSumQueryEntry', ['l2_norm_bound', 'noise_stddev']) + + +class PrivacyLedger(object): + """Class for keeping a record of private queries. + + The PrivacyLedger keeps a record of all queries executed over a given dataset + for the purpose of computing privacy guarantees. + """ + + def __init__( + self, + population_size, + selection_probability, + max_samples, + max_queries): + """Initialize the PrivacyLedger. + + Args: + population_size: An integer (may be variable) specifying the size of the + population. + selection_probability: A float (may be variable) specifying the + probability each record is included in a sample. + max_samples: The maximum number of samples. An exception is thrown if + more than this many samples are recorded. + max_queries: The maximum number of queries. An exception is thrown if + more than this many queries are recorded. + """ + self._population_size = population_size + self._selection_probability = selection_probability + + # The query buffer stores rows corresponding to GaussianSumQueryEntries. + self._query_buffer = tensor_buffer.TensorBuffer( + max_queries, [3], tf.float32, 'query') + self._sample_var = tf.Variable( + initial_value=tf.zeros([3]), trainable=False, name='sample') + + # The sample buffer stores rows corresponding to SampleEntries. + self._sample_buffer = tensor_buffer.TensorBuffer( + max_samples, [3], tf.float32, 'sample') + self._sample_count = tf.Variable( + initial_value=0.0, trainable=False, name='sample_count') + self._query_count = tf.Variable( + initial_value=0.0, trainable=False, name='query_count') + self._cs = tf.contrib.framework.CriticalSection() + + def record_sum_query(self, l2_norm_bound, noise_stddev): + """Records that a query was issued. + + Args: + l2_norm_bound: The maximum l2 norm of the tensor group in the query. + noise_stddev: The standard deviation of the noise applied to the sum. + + Returns: + An operation recording the sum query to the ledger. + """ + def _do_record_query(): + with tf.control_dependencies([ + tf.assign(self._query_count, self._query_count + 1)]): + return self._query_buffer.append( + [self._sample_count, l2_norm_bound, noise_stddev]) + + return self._cs.execute(_do_record_query) + + def finalize_sample(self): + """Finalizes sample and records sample ledger entry.""" + with tf.control_dependencies([ + tf.assign( + self._sample_var, + [self._population_size, + self._selection_probability, + self._query_count])]): + with tf.control_dependencies([ + tf.assign(self._sample_count, self._sample_count + 1), + tf.assign(self._query_count, 0)]): + return self._sample_buffer.append(self._sample_var) + + def _format_ledger(self, sample_array, query_array): + """Converts underlying representation into a list of SampleEntries.""" + samples = [] + query_pos = 0 + sample_pos = 0 + for sample in sample_array: + num_queries = int(sample[2]) + queries = [] + for _ in range(num_queries): + query = query_array[query_pos] + assert int(query[0]) == sample_pos + queries.append(GaussianSumQueryEntry(*query[1:])) + query_pos += 1 + samples.append(SampleEntry(sample[0], sample[1], queries)) + sample_pos += 1 + return samples + + def get_formatted_ledger(self, sess): + """Gets the formatted query ledger. + + Args: + sess: The tensorflow session in which the ledger was created. + + Returns: + The query ledger as a list of SampleEntries. + """ + sample_array = sess.run(self._sample_buffer.values) + query_array = sess.run(self._query_buffer.values) + + return self._format_ledger(sample_array, query_array) + + def get_formatted_ledger_eager(self): + """Gets the formatted query ledger. + + Returns: + The query ledger as a list of SampleEntries. + """ + sample_array = self._sample_buffer.values.numpy() + query_array = self._query_buffer.values.numpy() + + return self._format_ledger(sample_array, query_array) + + +class QueryWithLedger(dp_query.DPQuery): + """A class for DP queries that record events to a PrivacyLedger. + + QueryWithLedger should be the top-level query in a structure of queries that + may include sum queries, nested queries, etc. It should simply wrap another + query and contain a reference to the ledger. Any contained queries (including + those contained in the leaves of a nested query) should also contain a + reference to the same ledger object. + + For example usage, see privacy_ledger_test.py. + """ + + def __init__(self, query, ledger): + """Initializes the QueryWithLedger. + + Args: + query: The query whose events should be recorded to the ledger. Any + subqueries (including those in the leaves of a nested query) should + also contain a reference to the same ledger given here. + ledger: A PrivacyLedger to which privacy events should be recorded. + """ + self._query = query + self._ledger = ledger + + def initial_global_state(self): + """See base class.""" + return self._query.initial_global_state() + + def derive_sample_params(self, global_state): + """See base class.""" + return self._query.derive_sample_params(global_state) + + def initial_sample_state(self, global_state, tensors): + """See base class.""" + return self._query.initial_sample_state(global_state, tensors) + + def accumulate_record(self, params, sample_state, record): + """See base class.""" + return self._query.accumulate_record(params, sample_state, record) + + def get_noised_result(self, sample_state, global_state): + """Ensures sample is recorded to the ledger and returns noised result.""" + with tf.control_dependencies(nest.flatten(sample_state)): + with tf.control_dependencies([self._ledger.finalize_sample()]): + return self._query.get_noised_result(sample_state, global_state) diff --git a/privacy/analysis/privacy_ledger_test.py b/privacy/analysis/privacy_ledger_test.py new file mode 100644 index 0000000..48911a7 --- /dev/null +++ b/privacy/analysis/privacy_ledger_test.py @@ -0,0 +1,134 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for PrivacyLedger.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow_privacy.privacy.analysis import privacy_ledger +from tensorflow_privacy.privacy.optimizers import gaussian_query +from tensorflow_privacy.privacy.optimizers import nested_query +from tensorflow_privacy.privacy.optimizers import test_utils + +tf.enable_eager_execution() + + +class PrivacyLedgerTest(tf.test.TestCase): + + def test_basic(self): + ledger = privacy_ledger.PrivacyLedger(10, 0.1, 50, 50) + ledger.record_sum_query(5.0, 1.0) + ledger.record_sum_query(2.0, 0.5) + + ledger.finalize_sample() + + expected_queries = [[5.0, 1.0], [2.0, 0.5]] + formatted = ledger.get_formatted_ledger_eager() + + sample = formatted[0] + self.assertAllClose(sample.population_size, 10.0) + self.assertAllClose(sample.selection_probability, 0.1) + self.assertAllClose(sorted(sample.queries), sorted(expected_queries)) + + def test_sum_query(self): + record1 = tf.constant([2.0, 0.0]) + record2 = tf.constant([-1.0, 1.0]) + + population_size = tf.Variable(0) + selection_probability = tf.Variable(0.0) + ledger = privacy_ledger.PrivacyLedger( + population_size, selection_probability, 50, 50) + + query = gaussian_query.GaussianSumQuery( + l2_norm_clip=10.0, stddev=0.0, ledger=ledger) + query = privacy_ledger.QueryWithLedger(query, ledger) + + # First sample. + tf.assign(population_size, 10) + tf.assign(selection_probability, 0.1) + test_utils.run_query(query, [record1, record2]) + + expected_queries = [[10.0, 0.0]] + formatted = ledger.get_formatted_ledger_eager() + sample_1 = formatted[0] + self.assertAllClose(sample_1.population_size, 10.0) + self.assertAllClose(sample_1.selection_probability, 0.1) + self.assertAllClose(sample_1.queries, expected_queries) + + # Second sample. + tf.assign(population_size, 20) + tf.assign(selection_probability, 0.2) + test_utils.run_query(query, [record1, record2]) + + formatted = ledger.get_formatted_ledger_eager() + sample_1, sample_2 = formatted + self.assertAllClose(sample_1.population_size, 10.0) + self.assertAllClose(sample_1.selection_probability, 0.1) + self.assertAllClose(sample_1.queries, expected_queries) + + self.assertAllClose(sample_2.population_size, 20.0) + self.assertAllClose(sample_2.selection_probability, 0.2) + self.assertAllClose(sample_2.queries, expected_queries) + + def test_nested_query(self): + population_size = tf.Variable(0) + selection_probability = tf.Variable(0.0) + ledger = privacy_ledger.PrivacyLedger( + population_size, selection_probability, 50, 50) + + query1 = gaussian_query.GaussianAverageQuery( + l2_norm_clip=4.0, sum_stddev=2.0, denominator=5.0, ledger=ledger) + query2 = gaussian_query.GaussianAverageQuery( + l2_norm_clip=5.0, sum_stddev=1.0, denominator=5.0, ledger=ledger) + + query = nested_query.NestedQuery([query1, query2]) + query = privacy_ledger.QueryWithLedger(query, ledger) + + record1 = [1.0, [12.0, 9.0]] + record2 = [5.0, [1.0, 2.0]] + + # First sample. + tf.assign(population_size, 10) + tf.assign(selection_probability, 0.1) + test_utils.run_query(query, [record1, record2]) + + expected_queries = [[4.0, 2.0], [5.0, 1.0]] + formatted = ledger.get_formatted_ledger_eager() + sample_1 = formatted[0] + self.assertAllClose(sample_1.population_size, 10.0) + self.assertAllClose(sample_1.selection_probability, 0.1) + self.assertAllClose(sorted(sample_1.queries), sorted(expected_queries)) + + # Second sample. + tf.assign(population_size, 20) + tf.assign(selection_probability, 0.2) + test_utils.run_query(query, [record1, record2]) + + formatted = ledger.get_formatted_ledger_eager() + sample_1, sample_2 = formatted + self.assertAllClose(sample_1.population_size, 10.0) + self.assertAllClose(sample_1.selection_probability, 0.1) + self.assertAllClose(sorted(sample_1.queries), sorted(expected_queries)) + + self.assertAllClose(sample_2.population_size, 20.0) + self.assertAllClose(sample_2.selection_probability, 0.2) + self.assertAllClose(sorted(sample_2.queries), sorted(expected_queries)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/privacy/analysis/tensor_buffer.py b/privacy/analysis/tensor_buffer.py new file mode 100644 index 0000000..1b2341d --- /dev/null +++ b/privacy/analysis/tensor_buffer.py @@ -0,0 +1,91 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A lightweight fixed-sized buffer for maintaining lists. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +class TensorBuffer(object): + """A lightweight fixed-sized buffer for maintaining lists. + + The TensorBuffer accumulates tensors of the given shape into a tensor (whose + rank is one more than that of the given shape) via calls to `append`. The + current value of the accumulated tensor can be extracted via the property + `values`. + """ + + def __init__(self, max_size, shape, dtype=tf.int32, name=None): + """Initializes the TensorBuffer. + + Args: + max_size: The maximum size. Attempts to append more than this many rows + will fail with an exception. + shape: The shape (as tuple or list) of the tensors to accumulate. + dtype: The type of the tensors. + name: A string name for the variable_scope used. + + Raises: + ValueError: If the shape is empty (specifies scalar shape). + """ + shape = list(shape) + self._rank = len(shape) + if not self._rank: + raise ValueError('Shape cannot be scalar.') + shape = [max_size] + shape + + with tf.variable_scope(name): + self._buffer = tf.Variable( + initial_value=tf.zeros(shape, dtype), + trainable=False, + name='buffer') + self._size = tf.Variable( + initial_value=0, + trainable=False, + name='size') + + def append(self, value): + """Appends a new tensor to the end of the buffer. + + Args: + value: The tensor to append. Must match the shape specified in the + initializer. + + Returns: + An op appending the new tensor to the end of the buffer. + """ + with tf.control_dependencies([ + tf.assert_less( + self._size, + tf.shape(self._buffer)[0], + message='Appending past end of TensorBuffer.'), + tf.assert_equal( + tf.shape(value), + tf.shape(self._buffer)[1:], + message='Appending value of inconsistent shape.')]): + with tf.control_dependencies( + [tf.assign(self._buffer[self._size, :], value)]): + return tf.assign_add(self._size, 1) + + @property + def values(self): + """Returns the accumulated tensor.""" + begin_value = tf.zeros([self._rank + 1], dtype=tf.int32) + value_size = tf.concat( + [[self._size], tf.constant(-1, tf.int32, [self._rank])], 0) + return tf.slice(self._buffer, begin_value, value_size) diff --git a/privacy/analysis/tensor_buffer_test.py b/privacy/analysis/tensor_buffer_test.py new file mode 100644 index 0000000..b2752c1 --- /dev/null +++ b/privacy/analysis/tensor_buffer_test.py @@ -0,0 +1,73 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tensor_buffer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow_privacy.privacy.analysis import tensor_buffer + +tf.enable_eager_execution() + + +class TensorBufferTest(tf.test.TestCase): + + def test_basic(self): + size, shape = 2, [2, 3] + + my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer') + + value1 = [[1, 2, 3], [4, 5, 6]] + my_buffer.append(value1) + self.assertAllEqual(my_buffer.values.numpy(), [value1]) + + value2 = [[4, 5, 6], [7, 8, 9]] + my_buffer.append(value2) + self.assertAllEqual(my_buffer.values.numpy(), [value1, value2]) + + def test_fail_on_scalar(self): + with self.assertRaisesRegex(ValueError, 'Shape cannot be scalar.'): + tensor_buffer.TensorBuffer(1, ()) + + def test_fail_on_inconsistent_shape(self): + size, shape = 1, [2, 3] + + my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer') + + with self.assertRaisesRegex( + tf.errors.InvalidArgumentError, + 'Appending value of inconsistent shape.'): + my_buffer.append(tf.ones(shape=[3, 4], dtype=tf.int32)) + + def test_fail_on_overflow(self): + size, shape = 2, [2, 3] + + my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer') + + # First two should succeed. + my_buffer.append(tf.ones(shape=shape, dtype=tf.int32)) + my_buffer.append(tf.ones(shape=shape, dtype=tf.int32)) + + # Third one should fail. + with self.assertRaisesRegex( + tf.errors.InvalidArgumentError, + 'Appending past end of TensorBuffer.'): + my_buffer.append(tf.ones(shape=shape, dtype=tf.int32)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/privacy/optimizers/dp_optimizer.py b/privacy/optimizers/dp_optimizer.py index 958354d..f2bc699 100644 --- a/privacy/optimizers/dp_optimizer.py +++ b/privacy/optimizers/dp_optimizer.py @@ -19,6 +19,7 @@ from __future__ import print_function import tensorflow as tf +from tensorflow_privacy.privacy.analysis import privacy_ledger from tensorflow_privacy.privacy.optimizers import gaussian_query @@ -121,6 +122,19 @@ def make_gaussian_optimizer_class(cls): **kwargs): dp_average_query = gaussian_query.GaussianAverageQuery( l2_norm_clip, l2_norm_clip * noise_multiplier, num_microbatches) + if 'population_size' in kwargs: + population_size = kwargs.pop('population_size') + max_queries = kwargs.pop('ledger_max_queries', 1e6) + max_samples = kwargs.pop('ledger_max_samples', 1e6) + selection_probability = num_microbatches / population_size + ledger = privacy_ledger.PrivacyLedger( + population_size, + selection_probability, + max_samples, + max_queries) + dp_average_query = privacy_ledger.QueryWithLedger( + dp_average_query, ledger) + super(DPGaussianOptimizerClass, self).__init__( dp_average_query, num_microbatches, diff --git a/privacy/optimizers/dp_optimizer_test.py b/privacy/optimizers/dp_optimizer_test.py index 91e12fb..2413b92 100644 --- a/privacy/optimizers/dp_optimizer_test.py +++ b/privacy/optimizers/dp_optimizer_test.py @@ -22,6 +22,7 @@ import mock import numpy as np import tensorflow as tf +from tensorflow_privacy.privacy.analysis import privacy_ledger from tensorflow_privacy.privacy.optimizers import dp_optimizer from tensorflow_privacy.privacy.optimizers import gaussian_query @@ -52,8 +53,12 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): var0 = tf.Variable([1.0, 2.0]) data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) + ledger = privacy_ledger.PrivacyLedger( + 1e6, num_microbatches / 1e6, 50, 50) dp_average_query = gaussian_query.GaussianAverageQuery( - 1.0e9, 0.0, num_microbatches) + 1.0e9, 0.0, num_microbatches, ledger) + dp_average_query = privacy_ledger.QueryWithLedger( + dp_average_query, ledger) opt = cls( dp_average_query, @@ -79,7 +84,10 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): var0 = tf.Variable([0.0, 0.0]) data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]]) + ledger = privacy_ledger.PrivacyLedger(1e6, 1 / 1e6, 50, 50) dp_average_query = gaussian_query.GaussianAverageQuery(1.0, 0.0, 1) + dp_average_query = privacy_ledger.QueryWithLedger( + dp_average_query, ledger) opt = cls(dp_average_query, num_microbatches=1, learning_rate=2.0) @@ -101,7 +109,10 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): var0 = tf.Variable([0.0]) data0 = tf.Variable([[0.0]]) + ledger = privacy_ledger.PrivacyLedger(1e6, 1 / 1e6, 5000, 5000) dp_average_query = gaussian_query.GaussianAverageQuery(4.0, 8.0, 1) + dp_average_query = privacy_ledger.QueryWithLedger( + dp_average_query, ledger) opt = cls(dp_average_query, num_microbatches=1, learning_rate=2.0) @@ -142,7 +153,10 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): vector_loss = tf.squared_difference(labels, preds) scalar_loss = tf.reduce_mean(vector_loss) + ledger = privacy_ledger.PrivacyLedger(1e6, 1 / 1e6, 500, 500) dp_average_query = gaussian_query.GaussianAverageQuery(1.0, 0.0, 1) + dp_average_query = privacy_ledger.QueryWithLedger( + dp_average_query, ledger) optimizer = dp_optimizer.DPGradientDescentOptimizer( dp_average_query, num_microbatches=1, @@ -182,11 +196,17 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): var0 = tf.Variable([1.0, 2.0]) data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) - dp_average_query = ( - gaussian_query.GaussianAverageQuery(1.0e9, 0.0, 4)) + num_microbatches = 4 + + ledger = privacy_ledger.PrivacyLedger( + 1e6, num_microbatches / 1e6, 50, 50) + dp_average_query = gaussian_query.GaussianAverageQuery(1.0e9, 0.0, 4) + dp_average_query = privacy_ledger.QueryWithLedger( + dp_average_query, ledger) + opt = cls( dp_average_query, - num_microbatches=4, + num_microbatches=num_microbatches, learning_rate=2.0, unroll_microbatches=True) diff --git a/privacy/optimizers/gaussian_query.py b/privacy/optimizers/gaussian_query.py index 9f865a8..5d9dd8c 100644 --- a/privacy/optimizers/gaussian_query.py +++ b/privacy/optimizers/gaussian_query.py @@ -38,16 +38,18 @@ class GaussianSumQuery(dp_query.DPQuery): _GlobalState = collections.namedtuple( '_GlobalState', ['l2_norm_clip', 'stddev']) - def __init__(self, l2_norm_clip, stddev): + def __init__(self, l2_norm_clip, stddev, ledger=None): """Initializes the GaussianSumQuery. Args: l2_norm_clip: The clipping norm to apply to the global norm of each record. stddev: The stddev of the noise added to the sum. + ledger: The privacy ledger to which queries should be recorded. """ self._l2_norm_clip = l2_norm_clip self._stddev = stddev + self._ledger = ledger def initial_global_state(self): """Returns the initial global state for the GaussianSumQuery.""" @@ -74,8 +76,12 @@ class GaussianSumQuery(dp_query.DPQuery): Returns: An initial sample state. """ - del global_state # unused. - return nest.map_structure(tf.zeros_like, tensors) + if self._ledger: + dependencies = [self._ledger.record_sum_query(*global_state)] + else: + dependencies = [] + with tf.control_dependencies(dependencies): + return nest.map_structure(tf.zeros_like, tensors) def accumulate_record(self, params, sample_state, record): """Accumulates a single record into the sample state. @@ -126,7 +132,11 @@ class GaussianAverageQuery(dp_query.DPQuery): _GlobalState = collections.namedtuple( '_GlobalState', ['sum_state', 'denominator']) - def __init__(self, l2_norm_clip, sum_stddev, denominator): + def __init__(self, + l2_norm_clip, + sum_stddev, + denominator, + ledger=None): """Initializes the GaussianAverageQuery. Args: @@ -136,8 +146,9 @@ class GaussianAverageQuery(dp_query.DPQuery): normalization). denominator: The normalization constant (applied after noise is added to the sum). + ledger: The privacy ledger to which queries should be recorded. """ - self._numerator = GaussianSumQuery(l2_norm_clip, sum_stddev) + self._numerator = GaussianSumQuery(l2_norm_clip, sum_stddev, ledger) self._denominator = denominator def initial_global_state(self): diff --git a/tutorials/mnist_dpsgd_tutorial.py b/tutorials/mnist_dpsgd_tutorial.py index 1a4a1dc..0162c78 100644 --- a/tutorials/mnist_dpsgd_tutorial.py +++ b/tutorials/mnist_dpsgd_tutorial.py @@ -76,7 +76,8 @@ def cnn_model_fn(features, labels, mode): l2_norm_clip=FLAGS.l2_norm_clip, noise_multiplier=FLAGS.noise_multiplier, num_microbatches=FLAGS.microbatches, - learning_rate=FLAGS.learning_rate) + learning_rate=FLAGS.learning_rate, + population_size=60000) opt_loss = vector_loss else: optimizer = tf.train.GradientDescentOptimizer(