From e99fb7ea9baee079db99f440cb1e89a1990ba0a8 Mon Sep 17 00:00:00 2001 From: Zheng Xu Date: Wed, 1 Sep 2021 09:29:40 -0700 Subject: [PATCH] Try to fix flaky `tree_aggregation_query_test.test_noisy_cumsum_and_state_update`. PiperOrigin-RevId: 394248815 --- .../privacy/dp_query/tree_aggregation_query_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py index 56118ce..ace6484 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py @@ -214,16 +214,16 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( ('two_records_noise_fn', [2.71828, 3.14159], _get_noise_fn), - ('five_records_noise_fn', np.random.uniform(size=5).tolist(), + ('five_records_noise_fn', np.random.uniform(low=0.1, size=5).tolist(), _get_noise_fn), ('two_records_generator', [2.71828, 3.14159], _get_noise_generator), - ('five_records_generator', np.random.uniform(size=5).tolist(), + ('five_records_generator', np.random.uniform(low=0.1, size=5).tolist(), _get_noise_generator), ) def test_noisy_cumsum_and_state_update(self, records, value_generator): num_trials = 200 - record_specs = tf.nest.map_structure(lambda t: tf.TensorSpec(tf.shape(t)), - records[0]) + record_specs = tf.TensorSpec([]) + records = [tf.constant(r) for r in records] noised_sums = [] for i in range(num_trials): query = tree_aggregation_query.TreeCumulativeSumQuery( @@ -232,7 +232,7 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase): noise_generator=value_generator(record_specs, seed=i), record_specs=record_specs) query_result, _ = test_utils.run_query(query, records) - noised_sums.append(query_result) + noised_sums.append(query_result.numpy()) result_stddev = np.std(noised_sums) self.assertNear(result_stddev, NOISE_STD, 0.7) # value for chi-squared test