A few new features for QuantileAdaptiveClipSumQuery.
1. Remove redundant global_state.l2_norm_clip from QuantileAdaptiveClipSumQuery. 2. Simplify accumulation code by deriving from SumAggregationDPQuery. 3. Add geometric update option to QuantileAdaptiveClipAverageQuery. PiperOrigin-RevId: 292442733
This commit is contained in:
parent
856eda3aa1
commit
9bb3c1e6d8
2 changed files with 21 additions and 46 deletions
|
@ -34,7 +34,7 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||||
from tensorflow_privacy.privacy.dp_query import normalized_query
|
from tensorflow_privacy.privacy.dp_query import normalized_query
|
||||||
|
|
||||||
|
|
||||||
class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
|
class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""DPQuery for sum queries with adaptive clipping.
|
"""DPQuery for sum queries with adaptive clipping.
|
||||||
|
|
||||||
Clipping norm is tuned adaptively to converge to a value such that a specified
|
Clipping norm is tuned adaptively to converge to a value such that a specified
|
||||||
|
@ -44,7 +44,6 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
_GlobalState = collections.namedtuple(
|
_GlobalState = collections.namedtuple(
|
||||||
'_GlobalState', [
|
'_GlobalState', [
|
||||||
'l2_norm_clip',
|
|
||||||
'noise_multiplier',
|
'noise_multiplier',
|
||||||
'target_unclipped_quantile',
|
'target_unclipped_quantile',
|
||||||
'learning_rate',
|
'learning_rate',
|
||||||
|
@ -130,7 +129,6 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
|
||||||
stddev=sum_stddev)
|
stddev=sum_stddev)
|
||||||
|
|
||||||
return self._GlobalState(
|
return self._GlobalState(
|
||||||
initial_l2_norm_clip,
|
|
||||||
noise_multiplier,
|
noise_multiplier,
|
||||||
target_unclipped_quantile,
|
target_unclipped_quantile,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
|
@ -166,33 +164,7 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
|
||||||
# This makes the sensitivity of the clipped count query 0.5 instead of 1.0.
|
# This makes the sensitivity of the clipped count query 0.5 instead of 1.0.
|
||||||
was_clipped = tf.cast(global_norm >= l2_norm_clip, tf.float32) - 0.5
|
was_clipped = tf.cast(global_norm >= l2_norm_clip, tf.float32) - 0.5
|
||||||
|
|
||||||
preprocessed_clipped_fraction_record = (
|
return self._SampleState(preprocessed_sum_record, was_clipped)
|
||||||
self._clipped_fraction_query.preprocess_record(
|
|
||||||
params.clipped_fraction_params, was_clipped))
|
|
||||||
|
|
||||||
return preprocessed_sum_record, preprocessed_clipped_fraction_record
|
|
||||||
|
|
||||||
def accumulate_preprocessed_record(
|
|
||||||
self, sample_state, preprocessed_record, weight=1):
|
|
||||||
"""See base class."""
|
|
||||||
preprocessed_sum_record, preprocessed_clipped_fraction_record = preprocessed_record
|
|
||||||
sum_state = self._sum_query.accumulate_preprocessed_record(
|
|
||||||
sample_state.sum_state, preprocessed_sum_record)
|
|
||||||
|
|
||||||
clipped_fraction_state = self._clipped_fraction_query.accumulate_preprocessed_record(
|
|
||||||
sample_state.clipped_fraction_state,
|
|
||||||
preprocessed_clipped_fraction_record)
|
|
||||||
return self._SampleState(sum_state, clipped_fraction_state)
|
|
||||||
|
|
||||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
|
||||||
"""See base class."""
|
|
||||||
return self._SampleState(
|
|
||||||
self._sum_query.merge_sample_states(
|
|
||||||
sample_state_1.sum_state,
|
|
||||||
sample_state_2.sum_state),
|
|
||||||
self._clipped_fraction_query.merge_sample_states(
|
|
||||||
sample_state_1.clipped_fraction_state,
|
|
||||||
sample_state_2.clipped_fraction_state))
|
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
|
@ -207,7 +179,7 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
|
||||||
sample_state.clipped_fraction_state,
|
sample_state.clipped_fraction_state,
|
||||||
gs.clipped_fraction_state))
|
gs.clipped_fraction_state))
|
||||||
|
|
||||||
# Unshift clipped percentile by 0.5. (See comment in accumulate_record.)
|
# Unshift clipped percentile by 0.5. (See comment in initializer.)
|
||||||
clipped_quantile = clipped_fraction_result + 0.5
|
clipped_quantile = clipped_fraction_result + 0.5
|
||||||
unclipped_quantile = 1.0 - clipped_quantile
|
unclipped_quantile = 1.0 - clipped_quantile
|
||||||
|
|
||||||
|
@ -221,9 +193,10 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
|
||||||
update = global_state.learning_rate * loss_grad
|
update = global_state.learning_rate * loss_grad
|
||||||
|
|
||||||
if self._geometric_update:
|
if self._geometric_update:
|
||||||
new_l2_norm_clip = gs.l2_norm_clip * tf.math.exp(-update)
|
new_l2_norm_clip = gs.sum_state.l2_norm_clip * tf.math.exp(-update)
|
||||||
else:
|
else:
|
||||||
new_l2_norm_clip = tf.math.maximum(0.0, gs.l2_norm_clip - update)
|
new_l2_norm_clip = tf.math.maximum(0.0,
|
||||||
|
gs.sum_state.l2_norm_clip - update)
|
||||||
|
|
||||||
new_sum_stddev = new_l2_norm_clip * global_state.noise_multiplier
|
new_sum_stddev = new_l2_norm_clip * global_state.noise_multiplier
|
||||||
new_sum_query_global_state = self._sum_query.make_global_state(
|
new_sum_query_global_state = self._sum_query.make_global_state(
|
||||||
|
@ -231,7 +204,6 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
|
||||||
stddev=new_sum_stddev)
|
stddev=new_sum_stddev)
|
||||||
|
|
||||||
new_global_state = global_state._replace(
|
new_global_state = global_state._replace(
|
||||||
l2_norm_clip=new_l2_norm_clip,
|
|
||||||
sum_state=new_sum_query_global_state,
|
sum_state=new_sum_query_global_state,
|
||||||
clipped_fraction_state=new_clipped_fraction_state)
|
clipped_fraction_state=new_clipped_fraction_state)
|
||||||
|
|
||||||
|
@ -258,7 +230,8 @@ class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery):
|
||||||
target_unclipped_quantile,
|
target_unclipped_quantile,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
clipped_count_stddev,
|
clipped_count_stddev,
|
||||||
expected_num_records):
|
expected_num_records,
|
||||||
|
geometric_update=False):
|
||||||
"""Initializes the AdaptiveClipAverageQuery.
|
"""Initializes the AdaptiveClipAverageQuery.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -277,6 +250,7 @@ class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery):
|
||||||
should be about 0.5 for reasonable privacy.
|
should be about 0.5 for reasonable privacy.
|
||||||
expected_num_records: The expected number of records, used to estimate the
|
expected_num_records: The expected number of records, used to estimate the
|
||||||
clipped count quantile.
|
clipped count quantile.
|
||||||
|
geometric_update: If True, use geometric updating of clip.
|
||||||
"""
|
"""
|
||||||
numerator_query = QuantileAdaptiveClipSumQuery(
|
numerator_query = QuantileAdaptiveClipSumQuery(
|
||||||
initial_l2_norm_clip,
|
initial_l2_norm_clip,
|
||||||
|
@ -284,7 +258,8 @@ class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery):
|
||||||
target_unclipped_quantile,
|
target_unclipped_quantile,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
clipped_count_stddev,
|
clipped_count_stddev,
|
||||||
expected_num_records)
|
expected_num_records,
|
||||||
|
geometric_update)
|
||||||
super(QuantileAdaptiveClipAverageQuery, self).__init__(
|
super(QuantileAdaptiveClipAverageQuery, self).__init__(
|
||||||
numerator_query=numerator_query,
|
numerator_query=numerator_query,
|
||||||
denominator=denominator)
|
denominator=denominator)
|
||||||
|
|
|
@ -142,7 +142,7 @@ class QuantileAdaptiveClipSumQueryTest(
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
initial_clip = global_state.l2_norm_clip
|
initial_clip = global_state.sum_state.l2_norm_clip
|
||||||
self.assertAllClose(initial_clip, 10.0)
|
self.assertAllClose(initial_clip, 10.0)
|
||||||
|
|
||||||
# On the first two iterations, nothing is clipped, so the clip goes down
|
# On the first two iterations, nothing is clipped, so the clip goes down
|
||||||
|
@ -156,7 +156,7 @@ class QuantileAdaptiveClipSumQueryTest(
|
||||||
actual_sum, global_state = test_utils.run_query(
|
actual_sum, global_state = test_utils.run_query(
|
||||||
query, [record1, record2], global_state)
|
query, [record1, record2], global_state)
|
||||||
|
|
||||||
actual_clip = global_state.l2_norm_clip
|
actual_clip = global_state.sum_state.l2_norm_clip
|
||||||
|
|
||||||
self.assertAllClose(actual_clip.numpy(), expected_clip)
|
self.assertAllClose(actual_clip.numpy(), expected_clip)
|
||||||
self.assertAllClose(actual_sum.numpy(), (expected_sum,))
|
self.assertAllClose(actual_sum.numpy(), (expected_sum,))
|
||||||
|
@ -176,7 +176,7 @@ class QuantileAdaptiveClipSumQueryTest(
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
initial_clip = global_state.l2_norm_clip
|
initial_clip = global_state.sum_state.l2_norm_clip
|
||||||
self.assertAllClose(initial_clip, 16.0)
|
self.assertAllClose(initial_clip, 16.0)
|
||||||
|
|
||||||
# For two iterations, nothing is clipped, so the clip is cut in half.
|
# For two iterations, nothing is clipped, so the clip is cut in half.
|
||||||
|
@ -192,7 +192,7 @@ class QuantileAdaptiveClipSumQueryTest(
|
||||||
actual_sum, global_state = test_utils.run_query(
|
actual_sum, global_state = test_utils.run_query(
|
||||||
query, [record1, record2], global_state)
|
query, [record1, record2], global_state)
|
||||||
|
|
||||||
actual_clip = global_state.l2_norm_clip
|
actual_clip = global_state.sum_state.l2_norm_clip
|
||||||
|
|
||||||
self.assertAllClose(actual_clip.numpy(), expected_clip)
|
self.assertAllClose(actual_clip.numpy(), expected_clip)
|
||||||
self.assertAllClose(actual_sum.numpy(), (expected_sum,))
|
self.assertAllClose(actual_sum.numpy(), (expected_sum,))
|
||||||
|
@ -211,7 +211,7 @@ class QuantileAdaptiveClipSumQueryTest(
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
initial_clip = global_state.l2_norm_clip
|
initial_clip = global_state.sum_state.l2_norm_clip
|
||||||
self.assertAllClose(initial_clip, 0.0)
|
self.assertAllClose(initial_clip, 0.0)
|
||||||
|
|
||||||
# On the first two iterations, both are clipped, so the clip goes up
|
# On the first two iterations, both are clipped, so the clip goes up
|
||||||
|
@ -225,7 +225,7 @@ class QuantileAdaptiveClipSumQueryTest(
|
||||||
actual_sum, global_state = test_utils.run_query(
|
actual_sum, global_state = test_utils.run_query(
|
||||||
query, [record1, record2], global_state)
|
query, [record1, record2], global_state)
|
||||||
|
|
||||||
actual_clip = global_state.l2_norm_clip
|
actual_clip = global_state.sum_state.l2_norm_clip
|
||||||
|
|
||||||
self.assertAllClose(actual_clip.numpy(), expected_clip)
|
self.assertAllClose(actual_clip.numpy(), expected_clip)
|
||||||
self.assertAllClose(actual_sum.numpy(), (expected_sum,))
|
self.assertAllClose(actual_sum.numpy(), (expected_sum,))
|
||||||
|
@ -245,7 +245,7 @@ class QuantileAdaptiveClipSumQueryTest(
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
initial_clip = global_state.l2_norm_clip
|
initial_clip = global_state.sum_state.l2_norm_clip
|
||||||
self.assertAllClose(initial_clip, 0.5)
|
self.assertAllClose(initial_clip, 0.5)
|
||||||
|
|
||||||
# On the first two iterations, both are clipped, so the clip is doubled.
|
# On the first two iterations, both are clipped, so the clip is doubled.
|
||||||
|
@ -261,7 +261,7 @@ class QuantileAdaptiveClipSumQueryTest(
|
||||||
actual_sum, global_state = test_utils.run_query(
|
actual_sum, global_state = test_utils.run_query(
|
||||||
query, [record1, record2], global_state)
|
query, [record1, record2], global_state)
|
||||||
|
|
||||||
actual_clip = global_state.l2_norm_clip
|
actual_clip = global_state.sum_state.l2_norm_clip
|
||||||
|
|
||||||
self.assertAllClose(actual_clip.numpy(), expected_clip)
|
self.assertAllClose(actual_clip.numpy(), expected_clip)
|
||||||
self.assertAllClose(actual_sum.numpy(), (expected_sum,))
|
self.assertAllClose(actual_sum.numpy(), (expected_sum,))
|
||||||
|
@ -295,7 +295,7 @@ class QuantileAdaptiveClipSumQueryTest(
|
||||||
tf.compat.v1.assign(learning_rate, 1.0 / np.sqrt(t + 1))
|
tf.compat.v1.assign(learning_rate, 1.0 / np.sqrt(t + 1))
|
||||||
_, global_state = test_utils.run_query(query, records, global_state)
|
_, global_state = test_utils.run_query(query, records, global_state)
|
||||||
|
|
||||||
actual_clip = global_state.l2_norm_clip
|
actual_clip = global_state.sum_state.l2_norm_clip
|
||||||
|
|
||||||
if t > 40:
|
if t > 40:
|
||||||
self.assertNear(actual_clip, 5.0, 0.25)
|
self.assertNear(actual_clip, 5.0, 0.25)
|
||||||
|
@ -325,7 +325,7 @@ class QuantileAdaptiveClipSumQueryTest(
|
||||||
tf.compat.v1.assign(learning_rate, 1.0 / np.sqrt(t + 1))
|
tf.compat.v1.assign(learning_rate, 1.0 / np.sqrt(t + 1))
|
||||||
_, global_state = test_utils.run_query(query, records, global_state)
|
_, global_state = test_utils.run_query(query, records, global_state)
|
||||||
|
|
||||||
actual_clip = global_state.l2_norm_clip
|
actual_clip = global_state.sum_state.l2_norm_clip
|
||||||
|
|
||||||
if t > 40:
|
if t > 40:
|
||||||
self.assertNear(actual_clip, 5.0, 0.5)
|
self.assertNear(actual_clip, 5.0, 0.5)
|
||||||
|
|
Loading…
Reference in a new issue