Use tree aggregation noise for quantile estimation.
PiperOrigin-RevId: 391928297
This commit is contained in:
parent
0600fa26a2
commit
ef83391ce6
2 changed files with 98 additions and 48 deletions
|
@ -24,6 +24,7 @@ from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||||
from tensorflow_privacy.privacy.dp_query import no_privacy_query
|
from tensorflow_privacy.privacy.dp_query import no_privacy_query
|
||||||
from tensorflow_privacy.privacy.dp_query import normalized_query
|
from tensorflow_privacy.privacy.dp_query import normalized_query
|
||||||
|
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
|
||||||
|
|
||||||
|
|
||||||
class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
@ -209,3 +210,30 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
|
||||||
del below_estimate_stddev
|
del below_estimate_stddev
|
||||||
del expected_num_records
|
del expected_num_records
|
||||||
return no_privacy_query.NoPrivacyAverageQuery()
|
return no_privacy_query.NoPrivacyAverageQuery()
|
||||||
|
|
||||||
|
|
||||||
|
class TreeAggregationQuantileEstimatorQuery(QuantileEstimatorQuery):
|
||||||
|
"""Iterative process to estimate target quantile of a univariate distribution.
|
||||||
|
|
||||||
|
Unlike the base class, this uses a `TreeResidualSumQuery` to estimate the
|
||||||
|
fraction below estimate with an exact denominator. This assumes that below
|
||||||
|
estimate value is used in a SGD-like update and we want to privatize the
|
||||||
|
cumsum of the below estimate.
|
||||||
|
|
||||||
|
See "Practical and Private (Deep) Learning without Sampling or Shuffling"
|
||||||
|
(https://arxiv.org/abs/2103.00039) for tree aggregation and privacy
|
||||||
|
accounting, and "Differentially Private Learning with Adaptive Clipping"
|
||||||
|
(https://arxiv.org/abs/1905.03871) for how below estimate is used in a
|
||||||
|
SGD-like algorithm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _construct_below_estimate_query(self, below_estimate_stddev,
|
||||||
|
expected_num_records):
|
||||||
|
# See comments in `QuantileEstimatorQuery._construct_below_estimate_query`
|
||||||
|
# for why clip norm 0.5 is used for the query.
|
||||||
|
sum_query = tree_aggregation_query.TreeResidualSumQuery.build_l2_gaussian_query(
|
||||||
|
clip_norm=0.5,
|
||||||
|
noise_multiplier=2 * below_estimate_stddev,
|
||||||
|
record_specs=tf.TensorSpec([]))
|
||||||
|
return normalized_query.NormalizedQuery(
|
||||||
|
sum_query, denominator=expected_num_records)
|
||||||
|
|
|
@ -29,22 +29,26 @@ from tensorflow_privacy.privacy.dp_query import test_utils
|
||||||
tf.enable_eager_execution()
|
tf.enable_eager_execution()
|
||||||
|
|
||||||
|
|
||||||
def _make_quantile_estimator_query(
|
def _make_quantile_estimator_query(initial_estimate,
|
||||||
initial_estimate,
|
|
||||||
target_quantile,
|
target_quantile,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
below_estimate_stddev,
|
below_estimate_stddev,
|
||||||
expected_num_records,
|
expected_num_records,
|
||||||
geometric_update):
|
geometric_update,
|
||||||
|
tree_aggregation=False):
|
||||||
if expected_num_records is not None:
|
if expected_num_records is not None:
|
||||||
return quantile_estimator_query.QuantileEstimatorQuery(
|
if tree_aggregation:
|
||||||
initial_estimate,
|
return quantile_estimator_query.TreeAggregationQuantileEstimatorQuery(
|
||||||
target_quantile,
|
initial_estimate, target_quantile, learning_rate,
|
||||||
learning_rate,
|
below_estimate_stddev, expected_num_records, geometric_update)
|
||||||
below_estimate_stddev,
|
|
||||||
expected_num_records,
|
|
||||||
geometric_update)
|
|
||||||
else:
|
else:
|
||||||
|
return quantile_estimator_query.QuantileEstimatorQuery(
|
||||||
|
initial_estimate, target_quantile, learning_rate,
|
||||||
|
below_estimate_stddev, expected_num_records, geometric_update)
|
||||||
|
else:
|
||||||
|
if tree_aggregation:
|
||||||
|
raise ValueError(
|
||||||
|
'Cannot set expected_num_records to None for tree aggregation.')
|
||||||
return quantile_estimator_query.NoPrivacyQuantileEstimatorQuery(
|
return quantile_estimator_query.NoPrivacyQuantileEstimatorQuery(
|
||||||
initial_estimate,
|
initial_estimate,
|
||||||
target_quantile,
|
target_quantile,
|
||||||
|
@ -54,8 +58,9 @@ def _make_quantile_estimator_query(
|
||||||
|
|
||||||
class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@parameterized.named_parameters(('exact', True), ('fixed', False))
|
@parameterized.named_parameters(
|
||||||
def test_target_zero(self, exact):
|
('exact', True, False), ('fixed', False, False), ('tree', False, True))
|
||||||
|
def test_target_zero(self, exact, tree):
|
||||||
record1 = tf.constant(8.5)
|
record1 = tf.constant(8.5)
|
||||||
record2 = tf.constant(7.25)
|
record2 = tf.constant(7.25)
|
||||||
|
|
||||||
|
@ -65,7 +70,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
below_estimate_stddev=0.0,
|
below_estimate_stddev=0.0,
|
||||||
expected_num_records=(None if exact else 2.0),
|
expected_num_records=(None if exact else 2.0),
|
||||||
geometric_update=False)
|
geometric_update=False,
|
||||||
|
tree_aggregation=tree)
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
@ -84,8 +90,9 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
||||||
|
|
||||||
@parameterized.named_parameters(('exact', True), ('fixed', False))
|
@parameterized.named_parameters(
|
||||||
def test_target_zero_geometric(self, exact):
|
('exact', True, False), ('fixed', False, False), ('tree', False, True))
|
||||||
|
def test_target_zero_geometric(self, exact, tree):
|
||||||
record1 = tf.constant(5.0)
|
record1 = tf.constant(5.0)
|
||||||
record2 = tf.constant(2.5)
|
record2 = tf.constant(2.5)
|
||||||
|
|
||||||
|
@ -95,7 +102,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
|
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
|
||||||
below_estimate_stddev=0.0,
|
below_estimate_stddev=0.0,
|
||||||
expected_num_records=(None if exact else 2.0),
|
expected_num_records=(None if exact else 2.0),
|
||||||
geometric_update=True)
|
geometric_update=True,
|
||||||
|
tree_aggregation=tree)
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
@ -116,8 +124,9 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
||||||
|
|
||||||
@parameterized.named_parameters(('exact', True), ('fixed', False))
|
@parameterized.named_parameters(
|
||||||
def test_target_one(self, exact):
|
('exact', True, False), ('fixed', False, False), ('tree', False, True))
|
||||||
|
def test_target_one(self, exact, tree):
|
||||||
record1 = tf.constant(1.5)
|
record1 = tf.constant(1.5)
|
||||||
record2 = tf.constant(2.75)
|
record2 = tf.constant(2.75)
|
||||||
|
|
||||||
|
@ -127,7 +136,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
below_estimate_stddev=0.0,
|
below_estimate_stddev=0.0,
|
||||||
expected_num_records=(None if exact else 2.0),
|
expected_num_records=(None if exact else 2.0),
|
||||||
geometric_update=False)
|
geometric_update=False,
|
||||||
|
tree_aggregation=tree)
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
@ -146,8 +156,9 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
||||||
|
|
||||||
@parameterized.named_parameters(('exact', True), ('fixed', False))
|
@parameterized.named_parameters(
|
||||||
def test_target_one_geometric(self, exact):
|
('exact', True, False), ('fixed', False, False), ('tree', False, True))
|
||||||
|
def test_target_one_geometric(self, exact, tree):
|
||||||
record1 = tf.constant(1.5)
|
record1 = tf.constant(1.5)
|
||||||
record2 = tf.constant(3.0)
|
record2 = tf.constant(3.0)
|
||||||
|
|
||||||
|
@ -157,7 +168,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
|
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
|
||||||
below_estimate_stddev=0.0,
|
below_estimate_stddev=0.0,
|
||||||
expected_num_records=(None if exact else 2.0),
|
expected_num_records=(None if exact else 2.0),
|
||||||
geometric_update=True)
|
geometric_update=True,
|
||||||
|
tree_aggregation=tree)
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
@ -179,15 +191,19 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('start_low_geometric_exact', True, True, True),
|
('start_low_geometric_exact', True, True, True, False),
|
||||||
('start_low_arithmetic_exact', True, True, False),
|
('start_low_arithmetic_exact', True, True, False, False),
|
||||||
('start_high_geometric_exact', True, False, True),
|
('start_high_geometric_exact', True, False, True, False),
|
||||||
('start_high_arithmetic_exact', True, False, False),
|
('start_high_arithmetic_exact', True, False, False, False),
|
||||||
('start_low_geometric_noised', False, True, True),
|
('start_low_geometric_noised', False, True, True, False),
|
||||||
('start_low_arithmetic_noised', False, True, False),
|
('start_low_arithmetic_noised', False, True, False, False),
|
||||||
('start_high_geometric_noised', False, False, True),
|
('start_high_geometric_noised', False, False, True, False),
|
||||||
('start_high_arithmetic_noised', False, False, False))
|
('start_high_arithmetic_noised', False, False, False, False),
|
||||||
def test_linspace(self, exact, start_low, geometric):
|
('start_low_geometric_tree', False, True, True, True),
|
||||||
|
('start_low_arithmetic_tree', False, True, False, True),
|
||||||
|
('start_high_geometric_tree', False, False, True, True),
|
||||||
|
('start_high_arithmetic_tree', False, False, False, True))
|
||||||
|
def test_linspace(self, exact, start_low, geometric, tree):
|
||||||
# 100 records equally spaced from 0 to 10 in 0.1 increments.
|
# 100 records equally spaced from 0 to 10 in 0.1 increments.
|
||||||
# Test that we converge to the correct median value and bounce around it.
|
# Test that we converge to the correct median value and bounce around it.
|
||||||
num_records = 21
|
num_records = 21
|
||||||
|
@ -200,7 +216,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
below_estimate_stddev=(0.0 if exact else 1e-2),
|
below_estimate_stddev=(0.0 if exact else 1e-2),
|
||||||
expected_num_records=(None if exact else num_records),
|
expected_num_records=(None if exact else num_records),
|
||||||
geometric_update=geometric)
|
geometric_update=geometric,
|
||||||
|
tree_aggregation=tree)
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
@ -213,15 +230,19 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertNear(actual_estimate, 5.0, 0.25)
|
self.assertNear(actual_estimate, 5.0, 0.25)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('start_low_geometric_exact', True, True, True),
|
('start_low_geometric_exact', True, True, True, False),
|
||||||
('start_low_arithmetic_exact', True, True, False),
|
('start_low_arithmetic_exact', True, True, False, False),
|
||||||
('start_high_geometric_exact', True, False, True),
|
('start_high_geometric_exact', True, False, True, False),
|
||||||
('start_high_arithmetic_exact', True, False, False),
|
('start_high_arithmetic_exact', True, False, False, False),
|
||||||
('start_low_geometric_noised', False, True, True),
|
('start_low_geometric_noised', False, True, True, False),
|
||||||
('start_low_arithmetic_noised', False, True, False),
|
('start_low_arithmetic_noised', False, True, False, False),
|
||||||
('start_high_geometric_noised', False, False, True),
|
('start_high_geometric_noised', False, False, True, False),
|
||||||
('start_high_arithmetic_noised', False, False, False))
|
('start_high_arithmetic_noised', False, False, False, False),
|
||||||
def test_all_equal(self, exact, start_low, geometric):
|
('start_low_geometric_tree', False, True, True, True),
|
||||||
|
('start_low_arithmetic_tree', False, True, False, True),
|
||||||
|
('start_high_geometric_tree', False, False, True, True),
|
||||||
|
('start_high_arithmetic_tree', False, False, False, True))
|
||||||
|
def test_all_equal(self, exact, start_low, geometric, tree):
|
||||||
# 20 equal records. Test that we converge to that record and bounce around
|
# 20 equal records. Test that we converge to that record and bounce around
|
||||||
# it. Unlike the linspace test, the quantile-matching objective is very
|
# it. Unlike the linspace test, the quantile-matching objective is very
|
||||||
# sharp at the optimum so a decaying learning rate is necessary.
|
# sharp at the optimum so a decaying learning rate is necessary.
|
||||||
|
@ -236,7 +257,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
below_estimate_stddev=(0.0 if exact else 1e-2),
|
below_estimate_stddev=(0.0 if exact else 1e-2),
|
||||||
expected_num_records=(None if exact else num_records),
|
expected_num_records=(None if exact else num_records),
|
||||||
geometric_update=geometric)
|
geometric_update=geometric,
|
||||||
|
tree_aggregation=tree)
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue