diff --git a/tensorflow_privacy/privacy/dp_query/gaussian_query.py b/tensorflow_privacy/privacy/dp_query/gaussian_query.py index ca223ec..048595a 100644 --- a/tensorflow_privacy/privacy/dp_query/gaussian_query.py +++ b/tensorflow_privacy/privacy/dp_query/gaussian_query.py @@ -20,8 +20,9 @@ from __future__ import division from __future__ import print_function import collections +import distutils +import numbers -from distutils.version import LooseVersion import tensorflow.compat.v1 as tf from tensorflow_privacy.privacy.dp_query import dp_query @@ -55,8 +56,11 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery): def make_global_state(self, l2_norm_clip, stddev): """Creates a global state from the given parameters.""" - return self._GlobalState(tf.cast(l2_norm_clip, tf.float32), - tf.cast(stddev, tf.float32)) + l2_norm_clip = float(l2_norm_clip) if isinstance( + l2_norm_clip, numbers.Number) else tf.cast(l2_norm_clip, tf.float32) + stddev = float(stddev) if isinstance(stddev, numbers.Number) else tf.cast( + stddev, tf.float32) + return self._GlobalState(l2_norm_clip, stddev) def initial_global_state(self): return self.make_global_state(self._l2_norm_clip, self._stddev) @@ -87,7 +91,9 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery): def get_noised_result(self, sample_state, global_state): """See base class.""" - if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): + if distutils.version.LooseVersion( + tf.__version__) < distutils.version.LooseVersion('2.0.0'): + def add_noise(v): return v + tf.random.normal( tf.shape(input=v), stddev=global_state.stddev, dtype=v.dtype)