Fixed the previous bug that get_noised_result
does not map inner_query's get_noised_result
to the input record and updates global_state
.
PiperOrigin-RevId: 388153296
This commit is contained in:
parent
2672559471
commit
11900acf9b
1 changed files with 7 additions and 1 deletions
|
@ -540,13 +540,19 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
# This part is not written in tensorflow and will be executed on the server
|
# This part is not written in tensorflow and will be executed on the server
|
||||||
# side instead of the client side if used with
|
# side instead of the client side if used with
|
||||||
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
|
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
|
||||||
|
sample_state, inner_query_state = self._inner_query.get_noised_result(
|
||||||
|
sample_state, global_state.inner_query_state)
|
||||||
|
new_global_state = TreeRangeSumQuery.GlobalState(
|
||||||
|
arity=global_state.arity,
|
||||||
|
inner_query_state=inner_query_state)
|
||||||
|
|
||||||
row_splits = [0] + [
|
row_splits = [0] + [
|
||||||
(self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(
|
(self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(
|
||||||
math.floor(math.log(sample_state.shape[0], self._arity)) + 1)
|
math.floor(math.log(sample_state.shape[0], self._arity)) + 1)
|
||||||
]
|
]
|
||||||
tree = tf.RaggedTensor.from_row_splits(
|
tree = tf.RaggedTensor.from_row_splits(
|
||||||
values=sample_state, row_splits=row_splits)
|
values=sample_state, row_splits=row_splits)
|
||||||
return tree, global_state
|
return tree, new_global_state
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_central_gaussian_query(cls,
|
def build_central_gaussian_query(cls,
|
||||||
|
|
Loading…
Reference in a new issue