From b572707cfc578aa1347840e686378ac9b96bc905 Mon Sep 17 00:00:00 2001 From: Zheng Xu Date: Mon, 13 Sep 2021 17:47:50 -0700 Subject: [PATCH] Update reset and pre-process functions for tree aggregation queries. Minor comments update for adaptive clip query tests. PiperOrigin-RevId: 396483111 --- .../quantile_adaptive_clip_sum_query_test.py | 9 ++--- .../dp_query/quantile_estimator_query.py | 11 +++++- .../dp_query/quantile_estimator_query_test.py | 2 +- .../dp_query/tree_aggregation_query.py | 34 +++++++++++++++++-- 4 files changed, 47 insertions(+), 9 deletions(-) diff --git a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py index 51da202..5979266 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py @@ -230,7 +230,7 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase, ('start_high_arithmetic', False, False), ('start_high_geometric', False, True)) def test_adaptation_linspace(self, start_low, geometric): - # 100 records equally spaced from 0 to 10 in 0.1 increments. + # `num_records` records equally spaced from 0 to 10 in 0.1 increments. # Test that we converge to the correct median value and bounce around it. num_records = 21 records = [ @@ -262,9 +262,10 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase, ('start_high_arithmetic', False, False), ('start_high_geometric', False, True)) def test_adaptation_all_equal(self, start_low, geometric): - # 20 equal records. Test that we converge to that record and bounce around - # it. Unlike the linspace test, the quantile-matching objective is very - # sharp at the optimum so a decaying learning rate is necessary. + # `num_records` equal records. Test that we converge to that record and + # bounce around it. Unlike the linspace test, the quantile-matching + # objective is very sharp at the optimum so a decaying learning rate is + # necessary. num_records = 20 records = [tf.constant(5.0)] * num_records diff --git a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py index 9c90a03..4a453d6 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py @@ -74,6 +74,15 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery): updating is preferred for non-negative records like vector norms that could potentially be very large or very close to zero. """ + + if target_quantile < 0 or target_quantile > 1: + raise ValueError( + f'`target_quantile` must be between 0 and 1, got {target_quantile}.') + + if learning_rate < 0: + raise ValueError( + f'`learning_rate` must be non-negative, got {learning_rate}') + self._initial_estimate = initial_estimate self._target_quantile = target_quantile self._learning_rate = learning_rate @@ -208,7 +217,7 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery): return no_privacy_query.NoPrivacyAverageQuery() -class TreeAggregationQuantileEstimatorQuery(QuantileEstimatorQuery): +class TreeQuantileEstimatorQuery(QuantileEstimatorQuery): """Iterative process to estimate target quantile of a univariate distribution. Unlike the base class, this uses a `TreeResidualSumQuery` to estimate the diff --git a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query_test.py b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query_test.py index d349f56..e29fc4a 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query_test.py @@ -38,7 +38,7 @@ def _make_quantile_estimator_query(initial_estimate, tree_aggregation=False): if expected_num_records is not None: if tree_aggregation: - return quantile_estimator_query.TreeAggregationQuantileEstimatorQuery( + return quantile_estimator_query.TreeQuantileEstimatorQuery( initial_estimate, target_quantile, learning_rate, below_estimate_stddev, expected_num_records, geometric_update) else: diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index 2752dba..70f9efa 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -360,6 +360,23 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" return global_state.clip_value + def preprocess_record_l2_impl(self, params, record): + """Clips the l2 norm, returning the clipped record and the l2 norm. + + Args: + params: The parameters for the sample. + record: The record to be processed. + + Returns: + A tuple (preprocessed_records, l2_norm) where `preprocessed_records` is + the structure of preprocessed tensors, and l2_norm is the total l2 norm + before clipping. + """ + l2_norm_clip = params + record_as_list = tf.nest.flatten(record) + clipped_as_list, norm = tf.clip_by_global_norm(record_as_list, l2_norm_clip) + return tf.nest.pack_sequence_as(record, clipped_as_list), norm + def preprocess_record(self, params, record): """Implements `tensorflow_privacy.DPQuery.preprocess_record`. @@ -405,7 +422,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): `get_noised_result` when the restarting condition is met. Args: - noised_results: Noised cumulative sum returned by `get_noised_result`. + noised_results: Noised results returned by `get_noised_result`. global_state: Updated global state returned by `get_noised_result`, which records noise for the conceptual cumulative sum of the current leaf node, and tree state for the next conceptual cumulative sum. @@ -420,6 +437,17 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): previous_tree_noise=self._zero_initial_noise(), tree_state=new_tree_state) + def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev): + noise_generator_state = global_state.tree_state.value_generator_state + assert isinstance(self._tree_aggregator.value_generator, + tree_aggregation.GaussianNoiseGenerator) + noise_generator_state = self._tree_aggregator.value_generator.make_state( + noise_generator_state.seeds, stddev) + new_tree_state = attr.evolve( + global_state.tree_state, value_generator_state=noise_generator_state) + return attr.evolve( + global_state, clip_value=clip_norm, tree_state=new_tree_state) + @classmethod def build_l2_gaussian_query(cls, clip_norm, @@ -442,8 +470,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): aggregation algorithm based on the paper "Efficient Use of Differentially Private Binary Trees". """ - if clip_norm <= 0: - raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.') + if clip_norm < 0: + raise ValueError(f'`clip_norm` must be non-negative, got {clip_norm}.') if noise_multiplier < 0: raise ValueError(