Change GaussianSumQuery to not convert hyperparameters from Python numbers to Tensors.

PiperOrigin-RevId: 325251302
This commit is contained in:
Steve Chien 2020-08-06 09:56:05 -07:00 committed by A. Unique TensorFlower
parent efca03b593
commit 5433436b86

View file

@ -20,8 +20,9 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import distutils
import numbers
from distutils.version import LooseVersion
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
@ -55,8 +56,11 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
def make_global_state(self, l2_norm_clip, stddev): def make_global_state(self, l2_norm_clip, stddev):
"""Creates a global state from the given parameters.""" """Creates a global state from the given parameters."""
return self._GlobalState(tf.cast(l2_norm_clip, tf.float32), l2_norm_clip = float(l2_norm_clip) if isinstance(
tf.cast(stddev, tf.float32)) l2_norm_clip, numbers.Number) else tf.cast(l2_norm_clip, tf.float32)
stddev = float(stddev) if isinstance(stddev, numbers.Number) else tf.cast(
stddev, tf.float32)
return self._GlobalState(l2_norm_clip, stddev)
def initial_global_state(self): def initial_global_state(self):
return self.make_global_state(self._l2_norm_clip, self._stddev) return self.make_global_state(self._l2_norm_clip, self._stddev)
@ -87,7 +91,9 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state): def get_noised_result(self, sample_state, global_state):
"""See base class.""" """See base class."""
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): if distutils.version.LooseVersion(
tf.__version__) < distutils.version.LooseVersion('2.0.0'):
def add_noise(v): def add_noise(v):
return v + tf.random.normal( return v + tf.random.normal(
tf.shape(input=v), stddev=global_state.stddev, dtype=v.dtype) tf.shape(input=v), stddev=global_state.stddev, dtype=v.dtype)