Change default restarter state in tree_aggregation_query to empty tuple as None type is not compatible with TFF.

PiperOrigin-RevId: 390278173
This commit is contained in:
Zheng Xu 2021-08-11 20:20:15 -07:00 committed by A. Unique TensorFlower
parent b4c04093cf
commit b8c1ba72cd

View file

@ -137,7 +137,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
initial_tree_state = self._tree_aggregator.init_state() initial_tree_state = self._tree_aggregator.init_state()
initial_samples_cumulative_sum = tf.nest.map_structure( initial_samples_cumulative_sum = tf.nest.map_structure(
lambda spec: tf.zeros(spec.shape), self._record_specs) lambda spec: tf.zeros(spec.shape), self._record_specs)
restarter_state = None restarter_state = ()
if self._restart_indicator is not None: if self._restart_indicator is not None:
restarter_state = self._restart_indicator.initialize() restarter_state = self._restart_indicator.initialize()
return TreeCumulativeSumQuery.GlobalState( return TreeCumulativeSumQuery.GlobalState(
@ -368,7 +368,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
def initial_global_state(self): def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" """Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
initial_tree_state = self._tree_aggregator.init_state() initial_tree_state = self._tree_aggregator.init_state()
restarter_state = None restarter_state = ()
if self._restart_indicator is not None: if self._restart_indicator is not None:
restarter_state = self._restart_indicator.initialize() restarter_state = self._restart_indicator.initialize()
return TreeResidualSumQuery.GlobalState( return TreeResidualSumQuery.GlobalState(