From b8c1ba72cdb195be51faea4b9958661cc6737151 Mon Sep 17 00:00:00 2001 From: Zheng Xu Date: Wed, 11 Aug 2021 20:20:15 -0700 Subject: [PATCH] Change default restarter state in tree_aggregation_query to empty tuple as None type is not compatible with TFF. PiperOrigin-RevId: 390278173 --- tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index 943cf9f..7ef73a1 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -137,7 +137,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): initial_tree_state = self._tree_aggregator.init_state() initial_samples_cumulative_sum = tf.nest.map_structure( lambda spec: tf.zeros(spec.shape), self._record_specs) - restarter_state = None + restarter_state = () if self._restart_indicator is not None: restarter_state = self._restart_indicator.initialize() return TreeCumulativeSumQuery.GlobalState( @@ -368,7 +368,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): def initial_global_state(self): """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" initial_tree_state = self._tree_aggregator.init_state() - restarter_state = None + restarter_state = () if self._restart_indicator is not None: restarter_state = self._restart_indicator.initialize() return TreeResidualSumQuery.GlobalState(