forked from 626_privacy/tensorflow_privacy
parent
5433436b86
commit
5ad8676d38
1 changed files with 2 additions and 6 deletions
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
|||
|
||||
import collections
|
||||
import distutils
|
||||
import numbers
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
@ -56,11 +55,8 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
|||
|
||||
def make_global_state(self, l2_norm_clip, stddev):
|
||||
"""Creates a global state from the given parameters."""
|
||||
l2_norm_clip = float(l2_norm_clip) if isinstance(
|
||||
l2_norm_clip, numbers.Number) else tf.cast(l2_norm_clip, tf.float32)
|
||||
stddev = float(stddev) if isinstance(stddev, numbers.Number) else tf.cast(
|
||||
stddev, tf.float32)
|
||||
return self._GlobalState(l2_norm_clip, stddev)
|
||||
return self._GlobalState(tf.cast(l2_norm_clip, tf.float32),
|
||||
tf.cast(stddev, tf.float32))
|
||||
|
||||
def initial_global_state(self):
|
||||
return self.make_global_state(self._l2_norm_clip, self._stddev)
|
||||
|
|
Loading…
Reference in a new issue