Function to reset tree for tree aggregation based quantile estimation.

PiperOrigin-RevId: 399508765
This commit is contained in:
Zheng Xu 2021-09-28 12:55:22 -07:00 committed by A. Unique TensorFlower
parent b8b4c4b264
commit 99c82a49d8
2 changed files with 39 additions and 0 deletions

View file

@ -242,3 +242,10 @@ class TreeQuantileEstimatorQuery(QuantileEstimatorQuery):
record_specs=tf.TensorSpec([]))
return normalized_query.NormalizedQuery(
sum_query, denominator=expected_num_records)
def reset_state(self, noised_results, global_state):
new_numerator_state = self._below_estimate_query._numerator.reset_state( # pylint: disable=protected-access,line-too-long
noised_results, global_state.below_estimate_state.numerator_state)
new_below_estimate_state = global_state.below_estimate_state._replace(
numerator_state=new_numerator_state)
return global_state._replace(below_estimate_state=new_below_estimate_state)

View file

@ -280,6 +280,38 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
with self.assertRaisesRegex(ValueError, 'scalar'):
query.accumulate_record(None, None, [1.0, 2.0])
def test_tree_noise_restart(self):
sample_num, tolerance, stddev = 1000, 0.3, 0.1
initial_estimate, expected_num_records = 5., 2.
record1 = tf.constant(1.)
record2 = tf.constant(10.)
query = _make_quantile_estimator_query(
initial_estimate=initial_estimate,
target_quantile=.5,
learning_rate=1.,
below_estimate_stddev=stddev,
expected_num_records=expected_num_records,
geometric_update=False,
tree_aggregation=True)
global_state = query.initial_global_state()
self.assertAllClose(global_state.current_estimate, initial_estimate)
# As the target quantile is accurate, there is no signal and only noise.
samples = []
for _ in range(sample_num):
noised_estimate, global_state = test_utils.run_query(
query, [record1, record2], global_state)
samples.append(noised_estimate.numpy())
global_state = query.reset_state(noised_estimate, global_state)
self.assertNotEqual(global_state.current_estimate, initial_estimate)
global_state = global_state._replace(current_estimate=initial_estimate)
self.assertAllClose(
np.std(samples), stddev / expected_num_records, rtol=tolerance)
if __name__ == '__main__':
tf.test.main()