1d1a6e087a
1. Split DPQuery.accumulate_record function into preprocess_record and accumulate_preprocessed_record. 2. Add merge_sample_state function. 3. Add default implementations for some functions in DPQuery, and add base class SumAggregationDPQuery that implements some more. Only get_noised_result is still abstract. 4. Enforce that all states and parameters used as inputs and outputs to DPQuery functions are nested structures of tensors by replacing numbers with constants and Nones with empty tuples. PiperOrigin-RevId: 247975791
262 lines
9.1 KiB
Python
262 lines
9.1 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PrivacyLedger class for keeping a record of private queries."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
|
|
from distutils.version import LooseVersion
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from privacy.analysis import tensor_buffer
|
|
from privacy.dp_query import dp_query
|
|
|
|
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
|
nest = tf.contrib.framework.nest
|
|
else:
|
|
nest = tf.nest
|
|
|
|
SampleEntry = collections.namedtuple( # pylint: disable=invalid-name
|
|
'SampleEntry', ['population_size', 'selection_probability', 'queries'])
|
|
|
|
GaussianSumQueryEntry = collections.namedtuple( # pylint: disable=invalid-name
|
|
'GaussianSumQueryEntry', ['l2_norm_bound', 'noise_stddev'])
|
|
|
|
|
|
def format_ledger(sample_array, query_array):
|
|
"""Converts array representation into a list of SampleEntries."""
|
|
samples = []
|
|
query_pos = 0
|
|
sample_pos = 0
|
|
for sample in sample_array:
|
|
population_size, selection_probability, num_queries = sample
|
|
queries = []
|
|
for _ in range(int(num_queries)):
|
|
query = query_array[query_pos]
|
|
assert int(query[0]) == sample_pos
|
|
queries.append(GaussianSumQueryEntry(*query[1:]))
|
|
query_pos += 1
|
|
samples.append(SampleEntry(population_size, selection_probability, queries))
|
|
sample_pos += 1
|
|
return samples
|
|
|
|
|
|
class PrivacyLedger(object):
|
|
"""Class for keeping a record of private queries.
|
|
|
|
The PrivacyLedger keeps a record of all queries executed over a given dataset
|
|
for the purpose of computing privacy guarantees.
|
|
"""
|
|
|
|
def __init__(self,
|
|
population_size,
|
|
selection_probability=None,
|
|
max_samples=None,
|
|
max_queries=None):
|
|
"""Initialize the PrivacyLedger.
|
|
|
|
Args:
|
|
population_size: An integer (may be variable) specifying the size of the
|
|
population, i.e. size of the training data used in each epoch.
|
|
selection_probability: A float (may be variable) specifying the
|
|
probability each record is included in a sample.
|
|
max_samples: The maximum number of samples. An exception is thrown if more
|
|
than this many samples are recorded.
|
|
max_queries: The maximum number of queries. An exception is thrown if more
|
|
than this many queries are recorded.
|
|
"""
|
|
self._population_size = population_size
|
|
self._selection_probability = selection_probability
|
|
if max_samples is None:
|
|
max_samples = 1000 * population_size
|
|
if max_queries is None:
|
|
max_queries = 1000 * population_size
|
|
|
|
# The query buffer stores rows corresponding to GaussianSumQueryEntries.
|
|
self._query_buffer = tensor_buffer.TensorBuffer(max_queries, [3],
|
|
tf.float32, 'query')
|
|
self._sample_var = tf.Variable(
|
|
initial_value=tf.zeros([3]), trainable=False, name='sample')
|
|
|
|
# The sample buffer stores rows corresponding to SampleEntries.
|
|
self._sample_buffer = tensor_buffer.TensorBuffer(max_samples, [3],
|
|
tf.float32, 'sample')
|
|
self._sample_count = tf.Variable(
|
|
initial_value=0.0, trainable=False, name='sample_count')
|
|
self._query_count = tf.Variable(
|
|
initial_value=0.0, trainable=False, name='query_count')
|
|
try:
|
|
# Newer versions of TF
|
|
self._cs = tf.CriticalSection()
|
|
except AttributeError:
|
|
# Older versions of TF
|
|
self._cs = tf.contrib.framework.CriticalSection()
|
|
|
|
def record_sum_query(self, l2_norm_bound, noise_stddev):
|
|
"""Records that a query was issued.
|
|
|
|
Args:
|
|
l2_norm_bound: The maximum l2 norm of the tensor group in the query.
|
|
noise_stddev: The standard deviation of the noise applied to the sum.
|
|
|
|
Returns:
|
|
An operation recording the sum query to the ledger.
|
|
"""
|
|
|
|
def _do_record_query():
|
|
with tf.control_dependencies(
|
|
[tf.assign(self._query_count, self._query_count + 1)]):
|
|
return self._query_buffer.append(
|
|
[self._sample_count, l2_norm_bound, noise_stddev])
|
|
|
|
return self._cs.execute(_do_record_query)
|
|
|
|
def finalize_sample(self):
|
|
"""Finalizes sample and records sample ledger entry."""
|
|
with tf.control_dependencies([
|
|
tf.assign(self._sample_var, [
|
|
self._population_size, self._selection_probability,
|
|
self._query_count
|
|
])
|
|
]):
|
|
with tf.control_dependencies([
|
|
tf.assign(self._sample_count, self._sample_count + 1),
|
|
tf.assign(self._query_count, 0)
|
|
]):
|
|
return self._sample_buffer.append(self._sample_var)
|
|
|
|
def get_unformatted_ledger(self):
|
|
return self._sample_buffer.values, self._query_buffer.values
|
|
|
|
def get_formatted_ledger(self, sess):
|
|
"""Gets the formatted query ledger.
|
|
|
|
Args:
|
|
sess: The tensorflow session in which the ledger was created.
|
|
|
|
Returns:
|
|
The query ledger as a list of SampleEntries.
|
|
"""
|
|
sample_array = sess.run(self._sample_buffer.values)
|
|
query_array = sess.run(self._query_buffer.values)
|
|
|
|
return format_ledger(sample_array, query_array)
|
|
|
|
def get_formatted_ledger_eager(self):
|
|
"""Gets the formatted query ledger.
|
|
|
|
Returns:
|
|
The query ledger as a list of SampleEntries.
|
|
"""
|
|
sample_array = self._sample_buffer.values.numpy()
|
|
query_array = self._query_buffer.values.numpy()
|
|
|
|
return format_ledger(sample_array, query_array)
|
|
|
|
def set_sample_size(self, batch_size):
|
|
self._selection_probability = tf.cast(batch_size,
|
|
tf.float32) / self._population_size
|
|
|
|
|
|
class DummyLedger(object):
|
|
"""A ledger that records nothing.
|
|
|
|
This ledger may be passed in place of a normal PrivacyLedger in case privacy
|
|
accounting is to be handled externally.
|
|
"""
|
|
|
|
def record_sum_query(self, l2_norm_bound, noise_stddev):
|
|
del l2_norm_bound
|
|
del noise_stddev
|
|
return tf.no_op()
|
|
|
|
def finalize_sample(self):
|
|
return tf.no_op()
|
|
|
|
def get_unformatted_ledger(self):
|
|
empty_array = tf.zeros(shape=[0, 3])
|
|
return empty_array, empty_array
|
|
|
|
def get_formatted_ledger(self, sess):
|
|
del sess
|
|
empty_array = np.zeros(shape=[0, 3])
|
|
return empty_array, empty_array
|
|
|
|
def get_formatted_ledger_eager(self):
|
|
empty_array = np.zeros(shape=[0, 3])
|
|
return empty_array, empty_array
|
|
|
|
|
|
class QueryWithLedger(dp_query.DPQuery):
|
|
"""A class for DP queries that record events to a PrivacyLedger.
|
|
|
|
QueryWithLedger should be the top-level query in a structure of queries that
|
|
may include sum queries, nested queries, etc. It should simply wrap another
|
|
query and contain a reference to the ledger. Any contained queries (including
|
|
those contained in the leaves of a nested query) should also contain a
|
|
reference to the same ledger object.
|
|
|
|
For example usage, see privacy_ledger_test.py.
|
|
"""
|
|
|
|
def __init__(self, query, ledger):
|
|
"""Initializes the QueryWithLedger.
|
|
|
|
Args:
|
|
query: The query whose events should be recorded to the ledger. Any
|
|
subqueries (including those in the leaves of a nested query) should also
|
|
contain a reference to the same ledger given here.
|
|
ledger: A PrivacyLedger to which privacy events should be recorded.
|
|
"""
|
|
self._query = query
|
|
self._ledger = ledger
|
|
|
|
def initial_global_state(self):
|
|
"""See base class."""
|
|
return self._query.initial_global_state()
|
|
|
|
def derive_sample_params(self, global_state):
|
|
"""See base class."""
|
|
return self._query.derive_sample_params(global_state)
|
|
|
|
def initial_sample_state(self, global_state, template):
|
|
"""See base class."""
|
|
return self._query.initial_sample_state(global_state, template)
|
|
|
|
def preprocess_record(self, params, record):
|
|
"""See base class."""
|
|
return self._query.preprocess_record(params, record)
|
|
|
|
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
|
|
"""See base class."""
|
|
return self._query.accumulate_preprocessed_record(
|
|
sample_state, preprocessed_record)
|
|
|
|
def merge_sample_states(self, sample_state_1, sample_state_2):
|
|
"""See base class."""
|
|
return self._query.merge_sample_states(sample_state_1, sample_state_2)
|
|
|
|
def get_noised_result(self, sample_state, global_state):
|
|
"""Ensures sample is recorded to the ledger and returns noised result."""
|
|
with tf.control_dependencies(nest.flatten(sample_state)):
|
|
with tf.control_dependencies([self._ledger.finalize_sample()]):
|
|
return self._query.get_noised_result(sample_state, global_state)
|
|
|
|
def set_denominator(self, num_microbatches, microbatch_size=1):
|
|
self._query.set_denominator(num_microbatches)
|
|
self._ledger.set_sample_size(num_microbatches * microbatch_size)
|