Change GaussianSumQuery to not convert hyperparameters from Python numbers to Tensors.
PiperOrigin-RevId: 325251302
This commit is contained in:
parent
efca03b593
commit
5433436b86
1 changed files with 10 additions and 4 deletions
|
@ -20,8 +20,9 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import distutils
|
||||||
|
import numbers
|
||||||
|
|
||||||
from distutils.version import LooseVersion
|
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
|
@ -55,8 +56,11 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
|
||||||
def make_global_state(self, l2_norm_clip, stddev):
|
def make_global_state(self, l2_norm_clip, stddev):
|
||||||
"""Creates a global state from the given parameters."""
|
"""Creates a global state from the given parameters."""
|
||||||
return self._GlobalState(tf.cast(l2_norm_clip, tf.float32),
|
l2_norm_clip = float(l2_norm_clip) if isinstance(
|
||||||
tf.cast(stddev, tf.float32))
|
l2_norm_clip, numbers.Number) else tf.cast(l2_norm_clip, tf.float32)
|
||||||
|
stddev = float(stddev) if isinstance(stddev, numbers.Number) else tf.cast(
|
||||||
|
stddev, tf.float32)
|
||||||
|
return self._GlobalState(l2_norm_clip, stddev)
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
return self.make_global_state(self._l2_norm_clip, self._stddev)
|
return self.make_global_state(self._l2_norm_clip, self._stddev)
|
||||||
|
@ -87,7 +91,9 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
if distutils.version.LooseVersion(
|
||||||
|
tf.__version__) < distutils.version.LooseVersion('2.0.0'):
|
||||||
|
|
||||||
def add_noise(v):
|
def add_noise(v):
|
||||||
return v + tf.random.normal(
|
return v + tf.random.normal(
|
||||||
tf.shape(input=v), stddev=global_state.stddev, dtype=v.dtype)
|
tf.shape(input=v), stddev=global_state.stddev, dtype=v.dtype)
|
||||||
|
|
Loading…
Reference in a new issue