From 763694556616796d59db7c6385e439cf4a159af3 Mon Sep 17 00:00:00 2001 From: Galen Andrew Date: Fri, 24 May 2019 15:28:16 -0700 Subject: [PATCH] Cast to ensure record of NoPrivacyAverageQuery is float for compatibility with sample_state. PiperOrigin-RevId: 249909614 --- privacy/dp_query/no_privacy_query.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/privacy/dp_query/no_privacy_query.py b/privacy/dp_query/no_privacy_query.py index 449b970..3d03ce7 100644 --- a/privacy/dp_query/no_privacy_query.py +++ b/privacy/dp_query/no_privacy_query.py @@ -55,17 +55,17 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery): def preprocess_record(self, params, record, weight=1): """Multiplies record by weight.""" weighted_record = nest.map_structure(lambda t: weight * t, record) - return (weighted_record, weight) + return (weighted_record, tf.cast(weight, tf.float32)) def accumulate_record(self, params, sample_state, record, weight=1): """Accumulates record, multiplying by weight.""" weighted_record = nest.map_structure(lambda t: weight * t, record) return self.accumulate_preprocessed_record( - sample_state, (weighted_record, weight)) + sample_state, (weighted_record, tf.cast(weight, tf.float32))) def get_noised_result(self, sample_state, global_state): """See base class.""" sum_state, denominator = sample_state return nest.map_structure( - lambda t: tf.truediv(t, denominator), sum_state), () + lambda t: tf.truediv(t, denominator), sum_state), global_state