Fixed the previous bug that get_noised_result
does not map inner_query's get_noised_result
to the input record and updates global_state
.
PiperOrigin-RevId: 388153296
This commit is contained in:
parent
2672559471
commit
11900acf9b
1 changed files with 7 additions and 1 deletions
|
@ -540,13 +540,19 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
# This part is not written in tensorflow and will be executed on the server
|
||||
# side instead of the client side if used with
|
||||
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
|
||||
sample_state, inner_query_state = self._inner_query.get_noised_result(
|
||||
sample_state, global_state.inner_query_state)
|
||||
new_global_state = TreeRangeSumQuery.GlobalState(
|
||||
arity=global_state.arity,
|
||||
inner_query_state=inner_query_state)
|
||||
|
||||
row_splits = [0] + [
|
||||
(self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(
|
||||
math.floor(math.log(sample_state.shape[0], self._arity)) + 1)
|
||||
]
|
||||
tree = tf.RaggedTensor.from_row_splits(
|
||||
values=sample_state, row_splits=row_splits)
|
||||
return tree, global_state
|
||||
return tree, new_global_state
|
||||
|
||||
@classmethod
|
||||
def build_central_gaussian_query(cls,
|
||||
|
|
Loading…
Reference in a new issue