diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index 59ea0dc..082bf01 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -507,9 +507,6 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery): """ arity, inner_query_params = params preprocessed_record = _build_tree_from_leaf(record, arity).flat_values - preprocessed_record = self._inner_query.preprocess_record( - inner_query_params, preprocessed_record) - # The following codes reshape the output vector so the output shape of can # be statically inferred. This is useful when used with # `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know @@ -518,7 +515,12 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery): (self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) - 1) // (self._arity - 1) ] - return tf.reshape(preprocessed_record, preprocessed_record_shape) + preprocessed_record = tf.reshape(preprocessed_record, + preprocessed_record_shape) + preprocessed_record = self._inner_query.preprocess_record( + inner_query_params, preprocessed_record) + + return preprocessed_record def get_noised_result(self, sample_state, global_state): """Implements `tensorflow_privacy.DPQuery.get_noised_result`. @@ -543,8 +545,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery): sample_state, inner_query_state = self._inner_query.get_noised_result( sample_state, global_state.inner_query_state) new_global_state = TreeRangeSumQuery.GlobalState( - arity=global_state.arity, - inner_query_state=inner_query_state) + arity=global_state.arity, inner_query_state=inner_query_state) row_splits = [0] + [ (self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(