diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index 990391b..59ea0dc 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -540,13 +540,19 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery): # This part is not written in tensorflow and will be executed on the server # side instead of the client side if used with # tff.aggregators.DifferentiallyPrivateFactory for federated learning. + sample_state, inner_query_state = self._inner_query.get_noised_result( + sample_state, global_state.inner_query_state) + new_global_state = TreeRangeSumQuery.GlobalState( + arity=global_state.arity, + inner_query_state=inner_query_state) + row_splits = [0] + [ (self._arity**(x + 1) - 1) // (self._arity - 1) for x in range( math.floor(math.log(sample_state.shape[0], self._arity)) + 1) ] tree = tf.RaggedTensor.from_row_splits( values=sample_state, row_splits=row_splits) - return tree, global_state + return tree, new_global_state @classmethod def build_central_gaussian_query(cls,