diff --git a/tensorflow_privacy/privacy/dp_query/normalized_query.py b/tensorflow_privacy/privacy/dp_query/normalized_query.py index e6e1b9c..07812e6 100644 --- a/tensorflow_privacy/privacy/dp_query/normalized_query.py +++ b/tensorflow_privacy/privacy/dp_query/normalized_query.py @@ -73,7 +73,7 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery): sample_state, global_state.numerator_state) def normalize(v): - return tf.truediv(v, global_state.denominator) + return tf.truediv(v, tf.cast(global_state.denominator, v.dtype)) # The denominator is constant so the privacy cost comes from the numerator. return (tf.nest.map_structure(normalize, noised_sum),