forked from 626_privacy/tensorflow_privacy
Cast to ensure record of NoPrivacyAverageQuery is float for compatibility with sample_state.
PiperOrigin-RevId: 249909614
This commit is contained in:
parent
15c07250a1
commit
7636945566
1 changed files with 3 additions and 3 deletions
|
@ -55,17 +55,17 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
|
|||
def preprocess_record(self, params, record, weight=1):
|
||||
"""Multiplies record by weight."""
|
||||
weighted_record = nest.map_structure(lambda t: weight * t, record)
|
||||
return (weighted_record, weight)
|
||||
return (weighted_record, tf.cast(weight, tf.float32))
|
||||
|
||||
def accumulate_record(self, params, sample_state, record, weight=1):
|
||||
"""Accumulates record, multiplying by weight."""
|
||||
weighted_record = nest.map_structure(lambda t: weight * t, record)
|
||||
return self.accumulate_preprocessed_record(
|
||||
sample_state, (weighted_record, weight))
|
||||
sample_state, (weighted_record, tf.cast(weight, tf.float32)))
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""See base class."""
|
||||
sum_state, denominator = sample_state
|
||||
|
||||
return nest.map_structure(
|
||||
lambda t: tf.truediv(t, denominator), sum_state), ()
|
||||
lambda t: tf.truediv(t, denominator), sum_state), global_state
|
||||
|
|
Loading…
Reference in a new issue