A few new features for QuantileAdaptiveClipSumQuery.

1. Remove redundant global_state.l2_norm_clip from QuantileAdaptiveClipSumQuery.
2. Simplify accumulation code by deriving from SumAggregationDPQuery.
3. Add geometric update option to QuantileAdaptiveClipAverageQuery.

PiperOrigin-RevId: 292442733
This commit is contained in:
Galen Andrew 2020-01-30 15:59:02 -08:00 committed by Steve Chien
parent 856eda3aa1
commit 9bb3c1e6d8
2 changed files with 21 additions and 46 deletions

View file

@ -34,7 +34,7 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import normalized_query from tensorflow_privacy.privacy.dp_query import normalized_query
class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
"""DPQuery for sum queries with adaptive clipping. """DPQuery for sum queries with adaptive clipping.
Clipping norm is tuned adaptively to converge to a value such that a specified Clipping norm is tuned adaptively to converge to a value such that a specified
@ -44,7 +44,6 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
# pylint: disable=invalid-name # pylint: disable=invalid-name
_GlobalState = collections.namedtuple( _GlobalState = collections.namedtuple(
'_GlobalState', [ '_GlobalState', [
'l2_norm_clip',
'noise_multiplier', 'noise_multiplier',
'target_unclipped_quantile', 'target_unclipped_quantile',
'learning_rate', 'learning_rate',
@ -130,7 +129,6 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
stddev=sum_stddev) stddev=sum_stddev)
return self._GlobalState( return self._GlobalState(
initial_l2_norm_clip,
noise_multiplier, noise_multiplier,
target_unclipped_quantile, target_unclipped_quantile,
learning_rate, learning_rate,
@ -166,33 +164,7 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
# This makes the sensitivity of the clipped count query 0.5 instead of 1.0. # This makes the sensitivity of the clipped count query 0.5 instead of 1.0.
was_clipped = tf.cast(global_norm >= l2_norm_clip, tf.float32) - 0.5 was_clipped = tf.cast(global_norm >= l2_norm_clip, tf.float32) - 0.5
preprocessed_clipped_fraction_record = ( return self._SampleState(preprocessed_sum_record, was_clipped)
self._clipped_fraction_query.preprocess_record(
params.clipped_fraction_params, was_clipped))
return preprocessed_sum_record, preprocessed_clipped_fraction_record
def accumulate_preprocessed_record(
self, sample_state, preprocessed_record, weight=1):
"""See base class."""
preprocessed_sum_record, preprocessed_clipped_fraction_record = preprocessed_record
sum_state = self._sum_query.accumulate_preprocessed_record(
sample_state.sum_state, preprocessed_sum_record)
clipped_fraction_state = self._clipped_fraction_query.accumulate_preprocessed_record(
sample_state.clipped_fraction_state,
preprocessed_clipped_fraction_record)
return self._SampleState(sum_state, clipped_fraction_state)
def merge_sample_states(self, sample_state_1, sample_state_2):
"""See base class."""
return self._SampleState(
self._sum_query.merge_sample_states(
sample_state_1.sum_state,
sample_state_2.sum_state),
self._clipped_fraction_query.merge_sample_states(
sample_state_1.clipped_fraction_state,
sample_state_2.clipped_fraction_state))
def get_noised_result(self, sample_state, global_state): def get_noised_result(self, sample_state, global_state):
"""See base class.""" """See base class."""
@ -207,7 +179,7 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
sample_state.clipped_fraction_state, sample_state.clipped_fraction_state,
gs.clipped_fraction_state)) gs.clipped_fraction_state))
# Unshift clipped percentile by 0.5. (See comment in accumulate_record.) # Unshift clipped percentile by 0.5. (See comment in initializer.)
clipped_quantile = clipped_fraction_result + 0.5 clipped_quantile = clipped_fraction_result + 0.5
unclipped_quantile = 1.0 - clipped_quantile unclipped_quantile = 1.0 - clipped_quantile
@ -221,9 +193,10 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
update = global_state.learning_rate * loss_grad update = global_state.learning_rate * loss_grad
if self._geometric_update: if self._geometric_update:
new_l2_norm_clip = gs.l2_norm_clip * tf.math.exp(-update) new_l2_norm_clip = gs.sum_state.l2_norm_clip * tf.math.exp(-update)
else: else:
new_l2_norm_clip = tf.math.maximum(0.0, gs.l2_norm_clip - update) new_l2_norm_clip = tf.math.maximum(0.0,
gs.sum_state.l2_norm_clip - update)
new_sum_stddev = new_l2_norm_clip * global_state.noise_multiplier new_sum_stddev = new_l2_norm_clip * global_state.noise_multiplier
new_sum_query_global_state = self._sum_query.make_global_state( new_sum_query_global_state = self._sum_query.make_global_state(
@ -231,7 +204,6 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
stddev=new_sum_stddev) stddev=new_sum_stddev)
new_global_state = global_state._replace( new_global_state = global_state._replace(
l2_norm_clip=new_l2_norm_clip,
sum_state=new_sum_query_global_state, sum_state=new_sum_query_global_state,
clipped_fraction_state=new_clipped_fraction_state) clipped_fraction_state=new_clipped_fraction_state)
@ -258,7 +230,8 @@ class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery):
target_unclipped_quantile, target_unclipped_quantile,
learning_rate, learning_rate,
clipped_count_stddev, clipped_count_stddev,
expected_num_records): expected_num_records,
geometric_update=False):
"""Initializes the AdaptiveClipAverageQuery. """Initializes the AdaptiveClipAverageQuery.
Args: Args:
@ -277,6 +250,7 @@ class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery):
should be about 0.5 for reasonable privacy. should be about 0.5 for reasonable privacy.
expected_num_records: The expected number of records, used to estimate the expected_num_records: The expected number of records, used to estimate the
clipped count quantile. clipped count quantile.
geometric_update: If True, use geometric updating of clip.
""" """
numerator_query = QuantileAdaptiveClipSumQuery( numerator_query = QuantileAdaptiveClipSumQuery(
initial_l2_norm_clip, initial_l2_norm_clip,
@ -284,7 +258,8 @@ class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery):
target_unclipped_quantile, target_unclipped_quantile,
learning_rate, learning_rate,
clipped_count_stddev, clipped_count_stddev,
expected_num_records) expected_num_records,
geometric_update)
super(QuantileAdaptiveClipAverageQuery, self).__init__( super(QuantileAdaptiveClipAverageQuery, self).__init__(
numerator_query=numerator_query, numerator_query=numerator_query,
denominator=denominator) denominator=denominator)

View file

@ -142,7 +142,7 @@ class QuantileAdaptiveClipSumQueryTest(
global_state = query.initial_global_state() global_state = query.initial_global_state()
initial_clip = global_state.l2_norm_clip initial_clip = global_state.sum_state.l2_norm_clip
self.assertAllClose(initial_clip, 10.0) self.assertAllClose(initial_clip, 10.0)
# On the first two iterations, nothing is clipped, so the clip goes down # On the first two iterations, nothing is clipped, so the clip goes down
@ -156,7 +156,7 @@ class QuantileAdaptiveClipSumQueryTest(
actual_sum, global_state = test_utils.run_query( actual_sum, global_state = test_utils.run_query(
query, [record1, record2], global_state) query, [record1, record2], global_state)
actual_clip = global_state.l2_norm_clip actual_clip = global_state.sum_state.l2_norm_clip
self.assertAllClose(actual_clip.numpy(), expected_clip) self.assertAllClose(actual_clip.numpy(), expected_clip)
self.assertAllClose(actual_sum.numpy(), (expected_sum,)) self.assertAllClose(actual_sum.numpy(), (expected_sum,))
@ -176,7 +176,7 @@ class QuantileAdaptiveClipSumQueryTest(
global_state = query.initial_global_state() global_state = query.initial_global_state()
initial_clip = global_state.l2_norm_clip initial_clip = global_state.sum_state.l2_norm_clip
self.assertAllClose(initial_clip, 16.0) self.assertAllClose(initial_clip, 16.0)
# For two iterations, nothing is clipped, so the clip is cut in half. # For two iterations, nothing is clipped, so the clip is cut in half.
@ -192,7 +192,7 @@ class QuantileAdaptiveClipSumQueryTest(
actual_sum, global_state = test_utils.run_query( actual_sum, global_state = test_utils.run_query(
query, [record1, record2], global_state) query, [record1, record2], global_state)
actual_clip = global_state.l2_norm_clip actual_clip = global_state.sum_state.l2_norm_clip
self.assertAllClose(actual_clip.numpy(), expected_clip) self.assertAllClose(actual_clip.numpy(), expected_clip)
self.assertAllClose(actual_sum.numpy(), (expected_sum,)) self.assertAllClose(actual_sum.numpy(), (expected_sum,))
@ -211,7 +211,7 @@ class QuantileAdaptiveClipSumQueryTest(
global_state = query.initial_global_state() global_state = query.initial_global_state()
initial_clip = global_state.l2_norm_clip initial_clip = global_state.sum_state.l2_norm_clip
self.assertAllClose(initial_clip, 0.0) self.assertAllClose(initial_clip, 0.0)
# On the first two iterations, both are clipped, so the clip goes up # On the first two iterations, both are clipped, so the clip goes up
@ -225,7 +225,7 @@ class QuantileAdaptiveClipSumQueryTest(
actual_sum, global_state = test_utils.run_query( actual_sum, global_state = test_utils.run_query(
query, [record1, record2], global_state) query, [record1, record2], global_state)
actual_clip = global_state.l2_norm_clip actual_clip = global_state.sum_state.l2_norm_clip
self.assertAllClose(actual_clip.numpy(), expected_clip) self.assertAllClose(actual_clip.numpy(), expected_clip)
self.assertAllClose(actual_sum.numpy(), (expected_sum,)) self.assertAllClose(actual_sum.numpy(), (expected_sum,))
@ -245,7 +245,7 @@ class QuantileAdaptiveClipSumQueryTest(
global_state = query.initial_global_state() global_state = query.initial_global_state()
initial_clip = global_state.l2_norm_clip initial_clip = global_state.sum_state.l2_norm_clip
self.assertAllClose(initial_clip, 0.5) self.assertAllClose(initial_clip, 0.5)
# On the first two iterations, both are clipped, so the clip is doubled. # On the first two iterations, both are clipped, so the clip is doubled.
@ -261,7 +261,7 @@ class QuantileAdaptiveClipSumQueryTest(
actual_sum, global_state = test_utils.run_query( actual_sum, global_state = test_utils.run_query(
query, [record1, record2], global_state) query, [record1, record2], global_state)
actual_clip = global_state.l2_norm_clip actual_clip = global_state.sum_state.l2_norm_clip
self.assertAllClose(actual_clip.numpy(), expected_clip) self.assertAllClose(actual_clip.numpy(), expected_clip)
self.assertAllClose(actual_sum.numpy(), (expected_sum,)) self.assertAllClose(actual_sum.numpy(), (expected_sum,))
@ -295,7 +295,7 @@ class QuantileAdaptiveClipSumQueryTest(
tf.compat.v1.assign(learning_rate, 1.0 / np.sqrt(t + 1)) tf.compat.v1.assign(learning_rate, 1.0 / np.sqrt(t + 1))
_, global_state = test_utils.run_query(query, records, global_state) _, global_state = test_utils.run_query(query, records, global_state)
actual_clip = global_state.l2_norm_clip actual_clip = global_state.sum_state.l2_norm_clip
if t > 40: if t > 40:
self.assertNear(actual_clip, 5.0, 0.25) self.assertNear(actual_clip, 5.0, 0.25)
@ -325,7 +325,7 @@ class QuantileAdaptiveClipSumQueryTest(
tf.compat.v1.assign(learning_rate, 1.0 / np.sqrt(t + 1)) tf.compat.v1.assign(learning_rate, 1.0 / np.sqrt(t + 1))
_, global_state = test_utils.run_query(query, records, global_state) _, global_state = test_utils.run_query(query, records, global_state)
actual_clip = global_state.l2_norm_clip actual_clip = global_state.sum_state.l2_norm_clip
if t > 40: if t > 40:
self.assertNear(actual_clip, 5.0, 0.5) self.assertNear(actual_clip, 5.0, 0.5)