diff --git a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py index 31edb1f..25ec88b 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py @@ -34,7 +34,7 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query from tensorflow_privacy.privacy.dp_query import normalized_query -class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): +class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery): """DPQuery for sum queries with adaptive clipping. Clipping norm is tuned adaptively to converge to a value such that a specified @@ -44,7 +44,6 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): # pylint: disable=invalid-name _GlobalState = collections.namedtuple( '_GlobalState', [ - 'l2_norm_clip', 'noise_multiplier', 'target_unclipped_quantile', 'learning_rate', @@ -130,7 +129,6 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): stddev=sum_stddev) return self._GlobalState( - initial_l2_norm_clip, noise_multiplier, target_unclipped_quantile, learning_rate, @@ -166,33 +164,7 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): # This makes the sensitivity of the clipped count query 0.5 instead of 1.0. was_clipped = tf.cast(global_norm >= l2_norm_clip, tf.float32) - 0.5 - preprocessed_clipped_fraction_record = ( - self._clipped_fraction_query.preprocess_record( - params.clipped_fraction_params, was_clipped)) - - return preprocessed_sum_record, preprocessed_clipped_fraction_record - - def accumulate_preprocessed_record( - self, sample_state, preprocessed_record, weight=1): - """See base class.""" - preprocessed_sum_record, preprocessed_clipped_fraction_record = preprocessed_record - sum_state = self._sum_query.accumulate_preprocessed_record( - sample_state.sum_state, preprocessed_sum_record) - - clipped_fraction_state = self._clipped_fraction_query.accumulate_preprocessed_record( - sample_state.clipped_fraction_state, - preprocessed_clipped_fraction_record) - return self._SampleState(sum_state, clipped_fraction_state) - - def merge_sample_states(self, sample_state_1, sample_state_2): - """See base class.""" - return self._SampleState( - self._sum_query.merge_sample_states( - sample_state_1.sum_state, - sample_state_2.sum_state), - self._clipped_fraction_query.merge_sample_states( - sample_state_1.clipped_fraction_state, - sample_state_2.clipped_fraction_state)) + return self._SampleState(preprocessed_sum_record, was_clipped) def get_noised_result(self, sample_state, global_state): """See base class.""" @@ -207,7 +179,7 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): sample_state.clipped_fraction_state, gs.clipped_fraction_state)) - # Unshift clipped percentile by 0.5. (See comment in accumulate_record.) + # Unshift clipped percentile by 0.5. (See comment in initializer.) clipped_quantile = clipped_fraction_result + 0.5 unclipped_quantile = 1.0 - clipped_quantile @@ -221,9 +193,10 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): update = global_state.learning_rate * loss_grad if self._geometric_update: - new_l2_norm_clip = gs.l2_norm_clip * tf.math.exp(-update) + new_l2_norm_clip = gs.sum_state.l2_norm_clip * tf.math.exp(-update) else: - new_l2_norm_clip = tf.math.maximum(0.0, gs.l2_norm_clip - update) + new_l2_norm_clip = tf.math.maximum(0.0, + gs.sum_state.l2_norm_clip - update) new_sum_stddev = new_l2_norm_clip * global_state.noise_multiplier new_sum_query_global_state = self._sum_query.make_global_state( @@ -231,7 +204,6 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): stddev=new_sum_stddev) new_global_state = global_state._replace( - l2_norm_clip=new_l2_norm_clip, sum_state=new_sum_query_global_state, clipped_fraction_state=new_clipped_fraction_state) @@ -258,7 +230,8 @@ class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery): target_unclipped_quantile, learning_rate, clipped_count_stddev, - expected_num_records): + expected_num_records, + geometric_update=False): """Initializes the AdaptiveClipAverageQuery. Args: @@ -277,6 +250,7 @@ class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery): should be about 0.5 for reasonable privacy. expected_num_records: The expected number of records, used to estimate the clipped count quantile. + geometric_update: If True, use geometric updating of clip. """ numerator_query = QuantileAdaptiveClipSumQuery( initial_l2_norm_clip, @@ -284,7 +258,8 @@ class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery): target_unclipped_quantile, learning_rate, clipped_count_stddev, - expected_num_records) + expected_num_records, + geometric_update) super(QuantileAdaptiveClipAverageQuery, self).__init__( numerator_query=numerator_query, denominator=denominator) diff --git a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py index 67d4041..dbe766f 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py @@ -142,7 +142,7 @@ class QuantileAdaptiveClipSumQueryTest( global_state = query.initial_global_state() - initial_clip = global_state.l2_norm_clip + initial_clip = global_state.sum_state.l2_norm_clip self.assertAllClose(initial_clip, 10.0) # On the first two iterations, nothing is clipped, so the clip goes down @@ -156,7 +156,7 @@ class QuantileAdaptiveClipSumQueryTest( actual_sum, global_state = test_utils.run_query( query, [record1, record2], global_state) - actual_clip = global_state.l2_norm_clip + actual_clip = global_state.sum_state.l2_norm_clip self.assertAllClose(actual_clip.numpy(), expected_clip) self.assertAllClose(actual_sum.numpy(), (expected_sum,)) @@ -176,7 +176,7 @@ class QuantileAdaptiveClipSumQueryTest( global_state = query.initial_global_state() - initial_clip = global_state.l2_norm_clip + initial_clip = global_state.sum_state.l2_norm_clip self.assertAllClose(initial_clip, 16.0) # For two iterations, nothing is clipped, so the clip is cut in half. @@ -192,7 +192,7 @@ class QuantileAdaptiveClipSumQueryTest( actual_sum, global_state = test_utils.run_query( query, [record1, record2], global_state) - actual_clip = global_state.l2_norm_clip + actual_clip = global_state.sum_state.l2_norm_clip self.assertAllClose(actual_clip.numpy(), expected_clip) self.assertAllClose(actual_sum.numpy(), (expected_sum,)) @@ -211,7 +211,7 @@ class QuantileAdaptiveClipSumQueryTest( global_state = query.initial_global_state() - initial_clip = global_state.l2_norm_clip + initial_clip = global_state.sum_state.l2_norm_clip self.assertAllClose(initial_clip, 0.0) # On the first two iterations, both are clipped, so the clip goes up @@ -225,7 +225,7 @@ class QuantileAdaptiveClipSumQueryTest( actual_sum, global_state = test_utils.run_query( query, [record1, record2], global_state) - actual_clip = global_state.l2_norm_clip + actual_clip = global_state.sum_state.l2_norm_clip self.assertAllClose(actual_clip.numpy(), expected_clip) self.assertAllClose(actual_sum.numpy(), (expected_sum,)) @@ -245,7 +245,7 @@ class QuantileAdaptiveClipSumQueryTest( global_state = query.initial_global_state() - initial_clip = global_state.l2_norm_clip + initial_clip = global_state.sum_state.l2_norm_clip self.assertAllClose(initial_clip, 0.5) # On the first two iterations, both are clipped, so the clip is doubled. @@ -261,7 +261,7 @@ class QuantileAdaptiveClipSumQueryTest( actual_sum, global_state = test_utils.run_query( query, [record1, record2], global_state) - actual_clip = global_state.l2_norm_clip + actual_clip = global_state.sum_state.l2_norm_clip self.assertAllClose(actual_clip.numpy(), expected_clip) self.assertAllClose(actual_sum.numpy(), (expected_sum,)) @@ -295,7 +295,7 @@ class QuantileAdaptiveClipSumQueryTest( tf.compat.v1.assign(learning_rate, 1.0 / np.sqrt(t + 1)) _, global_state = test_utils.run_query(query, records, global_state) - actual_clip = global_state.l2_norm_clip + actual_clip = global_state.sum_state.l2_norm_clip if t > 40: self.assertNear(actual_clip, 5.0, 0.25) @@ -325,7 +325,7 @@ class QuantileAdaptiveClipSumQueryTest( tf.compat.v1.assign(learning_rate, 1.0 / np.sqrt(t + 1)) _, global_state = test_utils.run_query(query, records, global_state) - actual_clip = global_state.l2_norm_clip + actual_clip = global_state.sum_state.l2_norm_clip if t > 40: self.assertNear(actual_clip, 5.0, 0.5)