forked from 626_privacy/tensorflow_privacy
Function to reset tree for tree aggregation based quantile estimation.
PiperOrigin-RevId: 399508765
This commit is contained in:
parent
b8b4c4b264
commit
99c82a49d8
2 changed files with 39 additions and 0 deletions
|
@ -242,3 +242,10 @@ class TreeQuantileEstimatorQuery(QuantileEstimatorQuery):
|
|||
record_specs=tf.TensorSpec([]))
|
||||
return normalized_query.NormalizedQuery(
|
||||
sum_query, denominator=expected_num_records)
|
||||
|
||||
def reset_state(self, noised_results, global_state):
|
||||
new_numerator_state = self._below_estimate_query._numerator.reset_state( # pylint: disable=protected-access,line-too-long
|
||||
noised_results, global_state.below_estimate_state.numerator_state)
|
||||
new_below_estimate_state = global_state.below_estimate_state._replace(
|
||||
numerator_state=new_numerator_state)
|
||||
return global_state._replace(below_estimate_state=new_below_estimate_state)
|
||||
|
|
|
@ -280,6 +280,38 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
with self.assertRaisesRegex(ValueError, 'scalar'):
|
||||
query.accumulate_record(None, None, [1.0, 2.0])
|
||||
|
||||
def test_tree_noise_restart(self):
|
||||
sample_num, tolerance, stddev = 1000, 0.3, 0.1
|
||||
initial_estimate, expected_num_records = 5., 2.
|
||||
record1 = tf.constant(1.)
|
||||
record2 = tf.constant(10.)
|
||||
|
||||
query = _make_quantile_estimator_query(
|
||||
initial_estimate=initial_estimate,
|
||||
target_quantile=.5,
|
||||
learning_rate=1.,
|
||||
below_estimate_stddev=stddev,
|
||||
expected_num_records=expected_num_records,
|
||||
geometric_update=False,
|
||||
tree_aggregation=True)
|
||||
|
||||
global_state = query.initial_global_state()
|
||||
|
||||
self.assertAllClose(global_state.current_estimate, initial_estimate)
|
||||
|
||||
# As the target quantile is accurate, there is no signal and only noise.
|
||||
samples = []
|
||||
for _ in range(sample_num):
|
||||
noised_estimate, global_state = test_utils.run_query(
|
||||
query, [record1, record2], global_state)
|
||||
samples.append(noised_estimate.numpy())
|
||||
global_state = query.reset_state(noised_estimate, global_state)
|
||||
self.assertNotEqual(global_state.current_estimate, initial_estimate)
|
||||
global_state = global_state._replace(current_estimate=initial_estimate)
|
||||
|
||||
self.assertAllClose(
|
||||
np.std(samples), stddev / expected_num_records, rtol=tolerance)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in a new issue