diff --git a/tensorflow_privacy/privacy/dp_query/tree_range_query.py b/tensorflow_privacy/privacy/dp_query/tree_range_query.py index 7cdc0b0..d86cc89 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_range_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_range_query.py @@ -136,6 +136,12 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery): arity=self._arity, inner_query_state=self._inner_query.initial_global_state()) + def initial_sample_state(self, template=None): + """Implements `tensorflow_privacy.DPQuery.initial_sample_state`.""" + return self.preprocess_record( + self.derive_sample_params(self.initial_global_state()), + super().initial_sample_state(template)) + def derive_sample_params(self, global_state): """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" return (global_state.arity, diff --git a/tensorflow_privacy/privacy/dp_query/tree_range_query_test.py b/tensorflow_privacy/privacy/dp_query/tree_range_query_test.py index 9cae1a1..cf45723 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_range_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/tree_range_query_test.py @@ -79,6 +79,17 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase): self.assertIsInstance(global_state, tree_range_query.TreeRangeSumQuery.GlobalState) + def test_pass_initial_sample_state_to_get_noised_result(self): + query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( + 10., 0., 2) + global_state = query.initial_global_state() + template = tf.TensorSpec.from_tensor( + tf.constant([1, 2, 3, 4], dtype=tf.int32)) + sample_state = query.initial_sample_state(template) + result = query.get_noised_result(sample_state, global_state)[0] + expected_result = [[0], [0] * 2, [0] * 4] + self.assertAllClose(result, expected_result) + @parameterized.product( inner_query=['central', 'distributed'], clip_norm=[0.1, 1.0, 10.0],