forked from 626_privacy/tensorflow_privacy
Try to fix flaky tree_aggregation_query_test.test_noisy_cumsum_and_state_update
.
PiperOrigin-RevId: 394248815
This commit is contained in:
parent
7e7736ea91
commit
e99fb7ea9b
1 changed files with 5 additions and 5 deletions
|
@ -214,16 +214,16 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('two_records_noise_fn', [2.71828, 3.14159], _get_noise_fn),
|
('two_records_noise_fn', [2.71828, 3.14159], _get_noise_fn),
|
||||||
('five_records_noise_fn', np.random.uniform(size=5).tolist(),
|
('five_records_noise_fn', np.random.uniform(low=0.1, size=5).tolist(),
|
||||||
_get_noise_fn),
|
_get_noise_fn),
|
||||||
('two_records_generator', [2.71828, 3.14159], _get_noise_generator),
|
('two_records_generator', [2.71828, 3.14159], _get_noise_generator),
|
||||||
('five_records_generator', np.random.uniform(size=5).tolist(),
|
('five_records_generator', np.random.uniform(low=0.1, size=5).tolist(),
|
||||||
_get_noise_generator),
|
_get_noise_generator),
|
||||||
)
|
)
|
||||||
def test_noisy_cumsum_and_state_update(self, records, value_generator):
|
def test_noisy_cumsum_and_state_update(self, records, value_generator):
|
||||||
num_trials = 200
|
num_trials = 200
|
||||||
record_specs = tf.nest.map_structure(lambda t: tf.TensorSpec(tf.shape(t)),
|
record_specs = tf.TensorSpec([])
|
||||||
records[0])
|
records = [tf.constant(r) for r in records]
|
||||||
noised_sums = []
|
noised_sums = []
|
||||||
for i in range(num_trials):
|
for i in range(num_trials):
|
||||||
query = tree_aggregation_query.TreeCumulativeSumQuery(
|
query = tree_aggregation_query.TreeCumulativeSumQuery(
|
||||||
|
@ -232,7 +232,7 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
noise_generator=value_generator(record_specs, seed=i),
|
noise_generator=value_generator(record_specs, seed=i),
|
||||||
record_specs=record_specs)
|
record_specs=record_specs)
|
||||||
query_result, _ = test_utils.run_query(query, records)
|
query_result, _ = test_utils.run_query(query, records)
|
||||||
noised_sums.append(query_result)
|
noised_sums.append(query_result.numpy())
|
||||||
result_stddev = np.std(noised_sums)
|
result_stddev = np.std(noised_sums)
|
||||||
self.assertNear(result_stddev, NOISE_STD, 0.7) # value for chi-squared test
|
self.assertNear(result_stddev, NOISE_STD, 0.7) # value for chi-squared test
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue