Fixed the previous bug that get_noised_result does not map inner_query's get_noised_result to the input record and updates global_state.

PiperOrigin-RevId: 388153296
This commit is contained in:
A. Unique TensorFlower 2021-08-01 23:13:00 -07:00
parent 2672559471
commit 11900acf9b

View file

@ -540,13 +540,19 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
# This part is not written in tensorflow and will be executed on the server
# side instead of the client side if used with
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
sample_state, inner_query_state = self._inner_query.get_noised_result(
sample_state, global_state.inner_query_state)
new_global_state = TreeRangeSumQuery.GlobalState(
arity=global_state.arity,
inner_query_state=inner_query_state)
row_splits = [0] + [
(self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(
math.floor(math.log(sample_state.shape[0], self._arity)) + 1)
]
tree = tf.RaggedTensor.from_row_splits(
values=sample_state, row_splits=row_splits)
return tree, global_state
return tree, new_global_state
@classmethod
def build_central_gaussian_query(cls,