From ef83391ce67ba9aa6aa02d3d7237420147d1831f Mon Sep 17 00:00:00 2001 From: Zheng Xu Date: Thu, 19 Aug 2021 23:56:16 -0700 Subject: [PATCH] Use tree aggregation noise for quantile estimation. PiperOrigin-RevId: 391928297 --- .../dp_query/quantile_estimator_query.py | 28 +++++ .../dp_query/quantile_estimator_query_test.py | 118 +++++++++++------- 2 files changed, 98 insertions(+), 48 deletions(-) diff --git a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py index 4358a95..e23b83d 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py @@ -24,6 +24,7 @@ from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import gaussian_query from tensorflow_privacy.privacy.dp_query import no_privacy_query from tensorflow_privacy.privacy.dp_query import normalized_query +from tensorflow_privacy.privacy.dp_query import tree_aggregation_query class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery): @@ -209,3 +210,30 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery): del below_estimate_stddev del expected_num_records return no_privacy_query.NoPrivacyAverageQuery() + + +class TreeAggregationQuantileEstimatorQuery(QuantileEstimatorQuery): + """Iterative process to estimate target quantile of a univariate distribution. + + Unlike the base class, this uses a `TreeResidualSumQuery` to estimate the + fraction below estimate with an exact denominator. This assumes that below + estimate value is used in a SGD-like update and we want to privatize the + cumsum of the below estimate. + + See "Practical and Private (Deep) Learning without Sampling or Shuffling" + (https://arxiv.org/abs/2103.00039) for tree aggregation and privacy + accounting, and "Differentially Private Learning with Adaptive Clipping" + (https://arxiv.org/abs/1905.03871) for how below estimate is used in a + SGD-like algorithm. + """ + + def _construct_below_estimate_query(self, below_estimate_stddev, + expected_num_records): + # See comments in `QuantileEstimatorQuery._construct_below_estimate_query` + # for why clip norm 0.5 is used for the query. + sum_query = tree_aggregation_query.TreeResidualSumQuery.build_l2_gaussian_query( + clip_norm=0.5, + noise_multiplier=2 * below_estimate_stddev, + record_specs=tf.TensorSpec([])) + return normalized_query.NormalizedQuery( + sum_query, denominator=expected_num_records) diff --git a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query_test.py b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query_test.py index 679e525..d349f56 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query_test.py @@ -29,22 +29,26 @@ from tensorflow_privacy.privacy.dp_query import test_utils tf.enable_eager_execution() -def _make_quantile_estimator_query( - initial_estimate, - target_quantile, - learning_rate, - below_estimate_stddev, - expected_num_records, - geometric_update): +def _make_quantile_estimator_query(initial_estimate, + target_quantile, + learning_rate, + below_estimate_stddev, + expected_num_records, + geometric_update, + tree_aggregation=False): if expected_num_records is not None: - return quantile_estimator_query.QuantileEstimatorQuery( - initial_estimate, - target_quantile, - learning_rate, - below_estimate_stddev, - expected_num_records, - geometric_update) + if tree_aggregation: + return quantile_estimator_query.TreeAggregationQuantileEstimatorQuery( + initial_estimate, target_quantile, learning_rate, + below_estimate_stddev, expected_num_records, geometric_update) + else: + return quantile_estimator_query.QuantileEstimatorQuery( + initial_estimate, target_quantile, learning_rate, + below_estimate_stddev, expected_num_records, geometric_update) else: + if tree_aggregation: + raise ValueError( + 'Cannot set expected_num_records to None for tree aggregation.') return quantile_estimator_query.NoPrivacyQuantileEstimatorQuery( initial_estimate, target_quantile, @@ -54,8 +58,9 @@ def _make_quantile_estimator_query( class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase): - @parameterized.named_parameters(('exact', True), ('fixed', False)) - def test_target_zero(self, exact): + @parameterized.named_parameters( + ('exact', True, False), ('fixed', False, False), ('tree', False, True)) + def test_target_zero(self, exact, tree): record1 = tf.constant(8.5) record2 = tf.constant(7.25) @@ -65,7 +70,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase): learning_rate=1.0, below_estimate_stddev=0.0, expected_num_records=(None if exact else 2.0), - geometric_update=False) + geometric_update=False, + tree_aggregation=tree) global_state = query.initial_global_state() @@ -84,18 +90,20 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase): self.assertAllClose(actual_estimate.numpy(), expected_estimate) - @parameterized.named_parameters(('exact', True), ('fixed', False)) - def test_target_zero_geometric(self, exact): + @parameterized.named_parameters( + ('exact', True, False), ('fixed', False, False), ('tree', False, True)) + def test_target_zero_geometric(self, exact, tree): record1 = tf.constant(5.0) record2 = tf.constant(2.5) query = _make_quantile_estimator_query( initial_estimate=16.0, target_quantile=0.0, - learning_rate=np.log(2.0), # Geometric steps in powers of 2. + learning_rate=np.log(2.0), # Geometric steps in powers of 2. below_estimate_stddev=0.0, expected_num_records=(None if exact else 2.0), - geometric_update=True) + geometric_update=True, + tree_aggregation=tree) global_state = query.initial_global_state() @@ -116,8 +124,9 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase): self.assertAllClose(actual_estimate.numpy(), expected_estimate) - @parameterized.named_parameters(('exact', True), ('fixed', False)) - def test_target_one(self, exact): + @parameterized.named_parameters( + ('exact', True, False), ('fixed', False, False), ('tree', False, True)) + def test_target_one(self, exact, tree): record1 = tf.constant(1.5) record2 = tf.constant(2.75) @@ -127,7 +136,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase): learning_rate=1.0, below_estimate_stddev=0.0, expected_num_records=(None if exact else 2.0), - geometric_update=False) + geometric_update=False, + tree_aggregation=tree) global_state = query.initial_global_state() @@ -146,18 +156,20 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase): self.assertAllClose(actual_estimate.numpy(), expected_estimate) - @parameterized.named_parameters(('exact', True), ('fixed', False)) - def test_target_one_geometric(self, exact): + @parameterized.named_parameters( + ('exact', True, False), ('fixed', False, False), ('tree', False, True)) + def test_target_one_geometric(self, exact, tree): record1 = tf.constant(1.5) record2 = tf.constant(3.0) query = _make_quantile_estimator_query( initial_estimate=0.5, target_quantile=1.0, - learning_rate=np.log(2.0), # Geometric steps in powers of 2. + learning_rate=np.log(2.0), # Geometric steps in powers of 2. below_estimate_stddev=0.0, expected_num_records=(None if exact else 2.0), - geometric_update=True) + geometric_update=True, + tree_aggregation=tree) global_state = query.initial_global_state() @@ -179,15 +191,19 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase): self.assertAllClose(actual_estimate.numpy(), expected_estimate) @parameterized.named_parameters( - ('start_low_geometric_exact', True, True, True), - ('start_low_arithmetic_exact', True, True, False), - ('start_high_geometric_exact', True, False, True), - ('start_high_arithmetic_exact', True, False, False), - ('start_low_geometric_noised', False, True, True), - ('start_low_arithmetic_noised', False, True, False), - ('start_high_geometric_noised', False, False, True), - ('start_high_arithmetic_noised', False, False, False)) - def test_linspace(self, exact, start_low, geometric): + ('start_low_geometric_exact', True, True, True, False), + ('start_low_arithmetic_exact', True, True, False, False), + ('start_high_geometric_exact', True, False, True, False), + ('start_high_arithmetic_exact', True, False, False, False), + ('start_low_geometric_noised', False, True, True, False), + ('start_low_arithmetic_noised', False, True, False, False), + ('start_high_geometric_noised', False, False, True, False), + ('start_high_arithmetic_noised', False, False, False, False), + ('start_low_geometric_tree', False, True, True, True), + ('start_low_arithmetic_tree', False, True, False, True), + ('start_high_geometric_tree', False, False, True, True), + ('start_high_arithmetic_tree', False, False, False, True)) + def test_linspace(self, exact, start_low, geometric, tree): # 100 records equally spaced from 0 to 10 in 0.1 increments. # Test that we converge to the correct median value and bounce around it. num_records = 21 @@ -200,7 +216,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase): learning_rate=1.0, below_estimate_stddev=(0.0 if exact else 1e-2), expected_num_records=(None if exact else num_records), - geometric_update=geometric) + geometric_update=geometric, + tree_aggregation=tree) global_state = query.initial_global_state() @@ -213,15 +230,19 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase): self.assertNear(actual_estimate, 5.0, 0.25) @parameterized.named_parameters( - ('start_low_geometric_exact', True, True, True), - ('start_low_arithmetic_exact', True, True, False), - ('start_high_geometric_exact', True, False, True), - ('start_high_arithmetic_exact', True, False, False), - ('start_low_geometric_noised', False, True, True), - ('start_low_arithmetic_noised', False, True, False), - ('start_high_geometric_noised', False, False, True), - ('start_high_arithmetic_noised', False, False, False)) - def test_all_equal(self, exact, start_low, geometric): + ('start_low_geometric_exact', True, True, True, False), + ('start_low_arithmetic_exact', True, True, False, False), + ('start_high_geometric_exact', True, False, True, False), + ('start_high_arithmetic_exact', True, False, False, False), + ('start_low_geometric_noised', False, True, True, False), + ('start_low_arithmetic_noised', False, True, False, False), + ('start_high_geometric_noised', False, False, True, False), + ('start_high_arithmetic_noised', False, False, False, False), + ('start_low_geometric_tree', False, True, True, True), + ('start_low_arithmetic_tree', False, True, False, True), + ('start_high_geometric_tree', False, False, True, True), + ('start_high_arithmetic_tree', False, False, False, True)) + def test_all_equal(self, exact, start_low, geometric, tree): # 20 equal records. Test that we converge to that record and bounce around # it. Unlike the linspace test, the quantile-matching objective is very # sharp at the optimum so a decaying learning rate is necessary. @@ -236,7 +257,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase): learning_rate=learning_rate, below_estimate_stddev=(0.0 if exact else 1e-2), expected_num_records=(None if exact else num_records), - geometric_update=geometric) + geometric_update=geometric, + tree_aggregation=tree) global_state = query.initial_global_state()