From 11900acf9ba4c70d876e480e787671d93f9952fe Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 1 Aug 2021 23:13:00 -0700 Subject: [PATCH] Fixed the previous bug that `get_noised_result` does not map inner_query's `get_noised_result` to the input record and updates `global_state`. PiperOrigin-RevId: 388153296 --- .../privacy/dp_query/tree_aggregation_query.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index 990391b..59ea0dc 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -540,13 +540,19 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery): # This part is not written in tensorflow and will be executed on the server # side instead of the client side if used with # tff.aggregators.DifferentiallyPrivateFactory for federated learning. + sample_state, inner_query_state = self._inner_query.get_noised_result( + sample_state, global_state.inner_query_state) + new_global_state = TreeRangeSumQuery.GlobalState( + arity=global_state.arity, + inner_query_state=inner_query_state) + row_splits = [0] + [ (self._arity**(x + 1) - 1) // (self._arity - 1) for x in range( math.floor(math.log(sample_state.shape[0], self._arity)) + 1) ] tree = tf.RaggedTensor.from_row_splits( values=sample_state, row_splits=row_splits) - return tree, global_state + return tree, new_global_state @classmethod def build_central_gaussian_query(cls,