Change default restarter state in tree_aggregation_query to empty tuple as None type is not compatible with TFF.
PiperOrigin-RevId: 390278173
This commit is contained in:
parent
b4c04093cf
commit
b8c1ba72cd
1 changed files with 2 additions and 2 deletions
|
@ -137,7 +137,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
initial_tree_state = self._tree_aggregator.init_state()
|
initial_tree_state = self._tree_aggregator.init_state()
|
||||||
initial_samples_cumulative_sum = tf.nest.map_structure(
|
initial_samples_cumulative_sum = tf.nest.map_structure(
|
||||||
lambda spec: tf.zeros(spec.shape), self._record_specs)
|
lambda spec: tf.zeros(spec.shape), self._record_specs)
|
||||||
restarter_state = None
|
restarter_state = ()
|
||||||
if self._restart_indicator is not None:
|
if self._restart_indicator is not None:
|
||||||
restarter_state = self._restart_indicator.initialize()
|
restarter_state = self._restart_indicator.initialize()
|
||||||
return TreeCumulativeSumQuery.GlobalState(
|
return TreeCumulativeSumQuery.GlobalState(
|
||||||
|
@ -368,7 +368,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||||
initial_tree_state = self._tree_aggregator.init_state()
|
initial_tree_state = self._tree_aggregator.init_state()
|
||||||
restarter_state = None
|
restarter_state = ()
|
||||||
if self._restart_indicator is not None:
|
if self._restart_indicator is not None:
|
||||||
restarter_state = self._restart_indicator.initialize()
|
restarter_state = self._restart_indicator.initialize()
|
||||||
return TreeResidualSumQuery.GlobalState(
|
return TreeResidualSumQuery.GlobalState(
|
||||||
|
|
Loading…
Reference in a new issue