From 99c82a49d8e9c5f24efd2e262a4b071683abb2b5 Mon Sep 17 00:00:00 2001 From: Zheng Xu Date: Tue, 28 Sep 2021 12:55:22 -0700 Subject: [PATCH] Function to reset tree for tree aggregation based quantile estimation. PiperOrigin-RevId: 399508765 --- .../dp_query/quantile_estimator_query.py | 7 ++++ .../dp_query/quantile_estimator_query_test.py | 32 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py index 4a453d6..0708016 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py @@ -242,3 +242,10 @@ class TreeQuantileEstimatorQuery(QuantileEstimatorQuery): record_specs=tf.TensorSpec([])) return normalized_query.NormalizedQuery( sum_query, denominator=expected_num_records) + + def reset_state(self, noised_results, global_state): + new_numerator_state = self._below_estimate_query._numerator.reset_state( # pylint: disable=protected-access,line-too-long + noised_results, global_state.below_estimate_state.numerator_state) + new_below_estimate_state = global_state.below_estimate_state._replace( + numerator_state=new_numerator_state) + return global_state._replace(below_estimate_state=new_below_estimate_state) diff --git a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query_test.py b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query_test.py index e29fc4a..fa3f03e 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query_test.py @@ -280,6 +280,38 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase): with self.assertRaisesRegex(ValueError, 'scalar'): query.accumulate_record(None, None, [1.0, 2.0]) + def test_tree_noise_restart(self): + sample_num, tolerance, stddev = 1000, 0.3, 0.1 + initial_estimate, expected_num_records = 5., 2. + record1 = tf.constant(1.) + record2 = tf.constant(10.) + + query = _make_quantile_estimator_query( + initial_estimate=initial_estimate, + target_quantile=.5, + learning_rate=1., + below_estimate_stddev=stddev, + expected_num_records=expected_num_records, + geometric_update=False, + tree_aggregation=True) + + global_state = query.initial_global_state() + + self.assertAllClose(global_state.current_estimate, initial_estimate) + + # As the target quantile is accurate, there is no signal and only noise. + samples = [] + for _ in range(sample_num): + noised_estimate, global_state = test_utils.run_query( + query, [record1, record2], global_state) + samples.append(noised_estimate.numpy()) + global_state = query.reset_state(noised_estimate, global_state) + self.assertNotEqual(global_state.current_estimate, initial_estimate) + global_state = global_state._replace(current_estimate=initial_estimate) + + self.assertAllClose( + np.std(samples), stddev / expected_num_records, rtol=tolerance) + if __name__ == '__main__': tf.test.main()