From aa3f841893b15e856f93fe33ab2c33ef7a3f4442 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 7 Aug 2021 11:21:13 -0700 Subject: [PATCH] In `TreeRangeSumQuery.preprocess_record`, move the reshaping operation before applying `inner_query.preprocess_record`. The change is due to the newly checked-in `DistributedDiscreteGaussianSumQuery` whose `preprocess_record` requires explicit shape information during tracing. PiperOrigin-RevId: 389392878 --- .../privacy/dp_query/tree_aggregation_query.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index 59ea0dc..082bf01 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -507,9 +507,6 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery): """ arity, inner_query_params = params preprocessed_record = _build_tree_from_leaf(record, arity).flat_values - preprocessed_record = self._inner_query.preprocess_record( - inner_query_params, preprocessed_record) - # The following codes reshape the output vector so the output shape of can # be statically inferred. This is useful when used with # `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know @@ -518,7 +515,12 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery): (self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) - 1) // (self._arity - 1) ] - return tf.reshape(preprocessed_record, preprocessed_record_shape) + preprocessed_record = tf.reshape(preprocessed_record, + preprocessed_record_shape) + preprocessed_record = self._inner_query.preprocess_record( + inner_query_params, preprocessed_record) + + return preprocessed_record def get_noised_result(self, sample_state, global_state): """Implements `tensorflow_privacy.DPQuery.get_noised_result`. @@ -543,8 +545,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery): sample_state, inner_query_state = self._inner_query.get_noised_result( sample_state, global_state.inner_query_state) new_global_state = TreeRangeSumQuery.GlobalState( - arity=global_state.arity, - inner_query_state=inner_query_state) + arity=global_state.arity, inner_query_state=inner_query_state) row_splits = [0] + [ (self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(