Function to reset tree for tree aggregation based quantile estimation.
PiperOrigin-RevId: 399508765
This commit is contained in:
parent
b8b4c4b264
commit
99c82a49d8
2 changed files with 39 additions and 0 deletions
|
@ -242,3 +242,10 @@ class TreeQuantileEstimatorQuery(QuantileEstimatorQuery):
|
||||||
record_specs=tf.TensorSpec([]))
|
record_specs=tf.TensorSpec([]))
|
||||||
return normalized_query.NormalizedQuery(
|
return normalized_query.NormalizedQuery(
|
||||||
sum_query, denominator=expected_num_records)
|
sum_query, denominator=expected_num_records)
|
||||||
|
|
||||||
|
def reset_state(self, noised_results, global_state):
|
||||||
|
new_numerator_state = self._below_estimate_query._numerator.reset_state( # pylint: disable=protected-access,line-too-long
|
||||||
|
noised_results, global_state.below_estimate_state.numerator_state)
|
||||||
|
new_below_estimate_state = global_state.below_estimate_state._replace(
|
||||||
|
numerator_state=new_numerator_state)
|
||||||
|
return global_state._replace(below_estimate_state=new_below_estimate_state)
|
||||||
|
|
|
@ -280,6 +280,38 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
with self.assertRaisesRegex(ValueError, 'scalar'):
|
with self.assertRaisesRegex(ValueError, 'scalar'):
|
||||||
query.accumulate_record(None, None, [1.0, 2.0])
|
query.accumulate_record(None, None, [1.0, 2.0])
|
||||||
|
|
||||||
|
def test_tree_noise_restart(self):
|
||||||
|
sample_num, tolerance, stddev = 1000, 0.3, 0.1
|
||||||
|
initial_estimate, expected_num_records = 5., 2.
|
||||||
|
record1 = tf.constant(1.)
|
||||||
|
record2 = tf.constant(10.)
|
||||||
|
|
||||||
|
query = _make_quantile_estimator_query(
|
||||||
|
initial_estimate=initial_estimate,
|
||||||
|
target_quantile=.5,
|
||||||
|
learning_rate=1.,
|
||||||
|
below_estimate_stddev=stddev,
|
||||||
|
expected_num_records=expected_num_records,
|
||||||
|
geometric_update=False,
|
||||||
|
tree_aggregation=True)
|
||||||
|
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
self.assertAllClose(global_state.current_estimate, initial_estimate)
|
||||||
|
|
||||||
|
# As the target quantile is accurate, there is no signal and only noise.
|
||||||
|
samples = []
|
||||||
|
for _ in range(sample_num):
|
||||||
|
noised_estimate, global_state = test_utils.run_query(
|
||||||
|
query, [record1, record2], global_state)
|
||||||
|
samples.append(noised_estimate.numpy())
|
||||||
|
global_state = query.reset_state(noised_estimate, global_state)
|
||||||
|
self.assertNotEqual(global_state.current_estimate, initial_estimate)
|
||||||
|
global_state = global_state._replace(current_estimate=initial_estimate)
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
np.std(samples), stddev / expected_num_records, rtol=tolerance)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue