Ensures types remain consistent.

PiperOrigin-RevId: 563244784
This commit is contained in:
A. Unique TensorFlower 2023-09-06 16:12:46 -07:00
parent c92610e37a
commit a23cccde8b

View file

@ -73,7 +73,7 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
sample_state, global_state.numerator_state)
def normalize(v):
return tf.truediv(v, global_state.denominator)
return tf.truediv(v, tf.cast(global_state.denominator, v.dtype))
# The denominator is constant so the privacy cost comes from the numerator.
return (tf.nest.map_structure(normalize, noised_sum),