Update tensorflow privacy to use NamedTuple
instead of attrs
.
This allows these objects to be traversed when nested in tree-like structures more easily. PiperOrigin-RevId: 525532511
This commit is contained in:
parent
e362f51773
commit
60cb0dd2fb
3 changed files with 48 additions and 39 deletions
|
@ -21,11 +21,11 @@ module implements the core logic of tree aggregation in Tensorflow, which serves
|
||||||
as helper functions for `tree_aggregation_query`. This module and helper
|
as helper functions for `tree_aggregation_query`. This module and helper
|
||||||
functions are publicly accessible.
|
functions are publicly accessible.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import collections
|
import collections
|
||||||
from typing import Any, Callable, Collection, Optional, Tuple, Union
|
from typing import Any, Callable, Collection, NamedTuple, Optional, Tuple, Union
|
||||||
|
|
||||||
import attr
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
# TODO(b/192464750): find a proper place for the helper functions, privatize
|
# TODO(b/192464750): find a proper place for the helper functions, privatize
|
||||||
|
@ -170,8 +170,7 @@ class StatelessValueGenerator(ValueGenerator):
|
||||||
return self.value_fn(), state
|
return self.value_fn(), state
|
||||||
|
|
||||||
|
|
||||||
@attr.s(eq=False, frozen=True, slots=True)
|
class TreeState(NamedTuple):
|
||||||
class TreeState(object):
|
|
||||||
"""Class defining state of the tree.
|
"""Class defining state of the tree.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -183,9 +182,9 @@ class TreeState(object):
|
||||||
for the most recent leaf node.
|
for the most recent leaf node.
|
||||||
value_generator_state: State of a stateful `ValueGenerator` for tree node.
|
value_generator_state: State of a stateful `ValueGenerator` for tree node.
|
||||||
"""
|
"""
|
||||||
level_buffer = attr.ib(type=tf.Tensor)
|
level_buffer: tf.Tensor
|
||||||
level_buffer_idx = attr.ib(type=tf.Tensor)
|
level_buffer_idx: tf.Tensor
|
||||||
value_generator_state = attr.ib(type=Any)
|
value_generator_state: Any
|
||||||
|
|
||||||
|
|
||||||
# TODO(b/192464750): move `get_step_idx` to be a property of `TreeState`.
|
# TODO(b/192464750): move `get_step_idx` to be a property of `TreeState`.
|
||||||
|
|
|
@ -33,9 +33,11 @@ corresponding epsilon for a `target_delta` and `noise_multiplier` to achieve
|
||||||
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
|
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import attr
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
import dp_accounting
|
import dp_accounting
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
||||||
|
|
||||||
|
@ -84,8 +86,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
O(clip_norm*log(T)/eps) to guarantee eps-DP.
|
O(clip_norm*log(T)/eps) to guarantee eps-DP.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@attr.s(frozen=True)
|
class GlobalState(NamedTuple):
|
||||||
class GlobalState(object):
|
|
||||||
"""Class defining global state for Tree sum queries.
|
"""Class defining global state for Tree sum queries.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -94,9 +95,9 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
clip_value: The clipping value to be passed to clip_fn.
|
clip_value: The clipping value to be passed to clip_fn.
|
||||||
samples_cumulative_sum: Noiseless cumulative sum of samples over time.
|
samples_cumulative_sum: Noiseless cumulative sum of samples over time.
|
||||||
"""
|
"""
|
||||||
tree_state = attr.ib()
|
tree_state: Any
|
||||||
clip_value = attr.ib()
|
clip_value: Any
|
||||||
samples_cumulative_sum = attr.ib()
|
samples_cumulative_sum: Any
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
record_specs,
|
record_specs,
|
||||||
|
@ -182,10 +183,11 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
global_state.tree_state)
|
global_state.tree_state)
|
||||||
noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
|
noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
|
||||||
cumulative_sum_noise)
|
cumulative_sum_noise)
|
||||||
new_global_state = attr.evolve(
|
new_global_state = TreeCumulativeSumQuery.GlobalState(
|
||||||
global_state,
|
tree_state=new_tree_state,
|
||||||
|
clip_value=global_state.clip_value,
|
||||||
samples_cumulative_sum=new_cumulative_sum,
|
samples_cumulative_sum=new_cumulative_sum,
|
||||||
tree_state=new_tree_state)
|
)
|
||||||
event = dp_accounting.UnsupportedDpEvent()
|
event = dp_accounting.UnsupportedDpEvent()
|
||||||
return noised_cumulative_sum, new_global_state, event
|
return noised_cumulative_sum, new_global_state, event
|
||||||
|
|
||||||
|
@ -206,10 +208,11 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
state for the next cumulative sum.
|
state for the next cumulative sum.
|
||||||
"""
|
"""
|
||||||
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
|
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
|
||||||
return attr.evolve(
|
return TreeCumulativeSumQuery.GlobalState(
|
||||||
global_state,
|
tree_state=new_tree_state,
|
||||||
|
clip_value=global_state.clip_value,
|
||||||
samples_cumulative_sum=noised_results,
|
samples_cumulative_sum=noised_results,
|
||||||
tree_state=new_tree_state)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_l2_gaussian_query(cls,
|
def build_l2_gaussian_query(cls,
|
||||||
|
@ -312,8 +315,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
O(clip_norm*log(T)/eps) to guarantee eps-DP.
|
O(clip_norm*log(T)/eps) to guarantee eps-DP.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@attr.s(frozen=True)
|
class GlobalState(NamedTuple):
|
||||||
class GlobalState(object):
|
|
||||||
"""Class defining global state for Tree sum queries.
|
"""Class defining global state for Tree sum queries.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -323,9 +325,9 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
previous_tree_noise: Cumulative noise by tree aggregation from the
|
previous_tree_noise: Cumulative noise by tree aggregation from the
|
||||||
previous time the query is called on a sample.
|
previous time the query is called on a sample.
|
||||||
"""
|
"""
|
||||||
tree_state = attr.ib()
|
tree_state: Any
|
||||||
clip_value = attr.ib()
|
clip_value: Any
|
||||||
previous_tree_noise = attr.ib()
|
previous_tree_noise: Any
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
record_specs,
|
record_specs,
|
||||||
|
@ -426,8 +428,11 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c,
|
noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c,
|
||||||
sample_state, tree_noise,
|
sample_state, tree_noise,
|
||||||
global_state.previous_tree_noise)
|
global_state.previous_tree_noise)
|
||||||
new_global_state = attr.evolve(
|
new_global_state = TreeResidualSumQuery.GlobalState(
|
||||||
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
|
tree_state=new_tree_state,
|
||||||
|
clip_value=global_state.clip_value,
|
||||||
|
previous_tree_noise=tree_noise,
|
||||||
|
)
|
||||||
event = dp_accounting.UnsupportedDpEvent()
|
event = dp_accounting.UnsupportedDpEvent()
|
||||||
return noised_sample, new_global_state, event
|
return noised_sample, new_global_state, event
|
||||||
|
|
||||||
|
@ -448,10 +453,11 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""
|
"""
|
||||||
del noised_results
|
del noised_results
|
||||||
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
|
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
|
||||||
return attr.evolve(
|
return TreeResidualSumQuery.GlobalState(
|
||||||
global_state,
|
tree_state=new_tree_state,
|
||||||
|
clip_value=global_state.clip_value,
|
||||||
previous_tree_noise=self._zero_initial_noise(),
|
previous_tree_noise=self._zero_initial_noise(),
|
||||||
tree_state=new_tree_state)
|
)
|
||||||
|
|
||||||
def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev):
|
def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev):
|
||||||
noise_generator_state = global_state.tree_state.value_generator_state
|
noise_generator_state = global_state.tree_state.value_generator_state
|
||||||
|
@ -459,10 +465,16 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
tree_aggregation.GaussianNoiseGenerator)
|
tree_aggregation.GaussianNoiseGenerator)
|
||||||
noise_generator_state = self._tree_aggregator.value_generator.make_state(
|
noise_generator_state = self._tree_aggregator.value_generator.make_state(
|
||||||
noise_generator_state.seeds, stddev)
|
noise_generator_state.seeds, stddev)
|
||||||
new_tree_state = attr.evolve(
|
new_tree_state = tree_aggregation.TreeState(
|
||||||
global_state.tree_state, value_generator_state=noise_generator_state)
|
level_buffer=global_state.tree_state.level_buffer,
|
||||||
return attr.evolve(
|
level_buffer_idx=global_state.tree_state.level_buffer_idx,
|
||||||
global_state, clip_value=clip_norm, tree_state=new_tree_state)
|
value_generator_state=noise_generator_state,
|
||||||
|
)
|
||||||
|
return TreeResidualSumQuery.GlobalState(
|
||||||
|
tree_state=new_tree_state,
|
||||||
|
clip_value=clip_norm,
|
||||||
|
previous_tree_noise=global_state.previous_tree_noise,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_l2_gaussian_query(cls,
|
def build_l2_gaussian_query(cls,
|
||||||
|
|
|
@ -18,9 +18,8 @@
|
||||||
|
|
||||||
import distutils
|
import distutils
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
from typing import Any, NamedTuple, Optional
|
||||||
|
|
||||||
import attr
|
|
||||||
import dp_accounting
|
import dp_accounting
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
|
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
|
||||||
|
@ -102,8 +101,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
Improves efficiency and reduces noise scale.
|
Improves efficiency and reduces noise scale.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@attr.s(frozen=True)
|
class GlobalState(NamedTuple):
|
||||||
class GlobalState(object):
|
|
||||||
"""Class defining global state for TreeRangeSumQuery.
|
"""Class defining global state for TreeRangeSumQuery.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -111,8 +109,8 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
internal node has).
|
internal node has).
|
||||||
inner_query_state: The global state of the inner query.
|
inner_query_state: The global state of the inner query.
|
||||||
"""
|
"""
|
||||||
arity = attr.ib()
|
arity: Any
|
||||||
inner_query_state = attr.ib()
|
inner_query_state: Any
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
inner_query: dp_query.SumAggregationDPQuery,
|
inner_query: dp_query.SumAggregationDPQuery,
|
||||||
|
|
Loading…
Reference in a new issue