Relaxes dtype assumption in Gaussian DP sum query.
PiperOrigin-RevId: 307846823
This commit is contained in:
parent
c5c807807f
commit
463868e796
1 changed files with 2 additions and 2 deletions
|
@ -94,13 +94,13 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||||
def add_noise(v):
|
def add_noise(v):
|
||||||
return v + tf.random.normal(
|
return v + tf.random.normal(
|
||||||
tf.shape(input=v), stddev=global_state.stddev)
|
tf.shape(input=v), stddev=global_state.stddev, dtype=v.dtype)
|
||||||
else:
|
else:
|
||||||
random_normal = tf.random_normal_initializer(
|
random_normal = tf.random_normal_initializer(
|
||||||
stddev=global_state.stddev)
|
stddev=global_state.stddev)
|
||||||
|
|
||||||
def add_noise(v):
|
def add_noise(v):
|
||||||
return v + random_normal(tf.shape(input=v))
|
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
|
||||||
|
|
||||||
if self._ledger:
|
if self._ledger:
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
|
Loading…
Reference in a new issue