From 13b3a04a3e0d3072ce99ebd40d0a4b3939b7a77c Mon Sep 17 00:00:00 2001 From: pranav subramani Date: Fri, 8 Jan 2021 00:23:32 -0700 Subject: [PATCH] update keras model --- tensorflow_privacy/privacy/keras_models/dp_keras_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index 6565082..b854f54 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -27,7 +27,7 @@ def make_dp_model_class(cls): noise = tf.random.normal( tf.shape(input=summed_grads), stddev=noise_stddev) noised_grads = summed_grads + noise - return noised_grads / tf.cast(stacked_grads.shape[0], tf.float32) + return noised_grads / tf.cast(stacked_grads.shape[0], noised_grads.dtype) def compute_per_example_grads(self, data): x, y = data