Update reset and pre-process functions for tree aggregation queries. Minor comments update for adaptive clip query tests.

PiperOrigin-RevId: 396483111
This commit is contained in:
Zheng Xu 2021-09-13 17:47:50 -07:00 committed by A. Unique TensorFlower
parent 0d05f2eb18
commit b572707cfc
4 changed files with 47 additions and 9 deletions

View file

@ -230,7 +230,7 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase,
('start_high_arithmetic', False, False),
('start_high_geometric', False, True))
def test_adaptation_linspace(self, start_low, geometric):
# 100 records equally spaced from 0 to 10 in 0.1 increments.
# `num_records` records equally spaced from 0 to 10 in 0.1 increments.
# Test that we converge to the correct median value and bounce around it.
num_records = 21
records = [
@ -262,9 +262,10 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase,
('start_high_arithmetic', False, False),
('start_high_geometric', False, True))
def test_adaptation_all_equal(self, start_low, geometric):
# 20 equal records. Test that we converge to that record and bounce around
# it. Unlike the linspace test, the quantile-matching objective is very
# sharp at the optimum so a decaying learning rate is necessary.
# `num_records` equal records. Test that we converge to that record and
# bounce around it. Unlike the linspace test, the quantile-matching
# objective is very sharp at the optimum so a decaying learning rate is
# necessary.
num_records = 20
records = [tf.constant(5.0)] * num_records

View file

@ -74,6 +74,15 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
updating is preferred for non-negative records like vector norms that
could potentially be very large or very close to zero.
"""
if target_quantile < 0 or target_quantile > 1:
raise ValueError(
f'`target_quantile` must be between 0 and 1, got {target_quantile}.')
if learning_rate < 0:
raise ValueError(
f'`learning_rate` must be non-negative, got {learning_rate}')
self._initial_estimate = initial_estimate
self._target_quantile = target_quantile
self._learning_rate = learning_rate
@ -208,7 +217,7 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
return no_privacy_query.NoPrivacyAverageQuery()
class TreeAggregationQuantileEstimatorQuery(QuantileEstimatorQuery):
class TreeQuantileEstimatorQuery(QuantileEstimatorQuery):
"""Iterative process to estimate target quantile of a univariate distribution.
Unlike the base class, this uses a `TreeResidualSumQuery` to estimate the

View file

@ -38,7 +38,7 @@ def _make_quantile_estimator_query(initial_estimate,
tree_aggregation=False):
if expected_num_records is not None:
if tree_aggregation:
return quantile_estimator_query.TreeAggregationQuantileEstimatorQuery(
return quantile_estimator_query.TreeQuantileEstimatorQuery(
initial_estimate, target_quantile, learning_rate,
below_estimate_stddev, expected_num_records, geometric_update)
else:

View file

@ -360,6 +360,23 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return global_state.clip_value
def preprocess_record_l2_impl(self, params, record):
"""Clips the l2 norm, returning the clipped record and the l2 norm.
Args:
params: The parameters for the sample.
record: The record to be processed.
Returns:
A tuple (preprocessed_records, l2_norm) where `preprocessed_records` is
the structure of preprocessed tensors, and l2_norm is the total l2 norm
before clipping.
"""
l2_norm_clip = params
record_as_list = tf.nest.flatten(record)
clipped_as_list, norm = tf.clip_by_global_norm(record_as_list, l2_norm_clip)
return tf.nest.pack_sequence_as(record, clipped_as_list), norm
def preprocess_record(self, params, record):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
@ -405,7 +422,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
`get_noised_result` when the restarting condition is met.
Args:
noised_results: Noised cumulative sum returned by `get_noised_result`.
noised_results: Noised results returned by `get_noised_result`.
global_state: Updated global state returned by `get_noised_result`, which
records noise for the conceptual cumulative sum of the current leaf
node, and tree state for the next conceptual cumulative sum.
@ -420,6 +437,17 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
previous_tree_noise=self._zero_initial_noise(),
tree_state=new_tree_state)
def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev):
noise_generator_state = global_state.tree_state.value_generator_state
assert isinstance(self._tree_aggregator.value_generator,
tree_aggregation.GaussianNoiseGenerator)
noise_generator_state = self._tree_aggregator.value_generator.make_state(
noise_generator_state.seeds, stddev)
new_tree_state = attr.evolve(
global_state.tree_state, value_generator_state=noise_generator_state)
return attr.evolve(
global_state, clip_value=clip_norm, tree_state=new_tree_state)
@classmethod
def build_l2_gaussian_query(cls,
clip_norm,
@ -442,8 +470,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees".
"""
if clip_norm <= 0:
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
if clip_norm < 0:
raise ValueError(f'`clip_norm` must be non-negative, got {clip_norm}.')
if noise_multiplier < 0:
raise ValueError(