Update reset and pre-process functions for tree aggregation queries. Minor comments update for adaptive clip query tests.

PiperOrigin-RevId: 396483111
This commit is contained in:
Zheng Xu 2021-09-13 17:47:50 -07:00 committed by A. Unique TensorFlower
parent 0d05f2eb18
commit b572707cfc
4 changed files with 47 additions and 9 deletions

View file

@ -230,7 +230,7 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase,
('start_high_arithmetic', False, False), ('start_high_arithmetic', False, False),
('start_high_geometric', False, True)) ('start_high_geometric', False, True))
def test_adaptation_linspace(self, start_low, geometric): def test_adaptation_linspace(self, start_low, geometric):
# 100 records equally spaced from 0 to 10 in 0.1 increments. # `num_records` records equally spaced from 0 to 10 in 0.1 increments.
# Test that we converge to the correct median value and bounce around it. # Test that we converge to the correct median value and bounce around it.
num_records = 21 num_records = 21
records = [ records = [
@ -262,9 +262,10 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase,
('start_high_arithmetic', False, False), ('start_high_arithmetic', False, False),
('start_high_geometric', False, True)) ('start_high_geometric', False, True))
def test_adaptation_all_equal(self, start_low, geometric): def test_adaptation_all_equal(self, start_low, geometric):
# 20 equal records. Test that we converge to that record and bounce around # `num_records` equal records. Test that we converge to that record and
# it. Unlike the linspace test, the quantile-matching objective is very # bounce around it. Unlike the linspace test, the quantile-matching
# sharp at the optimum so a decaying learning rate is necessary. # objective is very sharp at the optimum so a decaying learning rate is
# necessary.
num_records = 20 num_records = 20
records = [tf.constant(5.0)] * num_records records = [tf.constant(5.0)] * num_records

View file

@ -74,6 +74,15 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
updating is preferred for non-negative records like vector norms that updating is preferred for non-negative records like vector norms that
could potentially be very large or very close to zero. could potentially be very large or very close to zero.
""" """
if target_quantile < 0 or target_quantile > 1:
raise ValueError(
f'`target_quantile` must be between 0 and 1, got {target_quantile}.')
if learning_rate < 0:
raise ValueError(
f'`learning_rate` must be non-negative, got {learning_rate}')
self._initial_estimate = initial_estimate self._initial_estimate = initial_estimate
self._target_quantile = target_quantile self._target_quantile = target_quantile
self._learning_rate = learning_rate self._learning_rate = learning_rate
@ -208,7 +217,7 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
return no_privacy_query.NoPrivacyAverageQuery() return no_privacy_query.NoPrivacyAverageQuery()
class TreeAggregationQuantileEstimatorQuery(QuantileEstimatorQuery): class TreeQuantileEstimatorQuery(QuantileEstimatorQuery):
"""Iterative process to estimate target quantile of a univariate distribution. """Iterative process to estimate target quantile of a univariate distribution.
Unlike the base class, this uses a `TreeResidualSumQuery` to estimate the Unlike the base class, this uses a `TreeResidualSumQuery` to estimate the

View file

@ -38,7 +38,7 @@ def _make_quantile_estimator_query(initial_estimate,
tree_aggregation=False): tree_aggregation=False):
if expected_num_records is not None: if expected_num_records is not None:
if tree_aggregation: if tree_aggregation:
return quantile_estimator_query.TreeAggregationQuantileEstimatorQuery( return quantile_estimator_query.TreeQuantileEstimatorQuery(
initial_estimate, target_quantile, learning_rate, initial_estimate, target_quantile, learning_rate,
below_estimate_stddev, expected_num_records, geometric_update) below_estimate_stddev, expected_num_records, geometric_update)
else: else:

View file

@ -360,6 +360,23 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" """Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return global_state.clip_value return global_state.clip_value
def preprocess_record_l2_impl(self, params, record):
"""Clips the l2 norm, returning the clipped record and the l2 norm.
Args:
params: The parameters for the sample.
record: The record to be processed.
Returns:
A tuple (preprocessed_records, l2_norm) where `preprocessed_records` is
the structure of preprocessed tensors, and l2_norm is the total l2 norm
before clipping.
"""
l2_norm_clip = params
record_as_list = tf.nest.flatten(record)
clipped_as_list, norm = tf.clip_by_global_norm(record_as_list, l2_norm_clip)
return tf.nest.pack_sequence_as(record, clipped_as_list), norm
def preprocess_record(self, params, record): def preprocess_record(self, params, record):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`. """Implements `tensorflow_privacy.DPQuery.preprocess_record`.
@ -405,7 +422,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
`get_noised_result` when the restarting condition is met. `get_noised_result` when the restarting condition is met.
Args: Args:
noised_results: Noised cumulative sum returned by `get_noised_result`. noised_results: Noised results returned by `get_noised_result`.
global_state: Updated global state returned by `get_noised_result`, which global_state: Updated global state returned by `get_noised_result`, which
records noise for the conceptual cumulative sum of the current leaf records noise for the conceptual cumulative sum of the current leaf
node, and tree state for the next conceptual cumulative sum. node, and tree state for the next conceptual cumulative sum.
@ -420,6 +437,17 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
previous_tree_noise=self._zero_initial_noise(), previous_tree_noise=self._zero_initial_noise(),
tree_state=new_tree_state) tree_state=new_tree_state)
def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev):
noise_generator_state = global_state.tree_state.value_generator_state
assert isinstance(self._tree_aggregator.value_generator,
tree_aggregation.GaussianNoiseGenerator)
noise_generator_state = self._tree_aggregator.value_generator.make_state(
noise_generator_state.seeds, stddev)
new_tree_state = attr.evolve(
global_state.tree_state, value_generator_state=noise_generator_state)
return attr.evolve(
global_state, clip_value=clip_norm, tree_state=new_tree_state)
@classmethod @classmethod
def build_l2_gaussian_query(cls, def build_l2_gaussian_query(cls,
clip_norm, clip_norm,
@ -442,8 +470,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
aggregation algorithm based on the paper "Efficient Use of aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees". Differentially Private Binary Trees".
""" """
if clip_norm <= 0: if clip_norm < 0:
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.') raise ValueError(f'`clip_norm` must be non-negative, got {clip_norm}.')
if noise_multiplier < 0: if noise_multiplier < 0:
raise ValueError( raise ValueError(