Cast to ensure record of NoPrivacyAverageQuery is float for compatibility with sample_state.

PiperOrigin-RevId: 249909614
This commit is contained in:
Galen Andrew 2019-05-24 15:28:16 -07:00 committed by A. Unique TensorFlower
parent 15c07250a1
commit 7636945566

View file

@ -55,17 +55,17 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
def preprocess_record(self, params, record, weight=1):
"""Multiplies record by weight."""
weighted_record = nest.map_structure(lambda t: weight * t, record)
return (weighted_record, weight)
return (weighted_record, tf.cast(weight, tf.float32))
def accumulate_record(self, params, sample_state, record, weight=1):
"""Accumulates record, multiplying by weight."""
weighted_record = nest.map_structure(lambda t: weight * t, record)
return self.accumulate_preprocessed_record(
sample_state, (weighted_record, weight))
sample_state, (weighted_record, tf.cast(weight, tf.float32)))
def get_noised_result(self, sample_state, global_state):
"""See base class."""
sum_state, denominator = sample_state
return nest.map_structure(
lambda t: tf.truediv(t, denominator), sum_state), ()
lambda t: tf.truediv(t, denominator), sum_state), global_state