In TreeRangeSumQuery.preprocess_record
, move the reshaping operation before applying inner_query.preprocess_record
. The change is due to the newly checked-in DistributedDiscreteGaussianSumQuery
whose preprocess_record
requires explicit shape information during tracing.
PiperOrigin-RevId: 389392878
This commit is contained in:
parent
11900acf9b
commit
aa3f841893
1 changed files with 7 additions and 6 deletions
|
@ -507,9 +507,6 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
"""
|
||||
arity, inner_query_params = params
|
||||
preprocessed_record = _build_tree_from_leaf(record, arity).flat_values
|
||||
preprocessed_record = self._inner_query.preprocess_record(
|
||||
inner_query_params, preprocessed_record)
|
||||
|
||||
# The following codes reshape the output vector so the output shape of can
|
||||
# be statically inferred. This is useful when used with
|
||||
# `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know
|
||||
|
@ -518,7 +515,12 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
(self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) -
|
||||
1) // (self._arity - 1)
|
||||
]
|
||||
return tf.reshape(preprocessed_record, preprocessed_record_shape)
|
||||
preprocessed_record = tf.reshape(preprocessed_record,
|
||||
preprocessed_record_shape)
|
||||
preprocessed_record = self._inner_query.preprocess_record(
|
||||
inner_query_params, preprocessed_record)
|
||||
|
||||
return preprocessed_record
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
|
||||
|
@ -543,8 +545,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
sample_state, inner_query_state = self._inner_query.get_noised_result(
|
||||
sample_state, global_state.inner_query_state)
|
||||
new_global_state = TreeRangeSumQuery.GlobalState(
|
||||
arity=global_state.arity,
|
||||
inner_query_state=inner_query_state)
|
||||
arity=global_state.arity, inner_query_state=inner_query_state)
|
||||
|
||||
row_splits = [0] + [
|
||||
(self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(
|
||||
|
|
Loading…
Reference in a new issue