From 463868e796d60668db75c67ee6838c31c3d67e77 Mon Sep 17 00:00:00 2001 From: Keith Rush Date: Wed, 22 Apr 2020 10:36:35 -0700 Subject: [PATCH] Relaxes dtype assumption in Gaussian DP sum query. PiperOrigin-RevId: 307846823 --- tensorflow_privacy/privacy/dp_query/gaussian_query.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_privacy/privacy/dp_query/gaussian_query.py b/tensorflow_privacy/privacy/dp_query/gaussian_query.py index 790015e..6f64965 100644 --- a/tensorflow_privacy/privacy/dp_query/gaussian_query.py +++ b/tensorflow_privacy/privacy/dp_query/gaussian_query.py @@ -94,13 +94,13 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery): if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): def add_noise(v): return v + tf.random.normal( - tf.shape(input=v), stddev=global_state.stddev) + tf.shape(input=v), stddev=global_state.stddev, dtype=v.dtype) else: random_normal = tf.random_normal_initializer( stddev=global_state.stddev) def add_noise(v): - return v + random_normal(tf.shape(input=v)) + return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype) if self._ledger: dependencies = [