diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index 6565082..b854f54 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -27,7 +27,7 @@ def make_dp_model_class(cls): noise = tf.random.normal( tf.shape(input=summed_grads), stddev=noise_stddev) noised_grads = summed_grads + noise - return noised_grads / tf.cast(stacked_grads.shape[0], tf.float32) + return noised_grads / tf.cast(stacked_grads.shape[0], noised_grads.dtype) def compute_per_example_grads(self, data): x, y = data