Update tensorflow privacy to use NamedTuple
instead of attrs
.
This allows these objects to be traversed when nested in tree-like structures more easily. PiperOrigin-RevId: 525532511
This commit is contained in:
parent
e362f51773
commit
60cb0dd2fb
3 changed files with 48 additions and 39 deletions
|
@ -21,11 +21,11 @@ module implements the core logic of tree aggregation in Tensorflow, which serves
|
|||
as helper functions for `tree_aggregation_query`. This module and helper
|
||||
functions are publicly accessible.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import collections
|
||||
from typing import Any, Callable, Collection, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Collection, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
import tensorflow as tf
|
||||
|
||||
# TODO(b/192464750): find a proper place for the helper functions, privatize
|
||||
|
@ -170,8 +170,7 @@ class StatelessValueGenerator(ValueGenerator):
|
|||
return self.value_fn(), state
|
||||
|
||||
|
||||
@attr.s(eq=False, frozen=True, slots=True)
|
||||
class TreeState(object):
|
||||
class TreeState(NamedTuple):
|
||||
"""Class defining state of the tree.
|
||||
|
||||
Attributes:
|
||||
|
@ -183,9 +182,9 @@ class TreeState(object):
|
|||
for the most recent leaf node.
|
||||
value_generator_state: State of a stateful `ValueGenerator` for tree node.
|
||||
"""
|
||||
level_buffer = attr.ib(type=tf.Tensor)
|
||||
level_buffer_idx = attr.ib(type=tf.Tensor)
|
||||
value_generator_state = attr.ib(type=Any)
|
||||
level_buffer: tf.Tensor
|
||||
level_buffer_idx: tf.Tensor
|
||||
value_generator_state: Any
|
||||
|
||||
|
||||
# TODO(b/192464750): move `get_step_idx` to be a property of `TreeState`.
|
||||
|
|
|
@ -33,9 +33,11 @@ corresponding epsilon for a `target_delta` and `noise_multiplier` to achieve
|
|||
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
|
||||
"""
|
||||
|
||||
import attr
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import dp_accounting
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
||||
|
||||
|
@ -84,8 +86,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
O(clip_norm*log(T)/eps) to guarantee eps-DP.
|
||||
"""
|
||||
|
||||
@attr.s(frozen=True)
|
||||
class GlobalState(object):
|
||||
class GlobalState(NamedTuple):
|
||||
"""Class defining global state for Tree sum queries.
|
||||
|
||||
Attributes:
|
||||
|
@ -94,9 +95,9 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
clip_value: The clipping value to be passed to clip_fn.
|
||||
samples_cumulative_sum: Noiseless cumulative sum of samples over time.
|
||||
"""
|
||||
tree_state = attr.ib()
|
||||
clip_value = attr.ib()
|
||||
samples_cumulative_sum = attr.ib()
|
||||
tree_state: Any
|
||||
clip_value: Any
|
||||
samples_cumulative_sum: Any
|
||||
|
||||
def __init__(self,
|
||||
record_specs,
|
||||
|
@ -182,10 +183,11 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
global_state.tree_state)
|
||||
noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
|
||||
cumulative_sum_noise)
|
||||
new_global_state = attr.evolve(
|
||||
global_state,
|
||||
new_global_state = TreeCumulativeSumQuery.GlobalState(
|
||||
tree_state=new_tree_state,
|
||||
clip_value=global_state.clip_value,
|
||||
samples_cumulative_sum=new_cumulative_sum,
|
||||
tree_state=new_tree_state)
|
||||
)
|
||||
event = dp_accounting.UnsupportedDpEvent()
|
||||
return noised_cumulative_sum, new_global_state, event
|
||||
|
||||
|
@ -206,10 +208,11 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
state for the next cumulative sum.
|
||||
"""
|
||||
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
|
||||
return attr.evolve(
|
||||
global_state,
|
||||
return TreeCumulativeSumQuery.GlobalState(
|
||||
tree_state=new_tree_state,
|
||||
clip_value=global_state.clip_value,
|
||||
samples_cumulative_sum=noised_results,
|
||||
tree_state=new_tree_state)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_l2_gaussian_query(cls,
|
||||
|
@ -312,8 +315,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
O(clip_norm*log(T)/eps) to guarantee eps-DP.
|
||||
"""
|
||||
|
||||
@attr.s(frozen=True)
|
||||
class GlobalState(object):
|
||||
class GlobalState(NamedTuple):
|
||||
"""Class defining global state for Tree sum queries.
|
||||
|
||||
Attributes:
|
||||
|
@ -323,9 +325,9 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
previous_tree_noise: Cumulative noise by tree aggregation from the
|
||||
previous time the query is called on a sample.
|
||||
"""
|
||||
tree_state = attr.ib()
|
||||
clip_value = attr.ib()
|
||||
previous_tree_noise = attr.ib()
|
||||
tree_state: Any
|
||||
clip_value: Any
|
||||
previous_tree_noise: Any
|
||||
|
||||
def __init__(self,
|
||||
record_specs,
|
||||
|
@ -426,8 +428,11 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c,
|
||||
sample_state, tree_noise,
|
||||
global_state.previous_tree_noise)
|
||||
new_global_state = attr.evolve(
|
||||
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
|
||||
new_global_state = TreeResidualSumQuery.GlobalState(
|
||||
tree_state=new_tree_state,
|
||||
clip_value=global_state.clip_value,
|
||||
previous_tree_noise=tree_noise,
|
||||
)
|
||||
event = dp_accounting.UnsupportedDpEvent()
|
||||
return noised_sample, new_global_state, event
|
||||
|
||||
|
@ -448,10 +453,11 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
"""
|
||||
del noised_results
|
||||
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
|
||||
return attr.evolve(
|
||||
global_state,
|
||||
return TreeResidualSumQuery.GlobalState(
|
||||
tree_state=new_tree_state,
|
||||
clip_value=global_state.clip_value,
|
||||
previous_tree_noise=self._zero_initial_noise(),
|
||||
tree_state=new_tree_state)
|
||||
)
|
||||
|
||||
def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev):
|
||||
noise_generator_state = global_state.tree_state.value_generator_state
|
||||
|
@ -459,10 +465,16 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
tree_aggregation.GaussianNoiseGenerator)
|
||||
noise_generator_state = self._tree_aggregator.value_generator.make_state(
|
||||
noise_generator_state.seeds, stddev)
|
||||
new_tree_state = attr.evolve(
|
||||
global_state.tree_state, value_generator_state=noise_generator_state)
|
||||
return attr.evolve(
|
||||
global_state, clip_value=clip_norm, tree_state=new_tree_state)
|
||||
new_tree_state = tree_aggregation.TreeState(
|
||||
level_buffer=global_state.tree_state.level_buffer,
|
||||
level_buffer_idx=global_state.tree_state.level_buffer_idx,
|
||||
value_generator_state=noise_generator_state,
|
||||
)
|
||||
return TreeResidualSumQuery.GlobalState(
|
||||
tree_state=new_tree_state,
|
||||
clip_value=clip_norm,
|
||||
previous_tree_noise=global_state.previous_tree_noise,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_l2_gaussian_query(cls,
|
||||
|
|
|
@ -18,9 +18,8 @@
|
|||
|
||||
import distutils
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import Any, NamedTuple, Optional
|
||||
|
||||
import attr
|
||||
import dp_accounting
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
|
||||
|
@ -102,8 +101,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
Improves efficiency and reduces noise scale.
|
||||
"""
|
||||
|
||||
@attr.s(frozen=True)
|
||||
class GlobalState(object):
|
||||
class GlobalState(NamedTuple):
|
||||
"""Class defining global state for TreeRangeSumQuery.
|
||||
|
||||
Attributes:
|
||||
|
@ -111,8 +109,8 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
internal node has).
|
||||
inner_query_state: The global state of the inner query.
|
||||
"""
|
||||
arity = attr.ib()
|
||||
inner_query_state = attr.ib()
|
||||
arity: Any
|
||||
inner_query_state: Any
|
||||
|
||||
def __init__(self,
|
||||
inner_query: dp_query.SumAggregationDPQuery,
|
||||
|
|
Loading…
Reference in a new issue