diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index 943cf9f..7ef73a1 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -137,7 +137,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): initial_tree_state = self._tree_aggregator.init_state() initial_samples_cumulative_sum = tf.nest.map_structure( lambda spec: tf.zeros(spec.shape), self._record_specs) - restarter_state = None + restarter_state = () if self._restart_indicator is not None: restarter_state = self._restart_indicator.initialize() return TreeCumulativeSumQuery.GlobalState( @@ -368,7 +368,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): def initial_global_state(self): """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" initial_tree_state = self._tree_aggregator.init_state() - restarter_state = None + restarter_state = () if self._restart_indicator is not None: restarter_state = self._restart_indicator.initialize() return TreeResidualSumQuery.GlobalState(