forked from 626_privacy/tensorflow_privacy
Update reset and pre-process functions for tree aggregation queries. Minor comments update for adaptive clip query tests.
PiperOrigin-RevId: 396483111
This commit is contained in:
parent
0d05f2eb18
commit
b572707cfc
4 changed files with 47 additions and 9 deletions
|
@ -230,7 +230,7 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase,
|
|||
('start_high_arithmetic', False, False),
|
||||
('start_high_geometric', False, True))
|
||||
def test_adaptation_linspace(self, start_low, geometric):
|
||||
# 100 records equally spaced from 0 to 10 in 0.1 increments.
|
||||
# `num_records` records equally spaced from 0 to 10 in 0.1 increments.
|
||||
# Test that we converge to the correct median value and bounce around it.
|
||||
num_records = 21
|
||||
records = [
|
||||
|
@ -262,9 +262,10 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase,
|
|||
('start_high_arithmetic', False, False),
|
||||
('start_high_geometric', False, True))
|
||||
def test_adaptation_all_equal(self, start_low, geometric):
|
||||
# 20 equal records. Test that we converge to that record and bounce around
|
||||
# it. Unlike the linspace test, the quantile-matching objective is very
|
||||
# sharp at the optimum so a decaying learning rate is necessary.
|
||||
# `num_records` equal records. Test that we converge to that record and
|
||||
# bounce around it. Unlike the linspace test, the quantile-matching
|
||||
# objective is very sharp at the optimum so a decaying learning rate is
|
||||
# necessary.
|
||||
num_records = 20
|
||||
records = [tf.constant(5.0)] * num_records
|
||||
|
||||
|
|
|
@ -74,6 +74,15 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
|||
updating is preferred for non-negative records like vector norms that
|
||||
could potentially be very large or very close to zero.
|
||||
"""
|
||||
|
||||
if target_quantile < 0 or target_quantile > 1:
|
||||
raise ValueError(
|
||||
f'`target_quantile` must be between 0 and 1, got {target_quantile}.')
|
||||
|
||||
if learning_rate < 0:
|
||||
raise ValueError(
|
||||
f'`learning_rate` must be non-negative, got {learning_rate}')
|
||||
|
||||
self._initial_estimate = initial_estimate
|
||||
self._target_quantile = target_quantile
|
||||
self._learning_rate = learning_rate
|
||||
|
@ -208,7 +217,7 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
|
|||
return no_privacy_query.NoPrivacyAverageQuery()
|
||||
|
||||
|
||||
class TreeAggregationQuantileEstimatorQuery(QuantileEstimatorQuery):
|
||||
class TreeQuantileEstimatorQuery(QuantileEstimatorQuery):
|
||||
"""Iterative process to estimate target quantile of a univariate distribution.
|
||||
|
||||
Unlike the base class, this uses a `TreeResidualSumQuery` to estimate the
|
||||
|
|
|
@ -38,7 +38,7 @@ def _make_quantile_estimator_query(initial_estimate,
|
|||
tree_aggregation=False):
|
||||
if expected_num_records is not None:
|
||||
if tree_aggregation:
|
||||
return quantile_estimator_query.TreeAggregationQuantileEstimatorQuery(
|
||||
return quantile_estimator_query.TreeQuantileEstimatorQuery(
|
||||
initial_estimate, target_quantile, learning_rate,
|
||||
below_estimate_stddev, expected_num_records, geometric_update)
|
||||
else:
|
||||
|
|
|
@ -360,6 +360,23 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
return global_state.clip_value
|
||||
|
||||
def preprocess_record_l2_impl(self, params, record):
|
||||
"""Clips the l2 norm, returning the clipped record and the l2 norm.
|
||||
|
||||
Args:
|
||||
params: The parameters for the sample.
|
||||
record: The record to be processed.
|
||||
|
||||
Returns:
|
||||
A tuple (preprocessed_records, l2_norm) where `preprocessed_records` is
|
||||
the structure of preprocessed tensors, and l2_norm is the total l2 norm
|
||||
before clipping.
|
||||
"""
|
||||
l2_norm_clip = params
|
||||
record_as_list = tf.nest.flatten(record)
|
||||
clipped_as_list, norm = tf.clip_by_global_norm(record_as_list, l2_norm_clip)
|
||||
return tf.nest.pack_sequence_as(record, clipped_as_list), norm
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
|
||||
|
||||
|
@ -405,7 +422,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
`get_noised_result` when the restarting condition is met.
|
||||
|
||||
Args:
|
||||
noised_results: Noised cumulative sum returned by `get_noised_result`.
|
||||
noised_results: Noised results returned by `get_noised_result`.
|
||||
global_state: Updated global state returned by `get_noised_result`, which
|
||||
records noise for the conceptual cumulative sum of the current leaf
|
||||
node, and tree state for the next conceptual cumulative sum.
|
||||
|
@ -420,6 +437,17 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
previous_tree_noise=self._zero_initial_noise(),
|
||||
tree_state=new_tree_state)
|
||||
|
||||
def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev):
|
||||
noise_generator_state = global_state.tree_state.value_generator_state
|
||||
assert isinstance(self._tree_aggregator.value_generator,
|
||||
tree_aggregation.GaussianNoiseGenerator)
|
||||
noise_generator_state = self._tree_aggregator.value_generator.make_state(
|
||||
noise_generator_state.seeds, stddev)
|
||||
new_tree_state = attr.evolve(
|
||||
global_state.tree_state, value_generator_state=noise_generator_state)
|
||||
return attr.evolve(
|
||||
global_state, clip_value=clip_norm, tree_state=new_tree_state)
|
||||
|
||||
@classmethod
|
||||
def build_l2_gaussian_query(cls,
|
||||
clip_norm,
|
||||
|
@ -442,8 +470,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
aggregation algorithm based on the paper "Efficient Use of
|
||||
Differentially Private Binary Trees".
|
||||
"""
|
||||
if clip_norm <= 0:
|
||||
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
|
||||
if clip_norm < 0:
|
||||
raise ValueError(f'`clip_norm` must be non-negative, got {clip_norm}.')
|
||||
|
||||
if noise_multiplier < 0:
|
||||
raise ValueError(
|
||||
|
|
Loading…
Reference in a new issue