Update reset and pre-process functions for tree aggregation queries. Minor comments update for adaptive clip query tests.
PiperOrigin-RevId: 396483111
This commit is contained in:
parent
0d05f2eb18
commit
b572707cfc
4 changed files with 47 additions and 9 deletions
|
@ -230,7 +230,7 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase,
|
||||||
('start_high_arithmetic', False, False),
|
('start_high_arithmetic', False, False),
|
||||||
('start_high_geometric', False, True))
|
('start_high_geometric', False, True))
|
||||||
def test_adaptation_linspace(self, start_low, geometric):
|
def test_adaptation_linspace(self, start_low, geometric):
|
||||||
# 100 records equally spaced from 0 to 10 in 0.1 increments.
|
# `num_records` records equally spaced from 0 to 10 in 0.1 increments.
|
||||||
# Test that we converge to the correct median value and bounce around it.
|
# Test that we converge to the correct median value and bounce around it.
|
||||||
num_records = 21
|
num_records = 21
|
||||||
records = [
|
records = [
|
||||||
|
@ -262,9 +262,10 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase,
|
||||||
('start_high_arithmetic', False, False),
|
('start_high_arithmetic', False, False),
|
||||||
('start_high_geometric', False, True))
|
('start_high_geometric', False, True))
|
||||||
def test_adaptation_all_equal(self, start_low, geometric):
|
def test_adaptation_all_equal(self, start_low, geometric):
|
||||||
# 20 equal records. Test that we converge to that record and bounce around
|
# `num_records` equal records. Test that we converge to that record and
|
||||||
# it. Unlike the linspace test, the quantile-matching objective is very
|
# bounce around it. Unlike the linspace test, the quantile-matching
|
||||||
# sharp at the optimum so a decaying learning rate is necessary.
|
# objective is very sharp at the optimum so a decaying learning rate is
|
||||||
|
# necessary.
|
||||||
num_records = 20
|
num_records = 20
|
||||||
records = [tf.constant(5.0)] * num_records
|
records = [tf.constant(5.0)] * num_records
|
||||||
|
|
||||||
|
|
|
@ -74,6 +74,15 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
updating is preferred for non-negative records like vector norms that
|
updating is preferred for non-negative records like vector norms that
|
||||||
could potentially be very large or very close to zero.
|
could potentially be very large or very close to zero.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if target_quantile < 0 or target_quantile > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f'`target_quantile` must be between 0 and 1, got {target_quantile}.')
|
||||||
|
|
||||||
|
if learning_rate < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f'`learning_rate` must be non-negative, got {learning_rate}')
|
||||||
|
|
||||||
self._initial_estimate = initial_estimate
|
self._initial_estimate = initial_estimate
|
||||||
self._target_quantile = target_quantile
|
self._target_quantile = target_quantile
|
||||||
self._learning_rate = learning_rate
|
self._learning_rate = learning_rate
|
||||||
|
@ -208,7 +217,7 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
|
||||||
return no_privacy_query.NoPrivacyAverageQuery()
|
return no_privacy_query.NoPrivacyAverageQuery()
|
||||||
|
|
||||||
|
|
||||||
class TreeAggregationQuantileEstimatorQuery(QuantileEstimatorQuery):
|
class TreeQuantileEstimatorQuery(QuantileEstimatorQuery):
|
||||||
"""Iterative process to estimate target quantile of a univariate distribution.
|
"""Iterative process to estimate target quantile of a univariate distribution.
|
||||||
|
|
||||||
Unlike the base class, this uses a `TreeResidualSumQuery` to estimate the
|
Unlike the base class, this uses a `TreeResidualSumQuery` to estimate the
|
||||||
|
|
|
@ -38,7 +38,7 @@ def _make_quantile_estimator_query(initial_estimate,
|
||||||
tree_aggregation=False):
|
tree_aggregation=False):
|
||||||
if expected_num_records is not None:
|
if expected_num_records is not None:
|
||||||
if tree_aggregation:
|
if tree_aggregation:
|
||||||
return quantile_estimator_query.TreeAggregationQuantileEstimatorQuery(
|
return quantile_estimator_query.TreeQuantileEstimatorQuery(
|
||||||
initial_estimate, target_quantile, learning_rate,
|
initial_estimate, target_quantile, learning_rate,
|
||||||
below_estimate_stddev, expected_num_records, geometric_update)
|
below_estimate_stddev, expected_num_records, geometric_update)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -360,6 +360,23 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||||
return global_state.clip_value
|
return global_state.clip_value
|
||||||
|
|
||||||
|
def preprocess_record_l2_impl(self, params, record):
|
||||||
|
"""Clips the l2 norm, returning the clipped record and the l2 norm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: The parameters for the sample.
|
||||||
|
record: The record to be processed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple (preprocessed_records, l2_norm) where `preprocessed_records` is
|
||||||
|
the structure of preprocessed tensors, and l2_norm is the total l2 norm
|
||||||
|
before clipping.
|
||||||
|
"""
|
||||||
|
l2_norm_clip = params
|
||||||
|
record_as_list = tf.nest.flatten(record)
|
||||||
|
clipped_as_list, norm = tf.clip_by_global_norm(record_as_list, l2_norm_clip)
|
||||||
|
return tf.nest.pack_sequence_as(record, clipped_as_list), norm
|
||||||
|
|
||||||
def preprocess_record(self, params, record):
|
def preprocess_record(self, params, record):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
|
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
|
||||||
|
|
||||||
|
@ -405,7 +422,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
`get_noised_result` when the restarting condition is met.
|
`get_noised_result` when the restarting condition is met.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
noised_results: Noised cumulative sum returned by `get_noised_result`.
|
noised_results: Noised results returned by `get_noised_result`.
|
||||||
global_state: Updated global state returned by `get_noised_result`, which
|
global_state: Updated global state returned by `get_noised_result`, which
|
||||||
records noise for the conceptual cumulative sum of the current leaf
|
records noise for the conceptual cumulative sum of the current leaf
|
||||||
node, and tree state for the next conceptual cumulative sum.
|
node, and tree state for the next conceptual cumulative sum.
|
||||||
|
@ -420,6 +437,17 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
previous_tree_noise=self._zero_initial_noise(),
|
previous_tree_noise=self._zero_initial_noise(),
|
||||||
tree_state=new_tree_state)
|
tree_state=new_tree_state)
|
||||||
|
|
||||||
|
def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev):
|
||||||
|
noise_generator_state = global_state.tree_state.value_generator_state
|
||||||
|
assert isinstance(self._tree_aggregator.value_generator,
|
||||||
|
tree_aggregation.GaussianNoiseGenerator)
|
||||||
|
noise_generator_state = self._tree_aggregator.value_generator.make_state(
|
||||||
|
noise_generator_state.seeds, stddev)
|
||||||
|
new_tree_state = attr.evolve(
|
||||||
|
global_state.tree_state, value_generator_state=noise_generator_state)
|
||||||
|
return attr.evolve(
|
||||||
|
global_state, clip_value=clip_norm, tree_state=new_tree_state)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_l2_gaussian_query(cls,
|
def build_l2_gaussian_query(cls,
|
||||||
clip_norm,
|
clip_norm,
|
||||||
|
@ -442,8 +470,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
aggregation algorithm based on the paper "Efficient Use of
|
aggregation algorithm based on the paper "Efficient Use of
|
||||||
Differentially Private Binary Trees".
|
Differentially Private Binary Trees".
|
||||||
"""
|
"""
|
||||||
if clip_norm <= 0:
|
if clip_norm < 0:
|
||||||
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
|
raise ValueError(f'`clip_norm` must be non-negative, got {clip_norm}.')
|
||||||
|
|
||||||
if noise_multiplier < 0:
|
if noise_multiplier < 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
Loading…
Reference in a new issue