Implement initial_sample_state for TreeRangeSumQuery.
PiperOrigin-RevId: 480685277
This commit is contained in:
parent
79fe32a60b
commit
5e37c1bc70
2 changed files with 17 additions and 0 deletions
|
@ -136,6 +136,12 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
arity=self._arity,
|
arity=self._arity,
|
||||||
inner_query_state=self._inner_query.initial_global_state())
|
inner_query_state=self._inner_query.initial_global_state())
|
||||||
|
|
||||||
|
def initial_sample_state(self, template=None):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
|
||||||
|
return self.preprocess_record(
|
||||||
|
self.derive_sample_params(self.initial_global_state()),
|
||||||
|
super().initial_sample_state(template))
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||||
return (global_state.arity,
|
return (global_state.arity,
|
||||||
|
|
|
@ -79,6 +79,17 @@ class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertIsInstance(global_state,
|
self.assertIsInstance(global_state,
|
||||||
tree_range_query.TreeRangeSumQuery.GlobalState)
|
tree_range_query.TreeRangeSumQuery.GlobalState)
|
||||||
|
|
||||||
|
def test_pass_initial_sample_state_to_get_noised_result(self):
|
||||||
|
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
|
||||||
|
10., 0., 2)
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
template = tf.TensorSpec.from_tensor(
|
||||||
|
tf.constant([1, 2, 3, 4], dtype=tf.int32))
|
||||||
|
sample_state = query.initial_sample_state(template)
|
||||||
|
result = query.get_noised_result(sample_state, global_state)[0]
|
||||||
|
expected_result = [[0], [0] * 2, [0] * 4]
|
||||||
|
self.assertAllClose(result, expected_result)
|
||||||
|
|
||||||
@parameterized.product(
|
@parameterized.product(
|
||||||
inner_query=['central', 'distributed'],
|
inner_query=['central', 'distributed'],
|
||||||
clip_norm=[0.1, 1.0, 10.0],
|
clip_norm=[0.1, 1.0, 10.0],
|
||||||
|
|
Loading…
Reference in a new issue