Implement initial_sample_state for TreeRangeSumQuery.
PiperOrigin-RevId: 480685277
This commit is contained in:
parent
79fe32a60b
commit
5e37c1bc70
2 changed files with 17 additions and 0 deletions
|
@ -136,6 +136,12 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
arity=self._arity,
|
||||
inner_query_state=self._inner_query.initial_global_state())
|
||||
|
||||
def initial_sample_state(self, template=None):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
|
||||
return self.preprocess_record(
|
||||
self.derive_sample_params(self.initial_global_state()),
|
||||
super().initial_sample_state(template))
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
return (global_state.arity,
|
||||
|
|
|
@ -79,6 +79,17 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertIsInstance(global_state,
|
||||
tree_range_query.TreeRangeSumQuery.GlobalState)
|
||||
|
||||
def test_pass_initial_sample_state_to_get_noised_result(self):
|
||||
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
||||
10., 0., 2)
|
||||
global_state = query.initial_global_state()
|
||||
template = tf.TensorSpec.from_tensor(
|
||||
tf.constant([1, 2, 3, 4], dtype=tf.int32))
|
||||
sample_state = query.initial_sample_state(template)
|
||||
result = query.get_noised_result(sample_state, global_state)[0]
|
||||
expected_result = [[0], [0] * 2, [0] * 4]
|
||||
self.assertAllClose(result, expected_result)
|
||||
|
||||
@parameterized.product(
|
||||
inner_query=['central', 'distributed'],
|
||||
clip_norm=[0.1, 1.0, 10.0],
|
||||
|
|
Loading…
Reference in a new issue