Use tree aggregation noise for quantile estimation.

PiperOrigin-RevId: 391928297
This commit is contained in:
Zheng Xu 2021-08-19 23:56:16 -07:00 committed by A. Unique TensorFlower
parent 0600fa26a2
commit ef83391ce6
2 changed files with 98 additions and 48 deletions

View file

@ -24,6 +24,7 @@ from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import no_privacy_query
from tensorflow_privacy.privacy.dp_query import normalized_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
@ -209,3 +210,30 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
del below_estimate_stddev
del expected_num_records
return no_privacy_query.NoPrivacyAverageQuery()
class TreeAggregationQuantileEstimatorQuery(QuantileEstimatorQuery):
"""Iterative process to estimate target quantile of a univariate distribution.
Unlike the base class, this uses a `TreeResidualSumQuery` to estimate the
fraction below estimate with an exact denominator. This assumes that below
estimate value is used in a SGD-like update and we want to privatize the
cumsum of the below estimate.
See "Practical and Private (Deep) Learning without Sampling or Shuffling"
(https://arxiv.org/abs/2103.00039) for tree aggregation and privacy
accounting, and "Differentially Private Learning with Adaptive Clipping"
(https://arxiv.org/abs/1905.03871) for how below estimate is used in a
SGD-like algorithm.
"""
def _construct_below_estimate_query(self, below_estimate_stddev,
expected_num_records):
# See comments in `QuantileEstimatorQuery._construct_below_estimate_query`
# for why clip norm 0.5 is used for the query.
sum_query = tree_aggregation_query.TreeResidualSumQuery.build_l2_gaussian_query(
clip_norm=0.5,
noise_multiplier=2 * below_estimate_stddev,
record_specs=tf.TensorSpec([]))
return normalized_query.NormalizedQuery(
sum_query, denominator=expected_num_records)

View file

@ -29,22 +29,26 @@ from tensorflow_privacy.privacy.dp_query import test_utils
tf.enable_eager_execution()
def _make_quantile_estimator_query(
initial_estimate,
target_quantile,
learning_rate,
below_estimate_stddev,
expected_num_records,
geometric_update):
def _make_quantile_estimator_query(initial_estimate,
target_quantile,
learning_rate,
below_estimate_stddev,
expected_num_records,
geometric_update,
tree_aggregation=False):
if expected_num_records is not None:
return quantile_estimator_query.QuantileEstimatorQuery(
initial_estimate,
target_quantile,
learning_rate,
below_estimate_stddev,
expected_num_records,
geometric_update)
if tree_aggregation:
return quantile_estimator_query.TreeAggregationQuantileEstimatorQuery(
initial_estimate, target_quantile, learning_rate,
below_estimate_stddev, expected_num_records, geometric_update)
else:
return quantile_estimator_query.QuantileEstimatorQuery(
initial_estimate, target_quantile, learning_rate,
below_estimate_stddev, expected_num_records, geometric_update)
else:
if tree_aggregation:
raise ValueError(
'Cannot set expected_num_records to None for tree aggregation.')
return quantile_estimator_query.NoPrivacyQuantileEstimatorQuery(
initial_estimate,
target_quantile,
@ -54,8 +58,9 @@ def _make_quantile_estimator_query(
class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('exact', True), ('fixed', False))
def test_target_zero(self, exact):
@parameterized.named_parameters(
('exact', True, False), ('fixed', False, False), ('tree', False, True))
def test_target_zero(self, exact, tree):
record1 = tf.constant(8.5)
record2 = tf.constant(7.25)
@ -65,7 +70,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
learning_rate=1.0,
below_estimate_stddev=0.0,
expected_num_records=(None if exact else 2.0),
geometric_update=False)
geometric_update=False,
tree_aggregation=tree)
global_state = query.initial_global_state()
@ -84,18 +90,20 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
@parameterized.named_parameters(('exact', True), ('fixed', False))
def test_target_zero_geometric(self, exact):
@parameterized.named_parameters(
('exact', True, False), ('fixed', False, False), ('tree', False, True))
def test_target_zero_geometric(self, exact, tree):
record1 = tf.constant(5.0)
record2 = tf.constant(2.5)
query = _make_quantile_estimator_query(
initial_estimate=16.0,
target_quantile=0.0,
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
below_estimate_stddev=0.0,
expected_num_records=(None if exact else 2.0),
geometric_update=True)
geometric_update=True,
tree_aggregation=tree)
global_state = query.initial_global_state()
@ -116,8 +124,9 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
@parameterized.named_parameters(('exact', True), ('fixed', False))
def test_target_one(self, exact):
@parameterized.named_parameters(
('exact', True, False), ('fixed', False, False), ('tree', False, True))
def test_target_one(self, exact, tree):
record1 = tf.constant(1.5)
record2 = tf.constant(2.75)
@ -127,7 +136,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
learning_rate=1.0,
below_estimate_stddev=0.0,
expected_num_records=(None if exact else 2.0),
geometric_update=False)
geometric_update=False,
tree_aggregation=tree)
global_state = query.initial_global_state()
@ -146,18 +156,20 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
@parameterized.named_parameters(('exact', True), ('fixed', False))
def test_target_one_geometric(self, exact):
@parameterized.named_parameters(
('exact', True, False), ('fixed', False, False), ('tree', False, True))
def test_target_one_geometric(self, exact, tree):
record1 = tf.constant(1.5)
record2 = tf.constant(3.0)
query = _make_quantile_estimator_query(
initial_estimate=0.5,
target_quantile=1.0,
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
below_estimate_stddev=0.0,
expected_num_records=(None if exact else 2.0),
geometric_update=True)
geometric_update=True,
tree_aggregation=tree)
global_state = query.initial_global_state()
@ -179,15 +191,19 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
@parameterized.named_parameters(
('start_low_geometric_exact', True, True, True),
('start_low_arithmetic_exact', True, True, False),
('start_high_geometric_exact', True, False, True),
('start_high_arithmetic_exact', True, False, False),
('start_low_geometric_noised', False, True, True),
('start_low_arithmetic_noised', False, True, False),
('start_high_geometric_noised', False, False, True),
('start_high_arithmetic_noised', False, False, False))
def test_linspace(self, exact, start_low, geometric):
('start_low_geometric_exact', True, True, True, False),
('start_low_arithmetic_exact', True, True, False, False),
('start_high_geometric_exact', True, False, True, False),
('start_high_arithmetic_exact', True, False, False, False),
('start_low_geometric_noised', False, True, True, False),
('start_low_arithmetic_noised', False, True, False, False),
('start_high_geometric_noised', False, False, True, False),
('start_high_arithmetic_noised', False, False, False, False),
('start_low_geometric_tree', False, True, True, True),
('start_low_arithmetic_tree', False, True, False, True),
('start_high_geometric_tree', False, False, True, True),
('start_high_arithmetic_tree', False, False, False, True))
def test_linspace(self, exact, start_low, geometric, tree):
# 100 records equally spaced from 0 to 10 in 0.1 increments.
# Test that we converge to the correct median value and bounce around it.
num_records = 21
@ -200,7 +216,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
learning_rate=1.0,
below_estimate_stddev=(0.0 if exact else 1e-2),
expected_num_records=(None if exact else num_records),
geometric_update=geometric)
geometric_update=geometric,
tree_aggregation=tree)
global_state = query.initial_global_state()
@ -213,15 +230,19 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
self.assertNear(actual_estimate, 5.0, 0.25)
@parameterized.named_parameters(
('start_low_geometric_exact', True, True, True),
('start_low_arithmetic_exact', True, True, False),
('start_high_geometric_exact', True, False, True),
('start_high_arithmetic_exact', True, False, False),
('start_low_geometric_noised', False, True, True),
('start_low_arithmetic_noised', False, True, False),
('start_high_geometric_noised', False, False, True),
('start_high_arithmetic_noised', False, False, False))
def test_all_equal(self, exact, start_low, geometric):
('start_low_geometric_exact', True, True, True, False),
('start_low_arithmetic_exact', True, True, False, False),
('start_high_geometric_exact', True, False, True, False),
('start_high_arithmetic_exact', True, False, False, False),
('start_low_geometric_noised', False, True, True, False),
('start_low_arithmetic_noised', False, True, False, False),
('start_high_geometric_noised', False, False, True, False),
('start_high_arithmetic_noised', False, False, False, False),
('start_low_geometric_tree', False, True, True, True),
('start_low_arithmetic_tree', False, True, False, True),
('start_high_geometric_tree', False, False, True, True),
('start_high_arithmetic_tree', False, False, False, True))
def test_all_equal(self, exact, start_low, geometric, tree):
# 20 equal records. Test that we converge to that record and bounce around
# it. Unlike the linspace test, the quantile-matching objective is very
# sharp at the optimum so a decaying learning rate is necessary.
@ -236,7 +257,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
learning_rate=learning_rate,
below_estimate_stddev=(0.0 if exact else 1e-2),
expected_num_records=(None if exact else num_records),
geometric_update=geometric)
geometric_update=geometric,
tree_aggregation=tree)
global_state = query.initial_global_state()