forked from 626_privacy/tensorflow_privacy
Use PrivacyLedger for privacy accounting.
Prior to this change the PrivacyLedger is running to keep a log of private queries, but the ledger is not actually used to compute the (epsilon, delta) guarantees. This CL adds a function to compute the RDP directly from the ledger. Note I did verify that the tutorial builds and runs with the changes and for the first few iterations prints the same epsilon values as before the change. PiperOrigin-RevId: 241063532
This commit is contained in:
parent
8507094f2b
commit
9106a04e2c
6 changed files with 157 additions and 55 deletions
|
@ -22,6 +22,7 @@ from __future__ import print_function
|
|||
import collections
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from privacy.analysis import tensor_buffer
|
||||
|
@ -39,6 +40,24 @@ GaussianSumQueryEntry = collections.namedtuple( # pylint: disable=invalid-name
|
|||
'GaussianSumQueryEntry', ['l2_norm_bound', 'noise_stddev'])
|
||||
|
||||
|
||||
def format_ledger(sample_array, query_array):
|
||||
"""Converts array representation into a list of SampleEntries."""
|
||||
samples = []
|
||||
query_pos = 0
|
||||
sample_pos = 0
|
||||
for sample in sample_array:
|
||||
population_size, selection_probability, num_queries = sample
|
||||
queries = []
|
||||
for _ in range(int(num_queries)):
|
||||
query = query_array[query_pos]
|
||||
assert int(query[0]) == sample_pos
|
||||
queries.append(GaussianSumQueryEntry(*query[1:]))
|
||||
query_pos += 1
|
||||
samples.append(SampleEntry(population_size, selection_probability, queries))
|
||||
sample_pos += 1
|
||||
return samples
|
||||
|
||||
|
||||
class PrivacyLedger(object):
|
||||
"""Class for keeping a record of private queries.
|
||||
|
||||
|
@ -118,22 +137,8 @@ class PrivacyLedger(object):
|
|||
tf.assign(self._query_count, 0)]):
|
||||
return self._sample_buffer.append(self._sample_var)
|
||||
|
||||
def _format_ledger(self, sample_array, query_array):
|
||||
"""Converts underlying representation into a list of SampleEntries."""
|
||||
samples = []
|
||||
query_pos = 0
|
||||
sample_pos = 0
|
||||
for sample in sample_array:
|
||||
num_queries = int(sample[2])
|
||||
queries = []
|
||||
for _ in range(num_queries):
|
||||
query = query_array[query_pos]
|
||||
assert int(query[0]) == sample_pos
|
||||
queries.append(GaussianSumQueryEntry(*query[1:]))
|
||||
query_pos += 1
|
||||
samples.append(SampleEntry(sample[0], sample[1], queries))
|
||||
sample_pos += 1
|
||||
return samples
|
||||
def get_unformatted_ledger(self):
|
||||
return self._sample_buffer.values, self._query_buffer.values
|
||||
|
||||
def get_formatted_ledger(self, sess):
|
||||
"""Gets the formatted query ledger.
|
||||
|
@ -147,7 +152,7 @@ class PrivacyLedger(object):
|
|||
sample_array = sess.run(self._sample_buffer.values)
|
||||
query_array = sess.run(self._query_buffer.values)
|
||||
|
||||
return self._format_ledger(sample_array, query_array)
|
||||
return format_ledger(sample_array, query_array)
|
||||
|
||||
def get_formatted_ledger_eager(self):
|
||||
"""Gets the formatted query ledger.
|
||||
|
@ -158,7 +163,36 @@ class PrivacyLedger(object):
|
|||
sample_array = self._sample_buffer.values.numpy()
|
||||
query_array = self._query_buffer.values.numpy()
|
||||
|
||||
return self._format_ledger(sample_array, query_array)
|
||||
return format_ledger(sample_array, query_array)
|
||||
|
||||
|
||||
class DummyLedger(object):
|
||||
"""A ledger that records nothing.
|
||||
|
||||
This ledger may be passed in place of a normal PrivacyLedger in case privacy
|
||||
accounting is to be handled externally.
|
||||
"""
|
||||
|
||||
def record_sum_query(self, l2_norm_bound, noise_stddev):
|
||||
del l2_norm_bound
|
||||
del noise_stddev
|
||||
return tf.no_op()
|
||||
|
||||
def finalize_sample(self):
|
||||
return tf.no_op()
|
||||
|
||||
def get_unformatted_ledger(self):
|
||||
empty_array = tf.zeros(shape=[0, 3])
|
||||
return empty_array, empty_array
|
||||
|
||||
def get_formatted_ledger(self, sess):
|
||||
del sess
|
||||
empty_array = np.zeros(shape=[0, 3])
|
||||
return empty_array, empty_array
|
||||
|
||||
def get_formatted_ledger_eager(self):
|
||||
empty_array = np.zeros(shape=[0, 3])
|
||||
return empty_array, empty_array
|
||||
|
||||
|
||||
class QueryWithLedger(dp_query.DPQuery):
|
||||
|
|
|
@ -295,3 +295,42 @@ def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
|
|||
else:
|
||||
eps, opt_order = _compute_eps(orders, rdp, target_delta)
|
||||
return eps, target_delta, opt_order
|
||||
|
||||
|
||||
def compute_rdp_from_ledger(ledger, orders):
|
||||
"""Compute RDP of Sampled Gaussian Mechanism from ledger.
|
||||
|
||||
Args:
|
||||
ledger: A formatted privacy ledger.
|
||||
orders: An array (or a scalar) of RDP orders.
|
||||
|
||||
Returns:
|
||||
RDP at all orders, can be np.inf.
|
||||
"""
|
||||
total_rdp = 0
|
||||
for sample in ledger:
|
||||
# Compute equivalent z from l2_clip_bounds and noise stddevs in sample.
|
||||
# See https://arxiv.org/pdf/1812.06210.pdf for derivation of this formula.
|
||||
effective_z = sum([
|
||||
(q.noise_stddev / q.l2_norm_bound)**-2 for q in sample.queries])**-0.5
|
||||
total_rdp += compute_rdp(
|
||||
sample.selection_probability, effective_z, 1, orders)
|
||||
return total_rdp
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ from mpmath import npdf
|
|||
from mpmath import quad
|
||||
import numpy as np
|
||||
|
||||
from privacy.analysis import privacy_ledger
|
||||
from privacy.analysis import rdp_accountant
|
||||
|
||||
|
||||
|
@ -154,6 +155,23 @@ class TestGaussianMoments(parameterized.TestCase):
|
|||
self.assertAlmostEqual(eps, 8.509656, places=5)
|
||||
self.assertEqual(opt_order, 2.5)
|
||||
|
||||
def test_compute_rdp_from_ledger(self):
|
||||
orders = range(2, 33)
|
||||
q = 0.1
|
||||
n = 1000
|
||||
l2_norm_clip = 3.14159
|
||||
noise_stddev = 2.71828
|
||||
steps = 3
|
||||
|
||||
query_entry = privacy_ledger.GaussianSumQueryEntry(
|
||||
l2_norm_clip, noise_stddev)
|
||||
ledger = [privacy_ledger.SampleEntry(n, q, [query_entry])] * steps
|
||||
|
||||
z = noise_stddev / l2_norm_clip
|
||||
rdp = rdp_accountant.compute_rdp(q, z, steps, orders)
|
||||
rdp_from_ledger = rdp_accountant.compute_rdp_from_ledger(ledger, orders)
|
||||
self.assertSequenceAlmostEqual(rdp, rdp_from_ledger)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
|
@ -156,21 +156,14 @@ def make_gaussian_optimizer_class(cls):
|
|||
l2_norm_clip,
|
||||
noise_multiplier,
|
||||
num_microbatches,
|
||||
ledger,
|
||||
unroll_microbatches=False,
|
||||
*args, # pylint: disable=keyword-arg-before-vararg
|
||||
**kwargs):
|
||||
dp_average_query = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip, l2_norm_clip * noise_multiplier, num_microbatches)
|
||||
if 'population_size' in kwargs:
|
||||
population_size = kwargs.pop('population_size')
|
||||
max_queries = kwargs.pop('ledger_max_queries', 1e6)
|
||||
max_samples = kwargs.pop('ledger_max_samples', 1e6)
|
||||
selection_probability = num_microbatches / population_size
|
||||
ledger = privacy_ledger.PrivacyLedger(
|
||||
population_size,
|
||||
selection_probability,
|
||||
max_samples,
|
||||
max_queries)
|
||||
l2_norm_clip, l2_norm_clip * noise_multiplier,
|
||||
num_microbatches, ledger)
|
||||
if ledger:
|
||||
dp_average_query = privacy_ledger.QueryWithLedger(
|
||||
dp_average_query, ledger)
|
||||
|
||||
|
@ -181,6 +174,10 @@ def make_gaussian_optimizer_class(cls):
|
|||
*args,
|
||||
**kwargs)
|
||||
|
||||
@property
|
||||
def ledger(self):
|
||||
return self._ledger
|
||||
|
||||
return DPGaussianOptimizerClass
|
||||
|
||||
# Compatibility with tf 1 and 2 APIs
|
||||
|
|
|
@ -232,7 +232,8 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
l2_norm_clip=4.0,
|
||||
noise_multiplier=2.0,
|
||||
num_microbatches=1,
|
||||
learning_rate=2.0)
|
||||
learning_rate=2.0,
|
||||
ledger=privacy_ledger.DummyLedger())
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
|
|
|
@ -21,7 +21,8 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from privacy.analysis.rdp_accountant import compute_rdp
|
||||
from privacy.analysis import privacy_ledger
|
||||
from privacy.analysis.rdp_accountant import compute_rdp_from_ledger
|
||||
from privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from privacy.optimizers import dp_optimizer
|
||||
|
||||
|
@ -46,6 +47,27 @@ tf.flags.DEFINE_string('model_dir', None, 'Model directory')
|
|||
FLAGS = tf.flags.FLAGS
|
||||
|
||||
|
||||
class EpsilonPrintingTrainingHook(tf.estimator.SessionRunHook):
|
||||
"""Training hook to print current value of epsilon after an epoch."""
|
||||
|
||||
def __init__(self, ledger):
|
||||
"""Initalizes the EpsilonPrintingTrainingHook.
|
||||
|
||||
Args:
|
||||
ledger: The privacy ledger.
|
||||
"""
|
||||
self._samples, self._queries = ledger.get_unformatted_ledger()
|
||||
|
||||
def end(self, session):
|
||||
orders = [1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64))
|
||||
samples = session.run(self._samples)
|
||||
queries = session.run(self._queries)
|
||||
formatted_ledger = privacy_ledger.format_ledger(samples, queries)
|
||||
rdp = compute_rdp_from_ledger(formatted_ledger, orders)
|
||||
eps = get_privacy_spent(orders, rdp, target_delta=1e-5)[0]
|
||||
print('For delta=1e-5, the current epsilon is: %.2f' % eps)
|
||||
|
||||
|
||||
def cnn_model_fn(features, labels, mode):
|
||||
"""Model function for a CNN."""
|
||||
|
||||
|
@ -75,6 +97,12 @@ def cnn_model_fn(features, labels, mode):
|
|||
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||
|
||||
if FLAGS.dpsgd:
|
||||
ledger = privacy_ledger.PrivacyLedger(
|
||||
population_size=60000,
|
||||
selection_probability=(FLAGS.batch_size / 60000),
|
||||
max_samples=1e6,
|
||||
max_queries=1e6)
|
||||
|
||||
# Use DP version of GradientDescentOptimizer. Other optimizers are
|
||||
# available in dp_optimizer. Most optimizers inheriting from
|
||||
# tf.train.Optimizer should be wrappable in differentially private
|
||||
|
@ -83,11 +111,15 @@ def cnn_model_fn(features, labels, mode):
|
|||
l2_norm_clip=FLAGS.l2_norm_clip,
|
||||
noise_multiplier=FLAGS.noise_multiplier,
|
||||
num_microbatches=FLAGS.microbatches,
|
||||
learning_rate=FLAGS.learning_rate,
|
||||
population_size=60000)
|
||||
ledger=ledger,
|
||||
learning_rate=FLAGS.learning_rate)
|
||||
training_hooks = [
|
||||
EpsilonPrintingTrainingHook(ledger)
|
||||
]
|
||||
opt_loss = vector_loss
|
||||
else:
|
||||
optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
|
||||
training_hooks = []
|
||||
opt_loss = scalar_loss
|
||||
global_step = tf.train.get_global_step()
|
||||
train_op = optimizer.minimize(loss=opt_loss, global_step=global_step)
|
||||
|
@ -97,7 +129,8 @@ def cnn_model_fn(features, labels, mode):
|
|||
# minimized is opt_loss defined above and passed to optimizer.minimize().
|
||||
return tf.estimator.EstimatorSpec(mode=mode,
|
||||
loss=scalar_loss,
|
||||
train_op=train_op)
|
||||
train_op=train_op,
|
||||
training_hooks=training_hooks)
|
||||
|
||||
# Add evaluation metrics (for EVAL mode).
|
||||
elif mode == tf.estimator.ModeKeys.EVAL:
|
||||
|
@ -107,6 +140,7 @@ def cnn_model_fn(features, labels, mode):
|
|||
labels=labels,
|
||||
predictions=tf.argmax(input=logits, axis=1))
|
||||
}
|
||||
|
||||
return tf.estimator.EstimatorSpec(mode=mode,
|
||||
loss=scalar_loss,
|
||||
eval_metric_ops=eval_metric_ops)
|
||||
|
@ -134,20 +168,6 @@ def load_mnist():
|
|||
return train_data, train_labels, test_data, test_labels
|
||||
|
||||
|
||||
def compute_epsilon(steps):
|
||||
"""Computes epsilon value for given hyperparameters."""
|
||||
if FLAGS.noise_multiplier == 0.0:
|
||||
return float('inf')
|
||||
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
|
||||
sampling_probability = FLAGS.batch_size / 60000
|
||||
rdp = compute_rdp(q=sampling_probability,
|
||||
noise_multiplier=FLAGS.noise_multiplier,
|
||||
steps=steps,
|
||||
orders=orders)
|
||||
# Delta is set to 1e-5 because MNIST has 60000 training points.
|
||||
return get_privacy_spent(orders, rdp, target_delta=1e-5)[0]
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
if FLAGS.batch_size % FLAGS.microbatches != 0:
|
||||
|
@ -184,12 +204,5 @@ def main(unused_argv):
|
|||
test_accuracy = eval_results['accuracy']
|
||||
print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))
|
||||
|
||||
# Compute the privacy budget expended so far.
|
||||
if FLAGS.dpsgd:
|
||||
eps = compute_epsilon(epoch * steps_per_epoch)
|
||||
print('For delta=1e-5, the current epsilon is: %.2f' % eps)
|
||||
else:
|
||||
print('Trained with vanilla non-private SGD optimizer')
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.app.run()
|
||||
|
|
Loading…
Reference in a new issue