diff --git a/tensorflow_privacy/privacy/dp_query/BUILD b/tensorflow_privacy/privacy/dp_query/BUILD index 32815fb..76ce76b 100644 --- a/tensorflow_privacy/privacy/dp_query/BUILD +++ b/tensorflow_privacy/privacy/dp_query/BUILD @@ -127,6 +127,7 @@ py_test( deps = [ ":quantile_adaptive_clip_sum_query", ":test_utils", + "//third_party/py/absl/testing:parameterized", "//third_party/py/numpy", "//third_party/py/tensorflow", "//third_party/py/tensorflow_privacy/privacy/analysis:privacy_ledger", diff --git a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py index e6f99d8..af51f56 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py @@ -66,7 +66,8 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): target_unclipped_quantile, learning_rate, clipped_count_stddev, - expected_num_records): + expected_num_records, + geometric_update=False): """Initializes the QuantileAdaptiveClipSumQuery. Args: @@ -84,6 +85,7 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): should be about 0.5 for reasonable privacy. expected_num_records: The expected number of records per round, used to estimate the clipped count quantile. + geometric_update: If True, use geometric updating of clip. """ self._initial_l2_norm_clip = initial_l2_norm_clip self._noise_multiplier = noise_multiplier @@ -107,6 +109,8 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): sum_stddev=clipped_count_stddev, denominator=expected_num_records) + self._geometric_update = geometric_update + def set_ledger(self, ledger): """See base class.""" self._sum_query.set_ledger(ledger) @@ -214,8 +218,12 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): # the true quantile matches the target. loss_grad = unclipped_quantile - global_state.target_unclipped_quantile - new_l2_norm_clip = gs.l2_norm_clip - global_state.learning_rate * loss_grad - new_l2_norm_clip = tf.maximum(0.0, new_l2_norm_clip) + update = global_state.learning_rate * loss_grad + + if self._geometric_update: + new_l2_norm_clip = gs.l2_norm_clip * tf.math.exp(-update) + else: + new_l2_norm_clip = tf.math.maximum(0.0, gs.l2_norm_clip - update) new_sum_stddev = new_l2_norm_clip * global_state.noise_multiplier new_sum_query_global_state = self._sum_query.make_global_state( diff --git a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py index 731ef1b..8bdbc94 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + import numpy as np import tensorflow as tf @@ -28,7 +30,8 @@ from tensorflow_privacy.privacy.dp_query import test_utils tf.compat.v1.enable_eager_execution() -class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase): +class QuantileAdaptiveClipSumQueryTest( + tf.test.TestCase, parameterized.TestCase): def test_sum_no_clip_no_noise(self): record1 = tf.constant([2.0, 0.0]) @@ -158,6 +161,42 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase): self.assertAllClose(actual_clip.numpy(), expected_clip) self.assertAllClose(actual_sum.numpy(), (expected_sum,)) + def test_adaptation_target_zero_geometric(self): + record1 = tf.constant([5.0]) + record2 = tf.constant([-2.5]) + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=16.0, + noise_multiplier=0.0, + target_unclipped_quantile=0.0, + learning_rate=np.log(2.0), # Geometric steps in powers of 2. + clipped_count_stddev=0.0, + expected_num_records=2.0, + geometric_update=True) + + global_state = query.initial_global_state() + + initial_clip = global_state.l2_norm_clip + self.assertAllClose(initial_clip, 16.0) + + # For two iterations, nothing is clipped, so the clip is cut in half. + # Then one record is clipped, so the clip goes down by only sqrt(2.0) to + # 4 / sqrt(2.0). Still only one record is clipped, so it reduces to 2.0. + # Now both records are clipped, and the clip norm stays there (at 2.0). + + four_div_root_two = 4 / np.sqrt(2.0) # approx 2.828 + + expected_sums = [2.5, 2.5, 1.5, four_div_root_two - 2.5, 0.0] + expected_clips = [8.0, 4.0, four_div_root_two, 2.0, 2.0] + for expected_sum, expected_clip in zip(expected_sums, expected_clips): + actual_sum, global_state = test_utils.run_query( + query, [record1, record2], global_state) + + actual_clip = global_state.l2_norm_clip + + self.assertAllClose(actual_clip.numpy(), expected_clip) + self.assertAllClose(actual_sum.numpy(), (expected_sum,)) + def test_adaptation_target_one(self): record1 = tf.constant([-1.5]) record2 = tf.constant([2.75]) @@ -191,22 +230,64 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase): self.assertAllClose(actual_clip.numpy(), expected_clip) self.assertAllClose(actual_sum.numpy(), (expected_sum,)) - def test_adaptation_linspace(self): + def test_adaptation_target_one_geometric(self): + record1 = tf.constant([-1.5]) + record2 = tf.constant([3.0]) + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=0.5, + noise_multiplier=0.0, + target_unclipped_quantile=1.0, + learning_rate=np.log(2.0), # Geometric steps in powers of 2. + clipped_count_stddev=0.0, + expected_num_records=2.0, + geometric_update=True) + + global_state = query.initial_global_state() + + initial_clip = global_state.l2_norm_clip + self.assertAllClose(initial_clip, 0.5) + + # On the first two iterations, both are clipped, so the clip is doubled. + # When the clip reaches 2.0, only one record is clipped, so the clip is + # multiplied by sqrt(2.0). Still only one is clipped so it increases to 4.0. + # Now both records are clipped, and the clip norm stays there (at 4.0). + + two_times_root_two = 2 * np.sqrt(2.0) # approx 2.828 + + expected_sums = [0.0, 0.0, 0.5, two_times_root_two - 1.5, 1.5] + expected_clips = [1.0, 2.0, two_times_root_two, 4.0, 4.0] + for expected_sum, expected_clip in zip(expected_sums, expected_clips): + actual_sum, global_state = test_utils.run_query( + query, [record1, record2], global_state) + + actual_clip = global_state.l2_norm_clip + + self.assertAllClose(actual_clip.numpy(), expected_clip) + self.assertAllClose(actual_sum.numpy(), (expected_sum,)) + + @parameterized.named_parameters( + ('start_low_arithmetic', True, False), + ('start_low_geometric', True, True), + ('start_high_arithmetic', False, False), + ('start_high_geometric', False, True)) + def test_adaptation_linspace(self, start_low, geometric): # 100 records equally spaced from 0 to 10 in 0.1 increments. # Test that with a decaying learning rate we converge to the correct - # median with error at most 0.1. + # median value and bounce around it. records = [tf.constant(x) for x in np.linspace( 0.0, 10.0, num=21, dtype=np.float32)] learning_rate = tf.Variable(1.0) query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( - initial_l2_norm_clip=0.0, + initial_l2_norm_clip=(1.0 if start_low else 10.0), noise_multiplier=0.0, target_unclipped_quantile=0.5, learning_rate=learning_rate, clipped_count_stddev=0.0, - expected_num_records=2.0) + expected_num_records=2.0, + geometric_update=geometric) global_state = query.initial_global_state() @@ -219,20 +300,24 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase): if t > 40: self.assertNear(actual_clip, 5.0, 0.25) - def test_adaptation_all_equal(self): - # 100 equal records. Test that with a decaying learning rate we converge to + @parameterized.named_parameters( + ('arithmetic', False), + ('geometric', True)) + def test_adaptation_all_equal(self, geometric): + # 20 equal records. Test that with a decaying learning rate we converge to # that record and bounce around it. records = [tf.constant(5.0)] * 20 learning_rate = tf.Variable(1.0) query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( - initial_l2_norm_clip=0.0, + initial_l2_norm_clip=1.0, noise_multiplier=0.0, target_unclipped_quantile=0.5, learning_rate=learning_rate, clipped_count_stddev=0.0, - expected_num_records=2.0) + expected_num_records=2.0, + geometric_update=geometric) global_state = query.initial_global_state() @@ -243,7 +328,7 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase): actual_clip = global_state.l2_norm_clip if t > 40: - self.assertNear(actual_clip, 5.0, 0.25) + self.assertNear(actual_clip, 5.0, 0.5) def test_ledger(self): record1 = tf.constant([8.5])