Relaxes dtype assumption in Gaussian DP sum query.
PiperOrigin-RevId: 307846823
This commit is contained in:
parent
c5c807807f
commit
463868e796
1 changed files with 2 additions and 2 deletions
|
@ -94,13 +94,13 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
|||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
def add_noise(v):
|
||||
return v + tf.random.normal(
|
||||
tf.shape(input=v), stddev=global_state.stddev)
|
||||
tf.shape(input=v), stddev=global_state.stddev, dtype=v.dtype)
|
||||
else:
|
||||
random_normal = tf.random_normal_initializer(
|
||||
stddev=global_state.stddev)
|
||||
|
||||
def add_noise(v):
|
||||
return v + random_normal(tf.shape(input=v))
|
||||
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
|
||||
|
||||
if self._ledger:
|
||||
dependencies = [
|
||||
|
|
Loading…
Reference in a new issue