In TreeRangeSumQuery.preprocess_record, move the reshaping operation before applying inner_query.preprocess_record. The change is due to the newly checked-in DistributedDiscreteGaussianSumQuery whose preprocess_record requires explicit shape information during tracing.

PiperOrigin-RevId: 389392878
This commit is contained in:
A. Unique TensorFlower 2021-08-07 11:21:13 -07:00
parent 11900acf9b
commit aa3f841893

View file

@ -507,9 +507,6 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
""" """
arity, inner_query_params = params arity, inner_query_params = params
preprocessed_record = _build_tree_from_leaf(record, arity).flat_values preprocessed_record = _build_tree_from_leaf(record, arity).flat_values
preprocessed_record = self._inner_query.preprocess_record(
inner_query_params, preprocessed_record)
# The following codes reshape the output vector so the output shape of can # The following codes reshape the output vector so the output shape of can
# be statically inferred. This is useful when used with # be statically inferred. This is useful when used with
# `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know # `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know
@ -518,7 +515,12 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
(self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) - (self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) -
1) // (self._arity - 1) 1) // (self._arity - 1)
] ]
return tf.reshape(preprocessed_record, preprocessed_record_shape) preprocessed_record = tf.reshape(preprocessed_record,
preprocessed_record_shape)
preprocessed_record = self._inner_query.preprocess_record(
inner_query_params, preprocessed_record)
return preprocessed_record
def get_noised_result(self, sample_state, global_state): def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`. """Implements `tensorflow_privacy.DPQuery.get_noised_result`.
@ -543,8 +545,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
sample_state, inner_query_state = self._inner_query.get_noised_result( sample_state, inner_query_state = self._inner_query.get_noised_result(
sample_state, global_state.inner_query_state) sample_state, global_state.inner_query_state)
new_global_state = TreeRangeSumQuery.GlobalState( new_global_state = TreeRangeSumQuery.GlobalState(
arity=global_state.arity, arity=global_state.arity, inner_query_state=inner_query_state)
inner_query_state=inner_query_state)
row_splits = [0] + [ row_splits = [0] + [
(self._arity**(x + 1) - 1) // (self._arity - 1) for x in range( (self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(