From 5ad8676d38b9ca37b82ebbc39d941d6a2888f1bc Mon Sep 17 00:00:00 2001 From: Steve Chien Date: Thu, 6 Aug 2020 14:19:22 -0700 Subject: [PATCH] Automated rollback of commit 5433436b863ec9d5822e4261e1b0637a4396a197 PiperOrigin-RevId: 325308999 --- tensorflow_privacy/privacy/dp_query/gaussian_query.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tensorflow_privacy/privacy/dp_query/gaussian_query.py b/tensorflow_privacy/privacy/dp_query/gaussian_query.py index 048595a..9095c44 100644 --- a/tensorflow_privacy/privacy/dp_query/gaussian_query.py +++ b/tensorflow_privacy/privacy/dp_query/gaussian_query.py @@ -21,7 +21,6 @@ from __future__ import print_function import collections import distutils -import numbers import tensorflow.compat.v1 as tf @@ -56,11 +55,8 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery): def make_global_state(self, l2_norm_clip, stddev): """Creates a global state from the given parameters.""" - l2_norm_clip = float(l2_norm_clip) if isinstance( - l2_norm_clip, numbers.Number) else tf.cast(l2_norm_clip, tf.float32) - stddev = float(stddev) if isinstance(stddev, numbers.Number) else tf.cast( - stddev, tf.float32) - return self._GlobalState(l2_norm_clip, stddev) + return self._GlobalState(tf.cast(l2_norm_clip, tf.float32), + tf.cast(stddev, tf.float32)) def initial_global_state(self): return self.make_global_state(self._l2_norm_clip, self._stddev)