diff --git a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py index dbe766f..fef4d88 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py @@ -273,10 +273,41 @@ class QuantileAdaptiveClipSumQueryTest( ('start_high_geometric', False, True)) def test_adaptation_linspace(self, start_low, geometric): # 100 records equally spaced from 0 to 10 in 0.1 increments. - # Test that with a decaying learning rate we converge to the correct - # median value and bounce around it. + # Test that we converge to the correct median value and bounce around it. + num_records = 21 records = [tf.constant(x) for x in np.linspace( - 0.0, 10.0, num=21, dtype=np.float32)] + 0.0, 10.0, num=num_records, dtype=np.float32)] + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=(1.0 if start_low else 10.0), + noise_multiplier=0.0, + target_unclipped_quantile=0.5, + learning_rate=1.0, + clipped_count_stddev=0.0, + expected_num_records=num_records, + geometric_update=geometric) + + global_state = query.initial_global_state() + + for t in range(50): + _, global_state = test_utils.run_query(query, records, global_state) + + actual_clip = global_state.sum_state.l2_norm_clip + + if t > 40: + self.assertNear(actual_clip, 5.0, 0.25) + + @parameterized.named_parameters( + ('start_low_arithmetic', True, False), + ('start_low_geometric', True, True), + ('start_high_arithmetic', False, False), + ('start_high_geometric', False, True)) + def test_adaptation_all_equal(self, start_low, geometric): + # 20 equal records. Test that we converge to that record and bounce around + # it. Unlike the linspace test, the quantile-matching objective is very + # sharp at the optimum so a decaying learning rate is necessary. + num_records = 20 + records = [tf.constant(5.0)] * num_records learning_rate = tf.Variable(1.0) @@ -286,37 +317,7 @@ class QuantileAdaptiveClipSumQueryTest( target_unclipped_quantile=0.5, learning_rate=learning_rate, clipped_count_stddev=0.0, - expected_num_records=2.0, - geometric_update=geometric) - - global_state = query.initial_global_state() - - for t in range(50): - tf.compat.v1.assign(learning_rate, 1.0 / np.sqrt(t + 1)) - _, global_state = test_utils.run_query(query, records, global_state) - - actual_clip = global_state.sum_state.l2_norm_clip - - if t > 40: - self.assertNear(actual_clip, 5.0, 0.25) - - @parameterized.named_parameters( - ('arithmetic', False), - ('geometric', True)) - def test_adaptation_all_equal(self, geometric): - # 20 equal records. Test that with a decaying learning rate we converge to - # that record and bounce around it. - records = [tf.constant(5.0)] * 20 - - learning_rate = tf.Variable(1.0) - - query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( - initial_l2_norm_clip=1.0, - noise_multiplier=0.0, - target_unclipped_quantile=0.5, - learning_rate=learning_rate, - clipped_count_stddev=0.0, - expected_num_records=2.0, + expected_num_records=num_records, geometric_update=geometric) global_state = query.initial_global_state()