From a23cccde8b25883f333c48ea8b47806bb1b64be8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 6 Sep 2023 16:12:46 -0700 Subject: [PATCH] Ensures types remain consistent. PiperOrigin-RevId: 563244784 --- tensorflow_privacy/privacy/dp_query/normalized_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_privacy/privacy/dp_query/normalized_query.py b/tensorflow_privacy/privacy/dp_query/normalized_query.py index e6e1b9c..07812e6 100644 --- a/tensorflow_privacy/privacy/dp_query/normalized_query.py +++ b/tensorflow_privacy/privacy/dp_query/normalized_query.py @@ -73,7 +73,7 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery): sample_state, global_state.numerator_state) def normalize(v): - return tf.truediv(v, global_state.denominator) + return tf.truediv(v, tf.cast(global_state.denominator, v.dtype)) # The denominator is constant so the privacy cost comes from the numerator. return (tf.nest.map_structure(normalize, noised_sum),