Relaxes dtype assumption in Gaussian DP sum query.

PiperOrigin-RevId: 307846823
This commit is contained in:
Keith Rush 2020-04-22 10:36:35 -07:00 committed by A. Unique TensorFlower
parent c5c807807f
commit 463868e796

View file

@ -94,13 +94,13 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
def add_noise(v): def add_noise(v):
return v + tf.random.normal( return v + tf.random.normal(
tf.shape(input=v), stddev=global_state.stddev) tf.shape(input=v), stddev=global_state.stddev, dtype=v.dtype)
else: else:
random_normal = tf.random_normal_initializer( random_normal = tf.random_normal_initializer(
stddev=global_state.stddev) stddev=global_state.stddev)
def add_noise(v): def add_noise(v):
return v + random_normal(tf.shape(input=v)) return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
if self._ledger: if self._ledger:
dependencies = [ dependencies = [