Adds AdaptiveClipAverageQuery which performs adaptive adjustment of the clipping norm to approximate a specified quantile of clipped updates per round.

PiperOrigin-RevId: 238698171
This commit is contained in:
Galen Andrew 2019-03-15 13:18:58 -07:00 committed by A. Unique TensorFlower
parent 947e17dcce
commit 9a53e1eb86
5 changed files with 47 additions and 31 deletions

View file

@ -79,6 +79,23 @@ class GaussianSumQuery(dp_query.DPQuery):
with tf.control_dependencies(dependencies): with tf.control_dependencies(dependencies):
return nest.map_structure(tf.zeros_like, tensors) return nest.map_structure(tf.zeros_like, tensors)
def accumulate_record_impl(self, params, sample_state, record):
"""Accumulates a single record into the sample state.
Args:
params: The parameters for the sample.
sample_state: The current sample state.
record: The record to accumulate.
Returns:
A tuple containing the updated sample state and the global norm.
"""
l2_norm_clip = params
record_as_list = nest.flatten(record)
clipped_as_list, norm = tf.clip_by_global_norm(record_as_list, l2_norm_clip)
clipped = nest.pack_sequence_as(record, clipped_as_list)
return nest.map_structure(tf.add, sample_state, clipped), norm
def accumulate_record(self, params, sample_state, record): def accumulate_record(self, params, sample_state, record):
"""Accumulates a single record into the sample state. """Accumulates a single record into the sample state.
@ -90,11 +107,9 @@ class GaussianSumQuery(dp_query.DPQuery):
Returns: Returns:
The updated sample state. The updated sample state.
""" """
l2_norm_clip = params new_sample_state, _ = self.accumulate_record_impl(
record_as_list = nest.flatten(record) params, sample_state, record)
clipped_as_list, _ = tf.clip_by_global_norm(record_as_list, l2_norm_clip) return new_sample_state
clipped = nest.pack_sequence_as(record, clipped_as_list)
return nest.map_structure(tf.add, sample_state, clipped)
def get_noised_result(self, sample_state, global_state): def get_noised_result(self, sample_state, global_state):
"""Gets noised sum after all records of sample have been accumulated. """Gets noised sum after all records of sample have been accumulated.

View file

@ -36,7 +36,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
query = gaussian_query.GaussianSumQuery( query = gaussian_query.GaussianSumQuery(
l2_norm_clip=10.0, stddev=0.0) l2_norm_clip=10.0, stddev=0.0)
query_result = test_utils.run_query(query, [record1, record2]) query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result) result = sess.run(query_result)
expected = [1.0, 1.0] expected = [1.0, 1.0]
self.assertAllClose(result, expected) self.assertAllClose(result, expected)
@ -48,7 +48,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
query = gaussian_query.GaussianSumQuery( query = gaussian_query.GaussianSumQuery(
l2_norm_clip=5.0, stddev=0.0) l2_norm_clip=5.0, stddev=0.0)
query_result = test_utils.run_query(query, [record1, record2]) query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result) result = sess.run(query_result)
expected = [1.0, 1.0] expected = [1.0, 1.0]
self.assertAllClose(result, expected) self.assertAllClose(result, expected)
@ -63,7 +63,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
assign_l2_norm_clip = tf.assign(l2_norm_clip, l2_norm_clip_placeholder) assign_l2_norm_clip = tf.assign(l2_norm_clip, l2_norm_clip_placeholder)
query = gaussian_query.GaussianSumQuery( query = gaussian_query.GaussianSumQuery(
l2_norm_clip=l2_norm_clip, stddev=0.0) l2_norm_clip=l2_norm_clip, stddev=0.0)
query_result = test_utils.run_query(query, [record1, record2]) query_result, _ = test_utils.run_query(query, [record1, record2])
self.evaluate(tf.global_variables_initializer()) self.evaluate(tf.global_variables_initializer())
result = sess.run(query_result) result = sess.run(query_result)
@ -82,7 +82,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
query = gaussian_query.GaussianSumQuery( query = gaussian_query.GaussianSumQuery(
l2_norm_clip=5.0, stddev=stddev) l2_norm_clip=5.0, stddev=stddev)
query_result = test_utils.run_query(query, [record1, record2]) query_result, _ = test_utils.run_query(query, [record1, record2])
noised_sums = [] noised_sums = []
for _ in xrange(1000): for _ in xrange(1000):
@ -98,7 +98,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
query = gaussian_query.GaussianAverageQuery( query = gaussian_query.GaussianAverageQuery(
l2_norm_clip=3.0, sum_stddev=0.0, denominator=2.0) l2_norm_clip=3.0, sum_stddev=0.0, denominator=2.0)
query_result = test_utils.run_query(query, [record1, record2]) query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result) result = sess.run(query_result)
expected_average = [1.0, 1.0] expected_average = [1.0, 1.0]
self.assertAllClose(result, expected_average) self.assertAllClose(result, expected_average)
@ -111,7 +111,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
query = gaussian_query.GaussianAverageQuery( query = gaussian_query.GaussianAverageQuery(
l2_norm_clip=5.0, sum_stddev=sum_stddev, denominator=denominator) l2_norm_clip=5.0, sum_stddev=sum_stddev, denominator=denominator)
query_result = test_utils.run_query(query, [record1, record2]) query_result, _ = test_utils.run_query(query, [record1, record2])
noised_averages = [] noised_averages = []
for _ in range(1000): for _ in range(1000):

View file

@ -46,7 +46,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
record1 = [1.0, [2.0, 3.0]] record1 = [1.0, [2.0, 3.0]]
record2 = [4.0, [3.0, 2.0]] record2 = [4.0, [3.0, 2.0]]
query_result = test_utils.run_query(query, [record1, record2]) query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result) result = sess.run(query_result)
expected = [5.0, [5.0, 5.0]] expected = [5.0, [5.0, 5.0]]
self.assertAllClose(result, expected) self.assertAllClose(result, expected)
@ -63,7 +63,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
record1 = [1.0, [2.0, 3.0]] record1 = [1.0, [2.0, 3.0]]
record2 = [4.0, [3.0, 2.0]] record2 = [4.0, [3.0, 2.0]]
query_result = test_utils.run_query(query, [record1, record2]) query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result) result = sess.run(query_result)
expected = [1.0, [1.0, 1.0]] expected = [1.0, [1.0, 1.0]]
self.assertAllClose(result, expected) self.assertAllClose(result, expected)
@ -80,7 +80,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
record1 = [1.0, [12.0, 9.0]] # Clipped to [1.0, [4.0, 3.0]] record1 = [1.0, [12.0, 9.0]] # Clipped to [1.0, [4.0, 3.0]]
record2 = [5.0, [1.0, 2.0]] # Clipped to [4.0, [1.0, 2.0]] record2 = [5.0, [1.0, 2.0]] # Clipped to [4.0, [1.0, 2.0]]
query_result = test_utils.run_query(query, [record1, record2]) query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result) result = sess.run(query_result)
expected = [1.0, [1.0, 1.0]] expected = [1.0, [1.0, 1.0]]
self.assertAllClose(result, expected) self.assertAllClose(result, expected)
@ -100,7 +100,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}] record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}]
record2 = [{'a': 3.14159, 'b': 0.0}, {'c': (6.0, -4.0), 'd': [5.0]}] record2 = [{'a': 3.14159, 'b': 0.0}, {'c': (6.0, -4.0), 'd': [5.0]}]
query_result = test_utils.run_query(query, [record1, record2]) query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result) result = sess.run(query_result)
expected = [{'a': 1.0, 'b': 1.0}, {'c': (1.0, 1.0), 'd': [1.0]}] expected = [{'a': 1.0, 'b': 1.0}, {'c': (1.0, 1.0), 'd': [1.0]}]
self.assertAllClose(result, expected) self.assertAllClose(result, expected)
@ -119,7 +119,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
record1 = (3.0, [2.0, 1.5]) record1 = (3.0, [2.0, 1.5])
record2 = (0.0, [-1.0, -3.5]) record2 = (0.0, [-1.0, -3.5])
query_result = test_utils.run_query(query, [record1, record2]) query_result, _ = test_utils.run_query(query, [record1, record2])
noised_averages = [] noised_averages = []
for _ in range(1000): for _ in range(1000):

View file

@ -33,7 +33,7 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
record2 = tf.constant([-1.0, 1.0]) record2 = tf.constant([-1.0, 1.0])
query = no_privacy_query.NoPrivacySumQuery() query = no_privacy_query.NoPrivacySumQuery()
query_result = test_utils.run_query(query, [record1, record2]) query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result) result = sess.run(query_result)
expected = [1.0, 1.0] expected = [1.0, 1.0]
self.assertAllClose(result, expected) self.assertAllClose(result, expected)
@ -43,12 +43,11 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
record1 = tf.constant([2.0, 0.0]) record1 = tf.constant([2.0, 0.0])
record2 = tf.constant([-1.0, 1.0]) record2 = tf.constant([-1.0, 1.0])
weight1 = 1 weights = [1, 2]
weight2 = 2
query = no_privacy_query.NoPrivacySumQuery() query = no_privacy_query.NoPrivacySumQuery()
query_result = test_utils.run_query( query_result, _ = test_utils.run_query(
query, [record1, record2], [weight1, weight2]) query, [record1, record2], weights=weights)
result = sess.run(query_result) result = sess.run(query_result)
expected = [0.0, 2.0] expected = [0.0, 2.0]
self.assertAllClose(result, expected) self.assertAllClose(result, expected)
@ -59,7 +58,7 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
record2 = tf.constant([-1.0, 2.0]) record2 = tf.constant([-1.0, 2.0])
query = no_privacy_query.NoPrivacyAverageQuery() query = no_privacy_query.NoPrivacyAverageQuery()
query_result = test_utils.run_query(query, [record1, record2]) query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result) result = sess.run(query_result)
expected = [2.0, 1.0] expected = [2.0, 1.0]
self.assertAllClose(result, expected) self.assertAllClose(result, expected)
@ -69,12 +68,11 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
record1 = tf.constant([4.0, 0.0]) record1 = tf.constant([4.0, 0.0])
record2 = tf.constant([-1.0, 1.0]) record2 = tf.constant([-1.0, 1.0])
weight1 = 1 weights = [1, 3]
weight2 = 3
query = no_privacy_query.NoPrivacyAverageQuery() query = no_privacy_query.NoPrivacyAverageQuery()
query_result = test_utils.run_query( query_result, _ = test_utils.run_query(
query, [record1, record2], [weight1, weight2]) query, [record1, record2], weights=weights)
result = sess.run(query_result) result = sess.run(query_result)
expected = [0.25, 0.75] expected = [0.25, 0.75]
self.assertAllClose(result, expected) self.assertAllClose(result, expected)

View file

@ -21,18 +21,22 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
def run_query(query, records, weights=None): def run_query(query, records, global_state=None, weights=None):
"""Executes query on the given set of records as a single sample. """Executes query on the given set of records as a single sample.
Args: Args:
query: A PrivateQuery to run. query: A PrivateQuery to run.
records: An iterable containing records to pass to the query. records: An iterable containing records to pass to the query.
global_state: The current global state. If None, an initial global state is
generated.
weights: An optional iterable containing the weights of the records. weights: An optional iterable containing the weights of the records.
Returns: Returns:
The result of the query. A tuple (result, new_global_state) where "result" is the result of the
query and "new_global_state" is the updated global state.
""" """
global_state = query.initial_global_state() if not global_state:
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state) params = query.derive_sample_params(global_state)
sample_state = query.initial_sample_state(global_state, next(iter(records))) sample_state = query.initial_sample_state(global_state, next(iter(records)))
if weights is None: if weights is None:
@ -42,5 +46,4 @@ def run_query(query, records, weights=None):
for weight, record in zip(weights, records): for weight, record in zip(weights, records):
sample_state = query.accumulate_record( sample_state = query.accumulate_record(
params, sample_state, record, weight) params, sample_state, record, weight)
result, _ = query.get_noised_result(sample_state, global_state) return query.get_noised_result(sample_state, global_state)
return result