Update PrivacyLedger and DPOptimizer to make certain arguments optional.

PiperOrigin-RevId: 246235646
This commit is contained in:
Steve Chien 2019-05-01 18:07:11 -07:00 committed by A. Unique TensorFlower
parent c09ec4c22b
commit beb86c6e18
5 changed files with 67 additions and 35 deletions

View file

@ -11,9 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""PrivacyLedger class for keeping a record of private queries."""
"""PrivacyLedger class for keeping a record of private queries.
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -65,36 +63,39 @@ class PrivacyLedger(object):
for the purpose of computing privacy guarantees. for the purpose of computing privacy guarantees.
""" """
def __init__( def __init__(self,
self, population_size,
population_size, selection_probability=None,
selection_probability, max_samples=None,
max_samples, max_queries=None):
max_queries):
"""Initialize the PrivacyLedger. """Initialize the PrivacyLedger.
Args: Args:
population_size: An integer (may be variable) specifying the size of the population_size: An integer (may be variable) specifying the size of the
population. population, i.e. size of the training data used in each epoch.
selection_probability: A float (may be variable) specifying the selection_probability: A float (may be variable) specifying the
probability each record is included in a sample. probability each record is included in a sample.
max_samples: The maximum number of samples. An exception is thrown if max_samples: The maximum number of samples. An exception is thrown if more
more than this many samples are recorded. than this many samples are recorded.
max_queries: The maximum number of queries. An exception is thrown if max_queries: The maximum number of queries. An exception is thrown if more
more than this many queries are recorded. than this many queries are recorded.
""" """
self._population_size = population_size self._population_size = population_size
self._selection_probability = selection_probability self._selection_probability = selection_probability
if max_samples is None:
max_samples = 1000 * population_size
if max_queries is None:
max_queries = 1000 * population_size
# The query buffer stores rows corresponding to GaussianSumQueryEntries. # The query buffer stores rows corresponding to GaussianSumQueryEntries.
self._query_buffer = tensor_buffer.TensorBuffer( self._query_buffer = tensor_buffer.TensorBuffer(max_queries, [3],
max_queries, [3], tf.float32, 'query') tf.float32, 'query')
self._sample_var = tf.Variable( self._sample_var = tf.Variable(
initial_value=tf.zeros([3]), trainable=False, name='sample') initial_value=tf.zeros([3]), trainable=False, name='sample')
# The sample buffer stores rows corresponding to SampleEntries. # The sample buffer stores rows corresponding to SampleEntries.
self._sample_buffer = tensor_buffer.TensorBuffer( self._sample_buffer = tensor_buffer.TensorBuffer(max_samples, [3],
max_samples, [3], tf.float32, 'sample') tf.float32, 'sample')
self._sample_count = tf.Variable( self._sample_count = tf.Variable(
initial_value=0.0, trainable=False, name='sample_count') initial_value=0.0, trainable=False, name='sample_count')
self._query_count = tf.Variable( self._query_count = tf.Variable(
@ -116,9 +117,10 @@ class PrivacyLedger(object):
Returns: Returns:
An operation recording the sum query to the ledger. An operation recording the sum query to the ledger.
""" """
def _do_record_query(): def _do_record_query():
with tf.control_dependencies([ with tf.control_dependencies(
tf.assign(self._query_count, self._query_count + 1)]): [tf.assign(self._query_count, self._query_count + 1)]):
return self._query_buffer.append( return self._query_buffer.append(
[self._sample_count, l2_norm_bound, noise_stddev]) [self._sample_count, l2_norm_bound, noise_stddev])
@ -127,14 +129,15 @@ class PrivacyLedger(object):
def finalize_sample(self): def finalize_sample(self):
"""Finalizes sample and records sample ledger entry.""" """Finalizes sample and records sample ledger entry."""
with tf.control_dependencies([ with tf.control_dependencies([
tf.assign( tf.assign(self._sample_var, [
self._sample_var, self._population_size, self._selection_probability,
[self._population_size, self._query_count
self._selection_probability, ])
self._query_count])]): ]):
with tf.control_dependencies([ with tf.control_dependencies([
tf.assign(self._sample_count, self._sample_count + 1), tf.assign(self._sample_count, self._sample_count + 1),
tf.assign(self._query_count, 0)]): tf.assign(self._query_count, 0)
]):
return self._sample_buffer.append(self._sample_var) return self._sample_buffer.append(self._sample_var)
def get_unformatted_ledger(self): def get_unformatted_ledger(self):
@ -165,6 +168,10 @@ class PrivacyLedger(object):
return format_ledger(sample_array, query_array) return format_ledger(sample_array, query_array)
def set_sample_size(self, batch_size):
self._selection_probability = tf.cast(batch_size,
tf.float32) / self._population_size
class DummyLedger(object): class DummyLedger(object):
"""A ledger that records nothing. """A ledger that records nothing.
@ -212,8 +219,8 @@ class QueryWithLedger(dp_query.DPQuery):
Args: Args:
query: The query whose events should be recorded to the ledger. Any query: The query whose events should be recorded to the ledger. Any
subqueries (including those in the leaves of a nested query) should subqueries (including those in the leaves of a nested query) should also
also contain a reference to the same ledger given here. contain a reference to the same ledger given here.
ledger: A PrivacyLedger to which privacy events should be recorded. ledger: A PrivacyLedger to which privacy events should be recorded.
""" """
self._query = query self._query = query
@ -240,3 +247,7 @@ class QueryWithLedger(dp_query.DPQuery):
with tf.control_dependencies(nest.flatten(sample_state)): with tf.control_dependencies(nest.flatten(sample_state)):
with tf.control_dependencies([self._ledger.finalize_sample()]): with tf.control_dependencies([self._ledger.finalize_sample()]):
return self._query.get_noised_result(sample_state, global_state) return self._query.get_noised_result(sample_state, global_state)
def set_denominator(self, num_microbatches, microbatch_size=1):
self._query.set_denominator(num_microbatches)
self._ledger.set_sample_size(num_microbatches * microbatch_size)

View file

@ -167,4 +167,4 @@ class GaussianAverageQuery(normalized_query.NormalizedQuery):
""" """
super(GaussianAverageQuery, self).__init__( super(GaussianAverageQuery, self).__init__(
numerator_query=GaussianSumQuery(l2_norm_clip, sum_stddev, ledger), numerator_query=GaussianSumQuery(l2_norm_clip, sum_stddev, ledger),
denominator=tf.cast(denominator, tf.float32)) denominator=denominator)

View file

@ -41,7 +41,8 @@ class NormalizedQuery(dp_query.DPQuery):
denominator: A value for the denominator. denominator: A value for the denominator.
""" """
self._numerator = numerator_query self._numerator = numerator_query
self._denominator = tf.cast(denominator, tf.float32) self._denominator = tf.cast(denominator,
tf.float32) if denominator is not None else None
def initial_global_state(self): def initial_global_state(self):
"""Returns the initial global state for the NormalizedQuery.""" """Returns the initial global state for the NormalizedQuery."""
@ -103,4 +104,5 @@ class NormalizedQuery(dp_query.DPQuery):
return nest.map_structure(normalize, noised_sum), new_sum_global_state return nest.map_structure(normalize, noised_sum), new_sum_global_state
def set_denominator(self, denominator):
self._denominator = tf.cast(denominator, tf.float32)

View file

@ -47,10 +47,22 @@ def make_optimizer_class(cls):
def __init__( def __init__(
self, self,
dp_average_query, dp_average_query,
num_microbatches, num_microbatches=None,
unroll_microbatches=False, unroll_microbatches=False,
*args, # pylint: disable=keyword-arg-before-vararg *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs): **kwargs):
"""Initialize the DPOptimizerClass.
Args:
dp_average_query: DPQuery object, specifying differential privacy
mechanism to use.
num_microbatches: How many microbatches into which the minibatch is
split. If None, will default to the size of the minibatch, and
per-example gradients will be computed.
unroll_microbatches: If true, processes microbatches within a Python
loop instead of a tf.while_loop. Can be used if using a tf.while_loop
raises an exception.
"""
super(DPOptimizerClass, self).__init__(*args, **kwargs) super(DPOptimizerClass, self).__init__(*args, **kwargs)
self._dp_average_query = dp_average_query self._dp_average_query = dp_average_query
self._num_microbatches = num_microbatches self._num_microbatches = num_microbatches
@ -74,6 +86,9 @@ def make_optimizer_class(cls):
raise ValueError('When in Eager mode, a tape needs to be passed.') raise ValueError('When in Eager mode, a tape needs to be passed.')
vector_loss = loss() vector_loss = loss()
if self._num_microbatches is None:
self._num_microbatches = tf.shape(vector_loss)[0]
self._dp_average_query.set_denominator(self._num_microbatches)
sample_state = self._dp_average_query.initial_sample_state( sample_state = self._dp_average_query.initial_sample_state(
self._global_state, var_list) self._global_state, var_list)
microbatches_losses = tf.reshape(vector_loss, microbatches_losses = tf.reshape(vector_loss,
@ -109,6 +124,9 @@ def make_optimizer_class(cls):
# we sampled each microbatch from the appropriate binomial distribution, # we sampled each microbatch from the appropriate binomial distribution,
# although that still wouldn't be quite correct because it would be # although that still wouldn't be quite correct because it would be
# sampling from the dataset without replacement. # sampling from the dataset without replacement.
if self._num_microbatches is None:
self._num_microbatches = tf.shape(loss)[0]
self._dp_average_query.set_denominator(self._num_microbatches)
microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1]) microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
sample_params = ( sample_params = (
self._dp_average_query.derive_sample_params(self._global_state)) self._dp_average_query.derive_sample_params(self._global_state))

View file

@ -46,14 +46,15 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
('DPAdagrad 4', dp_optimizer.DPAdagradOptimizer, 4, [-2.5, -2.5]), ('DPAdagrad 4', dp_optimizer.DPAdagradOptimizer, 4, [-2.5, -2.5]),
('DPAdam 1', dp_optimizer.DPAdamOptimizer, 1, [-2.5, -2.5]), ('DPAdam 1', dp_optimizer.DPAdamOptimizer, 1, [-2.5, -2.5]),
('DPAdam 2', dp_optimizer.DPAdamOptimizer, 2, [-2.5, -2.5]), ('DPAdam 2', dp_optimizer.DPAdamOptimizer, 2, [-2.5, -2.5]),
('DPAdam 4', dp_optimizer.DPAdamOptimizer, 4, [-2.5, -2.5])) ('DPAdam 4', dp_optimizer.DPAdamOptimizer, 4, [-2.5, -2.5]),
('DPAdam None', dp_optimizer.DPAdamOptimizer, None, [-2.5, -2.5]))
def testBaseline(self, cls, num_microbatches, expected_answer): def testBaseline(self, cls, num_microbatches, expected_answer):
with self.cached_session() as sess: with self.cached_session() as sess:
var0 = tf.Variable([1.0, 2.0]) var0 = tf.Variable([1.0, 2.0])
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
ledger = privacy_ledger.PrivacyLedger( ledger = privacy_ledger.PrivacyLedger(
1e6, num_microbatches / 1e6, 50, 50) 1e6, num_microbatches / 1e6 if num_microbatches else None, 50, 50)
dp_average_query = gaussian_query.GaussianAverageQuery( dp_average_query = gaussian_query.GaussianAverageQuery(
1.0e9, 0.0, num_microbatches, ledger) 1.0e9, 0.0, num_microbatches, ledger)
dp_average_query = privacy_ledger.QueryWithLedger( dp_average_query = privacy_ledger.QueryWithLedger(