diff --git a/privacy/dp_query/no_privacy_query.py b/privacy/dp_query/no_privacy_query.py index 449b970..3d03ce7 100644 --- a/privacy/dp_query/no_privacy_query.py +++ b/privacy/dp_query/no_privacy_query.py @@ -55,17 +55,17 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery): def preprocess_record(self, params, record, weight=1): """Multiplies record by weight.""" weighted_record = nest.map_structure(lambda t: weight * t, record) - return (weighted_record, weight) + return (weighted_record, tf.cast(weight, tf.float32)) def accumulate_record(self, params, sample_state, record, weight=1): """Accumulates record, multiplying by weight.""" weighted_record = nest.map_structure(lambda t: weight * t, record) return self.accumulate_preprocessed_record( - sample_state, (weighted_record, weight)) + sample_state, (weighted_record, tf.cast(weight, tf.float32))) def get_noised_result(self, sample_state, global_state): """See base class.""" sum_state, denominator = sample_state return nest.map_structure( - lambda t: tf.truediv(t, denominator), sum_state), () + lambda t: tf.truediv(t, denominator), sum_state), global_state