Try to fix flaky tree_aggregation_query_test.test_noisy_cumsum_and_state_update.

PiperOrigin-RevId: 394248815
This commit is contained in:
Zheng Xu 2021-09-01 09:29:40 -07:00 committed by A. Unique TensorFlower
parent 7e7736ea91
commit e99fb7ea9b

View file

@ -214,16 +214,16 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('two_records_noise_fn', [2.71828, 3.14159], _get_noise_fn),
('five_records_noise_fn', np.random.uniform(size=5).tolist(),
('five_records_noise_fn', np.random.uniform(low=0.1, size=5).tolist(),
_get_noise_fn),
('two_records_generator', [2.71828, 3.14159], _get_noise_generator),
('five_records_generator', np.random.uniform(size=5).tolist(),
('five_records_generator', np.random.uniform(low=0.1, size=5).tolist(),
_get_noise_generator),
)
def test_noisy_cumsum_and_state_update(self, records, value_generator):
num_trials = 200
record_specs = tf.nest.map_structure(lambda t: tf.TensorSpec(tf.shape(t)),
records[0])
record_specs = tf.TensorSpec([])
records = [tf.constant(r) for r in records]
noised_sums = []
for i in range(num_trials):
query = tree_aggregation_query.TreeCumulativeSumQuery(
@ -232,7 +232,7 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
noise_generator=value_generator(record_specs, seed=i),
record_specs=record_specs)
query_result, _ = test_utils.run_query(query, records)
noised_sums.append(query_result)
noised_sums.append(query_result.numpy())
result_stddev = np.std(noised_sums)
self.assertNear(result_stddev, NOISE_STD, 0.7) # value for chi-squared test