Add privacy ledger.

The privacy ledger keeps a record of all sampling and query events for analysis post hoc by the privacy accountant.

PiperOrigin-RevId: 233094012
This commit is contained in:
A. Unique TensorFlower 2019-02-08 11:21:20 -08:00
parent 36d9959c19
commit 4d0ab48c35
8 changed files with 553 additions and 10 deletions

View file

@ -0,0 +1,199 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PrivacyLedger class for keeping a record of private queries.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import tensor_buffer
from tensorflow_privacy.privacy.optimizers import dp_query
nest = tf.contrib.framework.nest
SampleEntry = collections.namedtuple( # pylint: disable=invalid-name
'SampleEntry', ['population_size', 'selection_probability', 'queries'])
GaussianSumQueryEntry = collections.namedtuple( # pylint: disable=invalid-name
'GaussianSumQueryEntry', ['l2_norm_bound', 'noise_stddev'])
class PrivacyLedger(object):
"""Class for keeping a record of private queries.
The PrivacyLedger keeps a record of all queries executed over a given dataset
for the purpose of computing privacy guarantees.
"""
def __init__(
self,
population_size,
selection_probability,
max_samples,
max_queries):
"""Initialize the PrivacyLedger.
Args:
population_size: An integer (may be variable) specifying the size of the
population.
selection_probability: A float (may be variable) specifying the
probability each record is included in a sample.
max_samples: The maximum number of samples. An exception is thrown if
more than this many samples are recorded.
max_queries: The maximum number of queries. An exception is thrown if
more than this many queries are recorded.
"""
self._population_size = population_size
self._selection_probability = selection_probability
# The query buffer stores rows corresponding to GaussianSumQueryEntries.
self._query_buffer = tensor_buffer.TensorBuffer(
max_queries, [3], tf.float32, 'query')
self._sample_var = tf.Variable(
initial_value=tf.zeros([3]), trainable=False, name='sample')
# The sample buffer stores rows corresponding to SampleEntries.
self._sample_buffer = tensor_buffer.TensorBuffer(
max_samples, [3], tf.float32, 'sample')
self._sample_count = tf.Variable(
initial_value=0.0, trainable=False, name='sample_count')
self._query_count = tf.Variable(
initial_value=0.0, trainable=False, name='query_count')
self._cs = tf.contrib.framework.CriticalSection()
def record_sum_query(self, l2_norm_bound, noise_stddev):
"""Records that a query was issued.
Args:
l2_norm_bound: The maximum l2 norm of the tensor group in the query.
noise_stddev: The standard deviation of the noise applied to the sum.
Returns:
An operation recording the sum query to the ledger.
"""
def _do_record_query():
with tf.control_dependencies([
tf.assign(self._query_count, self._query_count + 1)]):
return self._query_buffer.append(
[self._sample_count, l2_norm_bound, noise_stddev])
return self._cs.execute(_do_record_query)
def finalize_sample(self):
"""Finalizes sample and records sample ledger entry."""
with tf.control_dependencies([
tf.assign(
self._sample_var,
[self._population_size,
self._selection_probability,
self._query_count])]):
with tf.control_dependencies([
tf.assign(self._sample_count, self._sample_count + 1),
tf.assign(self._query_count, 0)]):
return self._sample_buffer.append(self._sample_var)
def _format_ledger(self, sample_array, query_array):
"""Converts underlying representation into a list of SampleEntries."""
samples = []
query_pos = 0
sample_pos = 0
for sample in sample_array:
num_queries = int(sample[2])
queries = []
for _ in range(num_queries):
query = query_array[query_pos]
assert int(query[0]) == sample_pos
queries.append(GaussianSumQueryEntry(*query[1:]))
query_pos += 1
samples.append(SampleEntry(sample[0], sample[1], queries))
sample_pos += 1
return samples
def get_formatted_ledger(self, sess):
"""Gets the formatted query ledger.
Args:
sess: The tensorflow session in which the ledger was created.
Returns:
The query ledger as a list of SampleEntries.
"""
sample_array = sess.run(self._sample_buffer.values)
query_array = sess.run(self._query_buffer.values)
return self._format_ledger(sample_array, query_array)
def get_formatted_ledger_eager(self):
"""Gets the formatted query ledger.
Returns:
The query ledger as a list of SampleEntries.
"""
sample_array = self._sample_buffer.values.numpy()
query_array = self._query_buffer.values.numpy()
return self._format_ledger(sample_array, query_array)
class QueryWithLedger(dp_query.DPQuery):
"""A class for DP queries that record events to a PrivacyLedger.
QueryWithLedger should be the top-level query in a structure of queries that
may include sum queries, nested queries, etc. It should simply wrap another
query and contain a reference to the ledger. Any contained queries (including
those contained in the leaves of a nested query) should also contain a
reference to the same ledger object.
For example usage, see privacy_ledger_test.py.
"""
def __init__(self, query, ledger):
"""Initializes the QueryWithLedger.
Args:
query: The query whose events should be recorded to the ledger. Any
subqueries (including those in the leaves of a nested query) should
also contain a reference to the same ledger given here.
ledger: A PrivacyLedger to which privacy events should be recorded.
"""
self._query = query
self._ledger = ledger
def initial_global_state(self):
"""See base class."""
return self._query.initial_global_state()
def derive_sample_params(self, global_state):
"""See base class."""
return self._query.derive_sample_params(global_state)
def initial_sample_state(self, global_state, tensors):
"""See base class."""
return self._query.initial_sample_state(global_state, tensors)
def accumulate_record(self, params, sample_state, record):
"""See base class."""
return self._query.accumulate_record(params, sample_state, record)
def get_noised_result(self, sample_state, global_state):
"""Ensures sample is recorded to the ledger and returns noised result."""
with tf.control_dependencies(nest.flatten(sample_state)):
with tf.control_dependencies([self._ledger.finalize_sample()]):
return self._query.get_noised_result(sample_state, global_state)

View file

@ -0,0 +1,134 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for PrivacyLedger."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.optimizers import gaussian_query
from tensorflow_privacy.privacy.optimizers import nested_query
from tensorflow_privacy.privacy.optimizers import test_utils
tf.enable_eager_execution()
class PrivacyLedgerTest(tf.test.TestCase):
def test_basic(self):
ledger = privacy_ledger.PrivacyLedger(10, 0.1, 50, 50)
ledger.record_sum_query(5.0, 1.0)
ledger.record_sum_query(2.0, 0.5)
ledger.finalize_sample()
expected_queries = [[5.0, 1.0], [2.0, 0.5]]
formatted = ledger.get_formatted_ledger_eager()
sample = formatted[0]
self.assertAllClose(sample.population_size, 10.0)
self.assertAllClose(sample.selection_probability, 0.1)
self.assertAllClose(sorted(sample.queries), sorted(expected_queries))
def test_sum_query(self):
record1 = tf.constant([2.0, 0.0])
record2 = tf.constant([-1.0, 1.0])
population_size = tf.Variable(0)
selection_probability = tf.Variable(0.0)
ledger = privacy_ledger.PrivacyLedger(
population_size, selection_probability, 50, 50)
query = gaussian_query.GaussianSumQuery(
l2_norm_clip=10.0, stddev=0.0, ledger=ledger)
query = privacy_ledger.QueryWithLedger(query, ledger)
# First sample.
tf.assign(population_size, 10)
tf.assign(selection_probability, 0.1)
test_utils.run_query(query, [record1, record2])
expected_queries = [[10.0, 0.0]]
formatted = ledger.get_formatted_ledger_eager()
sample_1 = formatted[0]
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sample_1.queries, expected_queries)
# Second sample.
tf.assign(population_size, 20)
tf.assign(selection_probability, 0.2)
test_utils.run_query(query, [record1, record2])
formatted = ledger.get_formatted_ledger_eager()
sample_1, sample_2 = formatted
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sample_1.queries, expected_queries)
self.assertAllClose(sample_2.population_size, 20.0)
self.assertAllClose(sample_2.selection_probability, 0.2)
self.assertAllClose(sample_2.queries, expected_queries)
def test_nested_query(self):
population_size = tf.Variable(0)
selection_probability = tf.Variable(0.0)
ledger = privacy_ledger.PrivacyLedger(
population_size, selection_probability, 50, 50)
query1 = gaussian_query.GaussianAverageQuery(
l2_norm_clip=4.0, sum_stddev=2.0, denominator=5.0, ledger=ledger)
query2 = gaussian_query.GaussianAverageQuery(
l2_norm_clip=5.0, sum_stddev=1.0, denominator=5.0, ledger=ledger)
query = nested_query.NestedQuery([query1, query2])
query = privacy_ledger.QueryWithLedger(query, ledger)
record1 = [1.0, [12.0, 9.0]]
record2 = [5.0, [1.0, 2.0]]
# First sample.
tf.assign(population_size, 10)
tf.assign(selection_probability, 0.1)
test_utils.run_query(query, [record1, record2])
expected_queries = [[4.0, 2.0], [5.0, 1.0]]
formatted = ledger.get_formatted_ledger_eager()
sample_1 = formatted[0]
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sorted(sample_1.queries), sorted(expected_queries))
# Second sample.
tf.assign(population_size, 20)
tf.assign(selection_probability, 0.2)
test_utils.run_query(query, [record1, record2])
formatted = ledger.get_formatted_ledger_eager()
sample_1, sample_2 = formatted
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sorted(sample_1.queries), sorted(expected_queries))
self.assertAllClose(sample_2.population_size, 20.0)
self.assertAllClose(sample_2.selection_probability, 0.2)
self.assertAllClose(sorted(sample_2.queries), sorted(expected_queries))
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,91 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A lightweight fixed-sized buffer for maintaining lists.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class TensorBuffer(object):
"""A lightweight fixed-sized buffer for maintaining lists.
The TensorBuffer accumulates tensors of the given shape into a tensor (whose
rank is one more than that of the given shape) via calls to `append`. The
current value of the accumulated tensor can be extracted via the property
`values`.
"""
def __init__(self, max_size, shape, dtype=tf.int32, name=None):
"""Initializes the TensorBuffer.
Args:
max_size: The maximum size. Attempts to append more than this many rows
will fail with an exception.
shape: The shape (as tuple or list) of the tensors to accumulate.
dtype: The type of the tensors.
name: A string name for the variable_scope used.
Raises:
ValueError: If the shape is empty (specifies scalar shape).
"""
shape = list(shape)
self._rank = len(shape)
if not self._rank:
raise ValueError('Shape cannot be scalar.')
shape = [max_size] + shape
with tf.variable_scope(name):
self._buffer = tf.Variable(
initial_value=tf.zeros(shape, dtype),
trainable=False,
name='buffer')
self._size = tf.Variable(
initial_value=0,
trainable=False,
name='size')
def append(self, value):
"""Appends a new tensor to the end of the buffer.
Args:
value: The tensor to append. Must match the shape specified in the
initializer.
Returns:
An op appending the new tensor to the end of the buffer.
"""
with tf.control_dependencies([
tf.assert_less(
self._size,
tf.shape(self._buffer)[0],
message='Appending past end of TensorBuffer.'),
tf.assert_equal(
tf.shape(value),
tf.shape(self._buffer)[1:],
message='Appending value of inconsistent shape.')]):
with tf.control_dependencies(
[tf.assign(self._buffer[self._size, :], value)]):
return tf.assign_add(self._size, 1)
@property
def values(self):
"""Returns the accumulated tensor."""
begin_value = tf.zeros([self._rank + 1], dtype=tf.int32)
value_size = tf.concat(
[[self._size], tf.constant(-1, tf.int32, [self._rank])], 0)
return tf.slice(self._buffer, begin_value, value_size)

View file

@ -0,0 +1,73 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensor_buffer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import tensor_buffer
tf.enable_eager_execution()
class TensorBufferTest(tf.test.TestCase):
def test_basic(self):
size, shape = 2, [2, 3]
my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer')
value1 = [[1, 2, 3], [4, 5, 6]]
my_buffer.append(value1)
self.assertAllEqual(my_buffer.values.numpy(), [value1])
value2 = [[4, 5, 6], [7, 8, 9]]
my_buffer.append(value2)
self.assertAllEqual(my_buffer.values.numpy(), [value1, value2])
def test_fail_on_scalar(self):
with self.assertRaisesRegex(ValueError, 'Shape cannot be scalar.'):
tensor_buffer.TensorBuffer(1, ())
def test_fail_on_inconsistent_shape(self):
size, shape = 1, [2, 3]
my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer')
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
'Appending value of inconsistent shape.'):
my_buffer.append(tf.ones(shape=[3, 4], dtype=tf.int32))
def test_fail_on_overflow(self):
size, shape = 2, [2, 3]
my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer')
# First two should succeed.
my_buffer.append(tf.ones(shape=shape, dtype=tf.int32))
my_buffer.append(tf.ones(shape=shape, dtype=tf.int32))
# Third one should fail.
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
'Appending past end of TensorBuffer.'):
my_buffer.append(tf.ones(shape=shape, dtype=tf.int32))
if __name__ == '__main__':
tf.test.main()

View file

@ -19,6 +19,7 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.optimizers import gaussian_query
@ -121,6 +122,19 @@ def make_gaussian_optimizer_class(cls):
**kwargs):
dp_average_query = gaussian_query.GaussianAverageQuery(
l2_norm_clip, l2_norm_clip * noise_multiplier, num_microbatches)
if 'population_size' in kwargs:
population_size = kwargs.pop('population_size')
max_queries = kwargs.pop('ledger_max_queries', 1e6)
max_samples = kwargs.pop('ledger_max_samples', 1e6)
selection_probability = num_microbatches / population_size
ledger = privacy_ledger.PrivacyLedger(
population_size,
selection_probability,
max_samples,
max_queries)
dp_average_query = privacy_ledger.QueryWithLedger(
dp_average_query, ledger)
super(DPGaussianOptimizerClass, self).__init__(
dp_average_query,
num_microbatches,

View file

@ -22,6 +22,7 @@ import mock
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.optimizers import dp_optimizer
from tensorflow_privacy.privacy.optimizers import gaussian_query
@ -52,8 +53,12 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
var0 = tf.Variable([1.0, 2.0])
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
ledger = privacy_ledger.PrivacyLedger(
1e6, num_microbatches / 1e6, 50, 50)
dp_average_query = gaussian_query.GaussianAverageQuery(
1.0e9, 0.0, num_microbatches)
1.0e9, 0.0, num_microbatches, ledger)
dp_average_query = privacy_ledger.QueryWithLedger(
dp_average_query, ledger)
opt = cls(
dp_average_query,
@ -79,7 +84,10 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
var0 = tf.Variable([0.0, 0.0])
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
ledger = privacy_ledger.PrivacyLedger(1e6, 1 / 1e6, 50, 50)
dp_average_query = gaussian_query.GaussianAverageQuery(1.0, 0.0, 1)
dp_average_query = privacy_ledger.QueryWithLedger(
dp_average_query, ledger)
opt = cls(dp_average_query, num_microbatches=1, learning_rate=2.0)
@ -101,7 +109,10 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
var0 = tf.Variable([0.0])
data0 = tf.Variable([[0.0]])
ledger = privacy_ledger.PrivacyLedger(1e6, 1 / 1e6, 5000, 5000)
dp_average_query = gaussian_query.GaussianAverageQuery(4.0, 8.0, 1)
dp_average_query = privacy_ledger.QueryWithLedger(
dp_average_query, ledger)
opt = cls(dp_average_query, num_microbatches=1, learning_rate=2.0)
@ -142,7 +153,10 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
vector_loss = tf.squared_difference(labels, preds)
scalar_loss = tf.reduce_mean(vector_loss)
ledger = privacy_ledger.PrivacyLedger(1e6, 1 / 1e6, 500, 500)
dp_average_query = gaussian_query.GaussianAverageQuery(1.0, 0.0, 1)
dp_average_query = privacy_ledger.QueryWithLedger(
dp_average_query, ledger)
optimizer = dp_optimizer.DPGradientDescentOptimizer(
dp_average_query,
num_microbatches=1,
@ -182,11 +196,17 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
var0 = tf.Variable([1.0, 2.0])
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
dp_average_query = (
gaussian_query.GaussianAverageQuery(1.0e9, 0.0, 4))
num_microbatches = 4
ledger = privacy_ledger.PrivacyLedger(
1e6, num_microbatches / 1e6, 50, 50)
dp_average_query = gaussian_query.GaussianAverageQuery(1.0e9, 0.0, 4)
dp_average_query = privacy_ledger.QueryWithLedger(
dp_average_query, ledger)
opt = cls(
dp_average_query,
num_microbatches=4,
num_microbatches=num_microbatches,
learning_rate=2.0,
unroll_microbatches=True)

View file

@ -38,16 +38,18 @@ class GaussianSumQuery(dp_query.DPQuery):
_GlobalState = collections.namedtuple(
'_GlobalState', ['l2_norm_clip', 'stddev'])
def __init__(self, l2_norm_clip, stddev):
def __init__(self, l2_norm_clip, stddev, ledger=None):
"""Initializes the GaussianSumQuery.
Args:
l2_norm_clip: The clipping norm to apply to the global norm of each
record.
stddev: The stddev of the noise added to the sum.
ledger: The privacy ledger to which queries should be recorded.
"""
self._l2_norm_clip = l2_norm_clip
self._stddev = stddev
self._ledger = ledger
def initial_global_state(self):
"""Returns the initial global state for the GaussianSumQuery."""
@ -74,8 +76,12 @@ class GaussianSumQuery(dp_query.DPQuery):
Returns: An initial sample state.
"""
del global_state # unused.
return nest.map_structure(tf.zeros_like, tensors)
if self._ledger:
dependencies = [self._ledger.record_sum_query(*global_state)]
else:
dependencies = []
with tf.control_dependencies(dependencies):
return nest.map_structure(tf.zeros_like, tensors)
def accumulate_record(self, params, sample_state, record):
"""Accumulates a single record into the sample state.
@ -126,7 +132,11 @@ class GaussianAverageQuery(dp_query.DPQuery):
_GlobalState = collections.namedtuple(
'_GlobalState', ['sum_state', 'denominator'])
def __init__(self, l2_norm_clip, sum_stddev, denominator):
def __init__(self,
l2_norm_clip,
sum_stddev,
denominator,
ledger=None):
"""Initializes the GaussianAverageQuery.
Args:
@ -136,8 +146,9 @@ class GaussianAverageQuery(dp_query.DPQuery):
normalization).
denominator: The normalization constant (applied after noise is added to
the sum).
ledger: The privacy ledger to which queries should be recorded.
"""
self._numerator = GaussianSumQuery(l2_norm_clip, sum_stddev)
self._numerator = GaussianSumQuery(l2_norm_clip, sum_stddev, ledger)
self._denominator = denominator
def initial_global_state(self):

View file

@ -76,7 +76,8 @@ def cnn_model_fn(features, labels, mode):
l2_norm_clip=FLAGS.l2_norm_clip,
noise_multiplier=FLAGS.noise_multiplier,
num_microbatches=FLAGS.microbatches,
learning_rate=FLAGS.learning_rate)
learning_rate=FLAGS.learning_rate,
population_size=60000)
opt_loss = vector_loss
else:
optimizer = tf.train.GradientDescentOptimizer(