Adds AdaptiveClipAverageQuery which performs adaptive adjustment of the clipping norm to approximate a specified quantile of clipped updates per round.
PiperOrigin-RevId: 238698171
This commit is contained in:
parent
947e17dcce
commit
9a53e1eb86
5 changed files with 47 additions and 31 deletions
|
@ -79,6 +79,23 @@ class GaussianSumQuery(dp_query.DPQuery):
|
|||
with tf.control_dependencies(dependencies):
|
||||
return nest.map_structure(tf.zeros_like, tensors)
|
||||
|
||||
def accumulate_record_impl(self, params, sample_state, record):
|
||||
"""Accumulates a single record into the sample state.
|
||||
|
||||
Args:
|
||||
params: The parameters for the sample.
|
||||
sample_state: The current sample state.
|
||||
record: The record to accumulate.
|
||||
|
||||
Returns:
|
||||
A tuple containing the updated sample state and the global norm.
|
||||
"""
|
||||
l2_norm_clip = params
|
||||
record_as_list = nest.flatten(record)
|
||||
clipped_as_list, norm = tf.clip_by_global_norm(record_as_list, l2_norm_clip)
|
||||
clipped = nest.pack_sequence_as(record, clipped_as_list)
|
||||
return nest.map_structure(tf.add, sample_state, clipped), norm
|
||||
|
||||
def accumulate_record(self, params, sample_state, record):
|
||||
"""Accumulates a single record into the sample state.
|
||||
|
||||
|
@ -90,11 +107,9 @@ class GaussianSumQuery(dp_query.DPQuery):
|
|||
Returns:
|
||||
The updated sample state.
|
||||
"""
|
||||
l2_norm_clip = params
|
||||
record_as_list = nest.flatten(record)
|
||||
clipped_as_list, _ = tf.clip_by_global_norm(record_as_list, l2_norm_clip)
|
||||
clipped = nest.pack_sequence_as(record, clipped_as_list)
|
||||
return nest.map_structure(tf.add, sample_state, clipped)
|
||||
new_sample_state, _ = self.accumulate_record_impl(
|
||||
params, sample_state, record)
|
||||
return new_sample_state
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Gets noised sum after all records of sample have been accumulated.
|
||||
|
|
|
@ -36,7 +36,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
query = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=10.0, stddev=0.0)
|
||||
query_result = test_utils.run_query(query, [record1, record2])
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [1.0, 1.0]
|
||||
self.assertAllClose(result, expected)
|
||||
|
@ -48,7 +48,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
query = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=5.0, stddev=0.0)
|
||||
query_result = test_utils.run_query(query, [record1, record2])
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [1.0, 1.0]
|
||||
self.assertAllClose(result, expected)
|
||||
|
@ -63,7 +63,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
assign_l2_norm_clip = tf.assign(l2_norm_clip, l2_norm_clip_placeholder)
|
||||
query = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=l2_norm_clip, stddev=0.0)
|
||||
query_result = test_utils.run_query(query, [record1, record2])
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
result = sess.run(query_result)
|
||||
|
@ -82,7 +82,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
query = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=5.0, stddev=stddev)
|
||||
query_result = test_utils.run_query(query, [record1, record2])
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
|
||||
noised_sums = []
|
||||
for _ in xrange(1000):
|
||||
|
@ -98,7 +98,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
query = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=3.0, sum_stddev=0.0, denominator=2.0)
|
||||
query_result = test_utils.run_query(query, [record1, record2])
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected_average = [1.0, 1.0]
|
||||
self.assertAllClose(result, expected_average)
|
||||
|
@ -111,7 +111,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
query = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=5.0, sum_stddev=sum_stddev, denominator=denominator)
|
||||
query_result = test_utils.run_query(query, [record1, record2])
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
|
||||
noised_averages = []
|
||||
for _ in range(1000):
|
||||
|
|
|
@ -46,7 +46,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
record1 = [1.0, [2.0, 3.0]]
|
||||
record2 = [4.0, [3.0, 2.0]]
|
||||
|
||||
query_result = test_utils.run_query(query, [record1, record2])
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [5.0, [5.0, 5.0]]
|
||||
self.assertAllClose(result, expected)
|
||||
|
@ -63,7 +63,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
record1 = [1.0, [2.0, 3.0]]
|
||||
record2 = [4.0, [3.0, 2.0]]
|
||||
|
||||
query_result = test_utils.run_query(query, [record1, record2])
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [1.0, [1.0, 1.0]]
|
||||
self.assertAllClose(result, expected)
|
||||
|
@ -80,7 +80,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
record1 = [1.0, [12.0, 9.0]] # Clipped to [1.0, [4.0, 3.0]]
|
||||
record2 = [5.0, [1.0, 2.0]] # Clipped to [4.0, [1.0, 2.0]]
|
||||
|
||||
query_result = test_utils.run_query(query, [record1, record2])
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [1.0, [1.0, 1.0]]
|
||||
self.assertAllClose(result, expected)
|
||||
|
@ -100,7 +100,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}]
|
||||
record2 = [{'a': 3.14159, 'b': 0.0}, {'c': (6.0, -4.0), 'd': [5.0]}]
|
||||
|
||||
query_result = test_utils.run_query(query, [record1, record2])
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [{'a': 1.0, 'b': 1.0}, {'c': (1.0, 1.0), 'd': [1.0]}]
|
||||
self.assertAllClose(result, expected)
|
||||
|
@ -119,7 +119,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
record1 = (3.0, [2.0, 1.5])
|
||||
record2 = (0.0, [-1.0, -3.5])
|
||||
|
||||
query_result = test_utils.run_query(query, [record1, record2])
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
|
||||
noised_averages = []
|
||||
for _ in range(1000):
|
||||
|
|
|
@ -33,7 +33,7 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
record2 = tf.constant([-1.0, 1.0])
|
||||
|
||||
query = no_privacy_query.NoPrivacySumQuery()
|
||||
query_result = test_utils.run_query(query, [record1, record2])
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [1.0, 1.0]
|
||||
self.assertAllClose(result, expected)
|
||||
|
@ -43,12 +43,11 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
record1 = tf.constant([2.0, 0.0])
|
||||
record2 = tf.constant([-1.0, 1.0])
|
||||
|
||||
weight1 = 1
|
||||
weight2 = 2
|
||||
weights = [1, 2]
|
||||
|
||||
query = no_privacy_query.NoPrivacySumQuery()
|
||||
query_result = test_utils.run_query(
|
||||
query, [record1, record2], [weight1, weight2])
|
||||
query_result, _ = test_utils.run_query(
|
||||
query, [record1, record2], weights=weights)
|
||||
result = sess.run(query_result)
|
||||
expected = [0.0, 2.0]
|
||||
self.assertAllClose(result, expected)
|
||||
|
@ -59,7 +58,7 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
record2 = tf.constant([-1.0, 2.0])
|
||||
|
||||
query = no_privacy_query.NoPrivacyAverageQuery()
|
||||
query_result = test_utils.run_query(query, [record1, record2])
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [2.0, 1.0]
|
||||
self.assertAllClose(result, expected)
|
||||
|
@ -69,12 +68,11 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
record1 = tf.constant([4.0, 0.0])
|
||||
record2 = tf.constant([-1.0, 1.0])
|
||||
|
||||
weight1 = 1
|
||||
weight2 = 3
|
||||
weights = [1, 3]
|
||||
|
||||
query = no_privacy_query.NoPrivacyAverageQuery()
|
||||
query_result = test_utils.run_query(
|
||||
query, [record1, record2], [weight1, weight2])
|
||||
query_result, _ = test_utils.run_query(
|
||||
query, [record1, record2], weights=weights)
|
||||
result = sess.run(query_result)
|
||||
expected = [0.25, 0.75]
|
||||
self.assertAllClose(result, expected)
|
||||
|
|
|
@ -21,18 +21,22 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
|
||||
def run_query(query, records, weights=None):
|
||||
def run_query(query, records, global_state=None, weights=None):
|
||||
"""Executes query on the given set of records as a single sample.
|
||||
|
||||
Args:
|
||||
query: A PrivateQuery to run.
|
||||
records: An iterable containing records to pass to the query.
|
||||
global_state: The current global state. If None, an initial global state is
|
||||
generated.
|
||||
weights: An optional iterable containing the weights of the records.
|
||||
|
||||
Returns:
|
||||
The result of the query.
|
||||
A tuple (result, new_global_state) where "result" is the result of the
|
||||
query and "new_global_state" is the updated global state.
|
||||
"""
|
||||
global_state = query.initial_global_state()
|
||||
if not global_state:
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
sample_state = query.initial_sample_state(global_state, next(iter(records)))
|
||||
if weights is None:
|
||||
|
@ -42,5 +46,4 @@ def run_query(query, records, weights=None):
|
|||
for weight, record in zip(weights, records):
|
||||
sample_state = query.accumulate_record(
|
||||
params, sample_state, record, weight)
|
||||
result, _ = query.get_noised_result(sample_state, global_state)
|
||||
return result
|
||||
return query.get_noised_result(sample_state, global_state)
|
||||
|
|
Loading…
Reference in a new issue