In TreeRangeSumQuery.preprocess_record
, move the reshaping operation before applying inner_query.preprocess_record
. The change is due to the newly checked-in DistributedDiscreteGaussianSumQuery
whose preprocess_record
requires explicit shape information during tracing.
PiperOrigin-RevId: 389392878
This commit is contained in:
parent
11900acf9b
commit
aa3f841893
1 changed files with 7 additions and 6 deletions
|
@ -507,9 +507,6 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""
|
"""
|
||||||
arity, inner_query_params = params
|
arity, inner_query_params = params
|
||||||
preprocessed_record = _build_tree_from_leaf(record, arity).flat_values
|
preprocessed_record = _build_tree_from_leaf(record, arity).flat_values
|
||||||
preprocessed_record = self._inner_query.preprocess_record(
|
|
||||||
inner_query_params, preprocessed_record)
|
|
||||||
|
|
||||||
# The following codes reshape the output vector so the output shape of can
|
# The following codes reshape the output vector so the output shape of can
|
||||||
# be statically inferred. This is useful when used with
|
# be statically inferred. This is useful when used with
|
||||||
# `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know
|
# `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know
|
||||||
|
@ -518,7 +515,12 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
(self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) -
|
(self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) -
|
||||||
1) // (self._arity - 1)
|
1) // (self._arity - 1)
|
||||||
]
|
]
|
||||||
return tf.reshape(preprocessed_record, preprocessed_record_shape)
|
preprocessed_record = tf.reshape(preprocessed_record,
|
||||||
|
preprocessed_record_shape)
|
||||||
|
preprocessed_record = self._inner_query.preprocess_record(
|
||||||
|
inner_query_params, preprocessed_record)
|
||||||
|
|
||||||
|
return preprocessed_record
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
|
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
|
||||||
|
@ -543,8 +545,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
sample_state, inner_query_state = self._inner_query.get_noised_result(
|
sample_state, inner_query_state = self._inner_query.get_noised_result(
|
||||||
sample_state, global_state.inner_query_state)
|
sample_state, global_state.inner_query_state)
|
||||||
new_global_state = TreeRangeSumQuery.GlobalState(
|
new_global_state = TreeRangeSumQuery.GlobalState(
|
||||||
arity=global_state.arity,
|
arity=global_state.arity, inner_query_state=inner_query_state)
|
||||||
inner_query_state=inner_query_state)
|
|
||||||
|
|
||||||
row_splits = [0] + [
|
row_splits = [0] + [
|
||||||
(self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(
|
(self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(
|
||||||
|
|
Loading…
Reference in a new issue