forked from 626_privacy/tensorflow_privacy
Ensures types remain consistent.
PiperOrigin-RevId: 563244784
This commit is contained in:
parent
c92610e37a
commit
a23cccde8b
1 changed files with 1 additions and 1 deletions
|
@ -73,7 +73,7 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
|
||||||
sample_state, global_state.numerator_state)
|
sample_state, global_state.numerator_state)
|
||||||
|
|
||||||
def normalize(v):
|
def normalize(v):
|
||||||
return tf.truediv(v, global_state.denominator)
|
return tf.truediv(v, tf.cast(global_state.denominator, v.dtype))
|
||||||
|
|
||||||
# The denominator is constant so the privacy cost comes from the numerator.
|
# The denominator is constant so the privacy cost comes from the numerator.
|
||||||
return (tf.nest.map_structure(normalize, noised_sum),
|
return (tf.nest.map_structure(normalize, noised_sum),
|
||||||
|
|
Loading…
Reference in a new issue