forked from 626_privacy/tensorflow_privacy
Try to fix flaky tree_aggregation_query_test.test_noisy_cumsum_and_state_update
.
PiperOrigin-RevId: 394248815
This commit is contained in:
parent
7e7736ea91
commit
e99fb7ea9b
1 changed files with 5 additions and 5 deletions
|
@ -214,16 +214,16 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
@parameterized.named_parameters(
|
||||
('two_records_noise_fn', [2.71828, 3.14159], _get_noise_fn),
|
||||
('five_records_noise_fn', np.random.uniform(size=5).tolist(),
|
||||
('five_records_noise_fn', np.random.uniform(low=0.1, size=5).tolist(),
|
||||
_get_noise_fn),
|
||||
('two_records_generator', [2.71828, 3.14159], _get_noise_generator),
|
||||
('five_records_generator', np.random.uniform(size=5).tolist(),
|
||||
('five_records_generator', np.random.uniform(low=0.1, size=5).tolist(),
|
||||
_get_noise_generator),
|
||||
)
|
||||
def test_noisy_cumsum_and_state_update(self, records, value_generator):
|
||||
num_trials = 200
|
||||
record_specs = tf.nest.map_structure(lambda t: tf.TensorSpec(tf.shape(t)),
|
||||
records[0])
|
||||
record_specs = tf.TensorSpec([])
|
||||
records = [tf.constant(r) for r in records]
|
||||
noised_sums = []
|
||||
for i in range(num_trials):
|
||||
query = tree_aggregation_query.TreeCumulativeSumQuery(
|
||||
|
@ -232,7 +232,7 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
noise_generator=value_generator(record_specs, seed=i),
|
||||
record_specs=record_specs)
|
||||
query_result, _ = test_utils.run_query(query, records)
|
||||
noised_sums.append(query_result)
|
||||
noised_sums.append(query_result.numpy())
|
||||
result_stddev = np.std(noised_sums)
|
||||
self.assertNear(result_stddev, NOISE_STD, 0.7) # value for chi-squared test
|
||||
|
||||
|
|
Loading…
Reference in a new issue