diff --git a/tensorflow_privacy/privacy/dp_query/gaussian_query.py b/tensorflow_privacy/privacy/dp_query/gaussian_query.py index 048595a..9095c44 100644 --- a/tensorflow_privacy/privacy/dp_query/gaussian_query.py +++ b/tensorflow_privacy/privacy/dp_query/gaussian_query.py @@ -21,7 +21,6 @@ from __future__ import print_function import collections import distutils -import numbers import tensorflow.compat.v1 as tf @@ -56,11 +55,8 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery): def make_global_state(self, l2_norm_clip, stddev): """Creates a global state from the given parameters.""" - l2_norm_clip = float(l2_norm_clip) if isinstance( - l2_norm_clip, numbers.Number) else tf.cast(l2_norm_clip, tf.float32) - stddev = float(stddev) if isinstance(stddev, numbers.Number) else tf.cast( - stddev, tf.float32) - return self._GlobalState(l2_norm_clip, stddev) + return self._GlobalState(tf.cast(l2_norm_clip, tf.float32), + tf.cast(stddev, tf.float32)) def initial_global_state(self): return self.make_global_state(self._l2_norm_clip, self._stddev)