forked from 626_privacy/tensorflow_privacy
Ensures types remain consistent.
PiperOrigin-RevId: 563244784
This commit is contained in:
parent
c92610e37a
commit
a23cccde8b
1 changed files with 1 additions and 1 deletions
|
@ -73,7 +73,7 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
|
|||
sample_state, global_state.numerator_state)
|
||||
|
||||
def normalize(v):
|
||||
return tf.truediv(v, global_state.denominator)
|
||||
return tf.truediv(v, tf.cast(global_state.denominator, v.dtype))
|
||||
|
||||
# The denominator is constant so the privacy cost comes from the numerator.
|
||||
return (tf.nest.map_structure(normalize, noised_sum),
|
||||
|
|
Loading…
Reference in a new issue