diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py index ace6484..1115f40 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py @@ -221,9 +221,9 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase): _get_noise_generator), ) def test_noisy_cumsum_and_state_update(self, records, value_generator): - num_trials = 200 - record_specs = tf.TensorSpec([]) - records = [tf.constant(r) for r in records] + num_trials, vector_size = 10, 100 + record_specs = tf.TensorSpec([vector_size]) + records = [tf.constant(r, shape=[vector_size]) for r in records] noised_sums = [] for i in range(num_trials): query = tree_aggregation_query.TreeCumulativeSumQuery(