Implement initial_sample_state for TreeRangeSumQuery.

PiperOrigin-RevId: 480685277
This commit is contained in:
A. Unique TensorFlower 2022-10-12 12:10:34 -07:00
parent 79fe32a60b
commit 5e37c1bc70
2 changed files with 17 additions and 0 deletions

View file

@ -136,6 +136,12 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
arity=self._arity,
inner_query_state=self._inner_query.initial_global_state())
def initial_sample_state(self, template=None):
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
return self.preprocess_record(
self.derive_sample_params(self.initial_global_state()),
super().initial_sample_state(template))
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return (global_state.arity,

View file

@ -79,6 +79,17 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(global_state,
tree_range_query.TreeRangeSumQuery.GlobalState)
def test_pass_initial_sample_state_to_get_noised_result(self):
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
10., 0., 2)
global_state = query.initial_global_state()
template = tf.TensorSpec.from_tensor(
tf.constant([1, 2, 3, 4], dtype=tf.int32))
sample_state = query.initial_sample_state(template)
result = query.get_noised_result(sample_state, global_state)[0]
expected_result = [[0], [0] * 2, [0] * 4]
self.assertAllClose(result, expected_result)
@parameterized.product(
inner_query=['central', 'distributed'],
clip_norm=[0.1, 1.0, 10.0],