Update tensorflow privacy to use NamedTuple instead of attrs.

This allows these objects to be traversed when nested in tree-like structures more easily.

PiperOrigin-RevId: 525532511
This commit is contained in:
Michael Reneer 2023-04-19 13:17:52 -07:00 committed by A. Unique TensorFlower
parent e362f51773
commit 60cb0dd2fb
3 changed files with 48 additions and 39 deletions

View file

@ -21,11 +21,11 @@ module implements the core logic of tree aggregation in Tensorflow, which serves
as helper functions for `tree_aggregation_query`. This module and helper
functions are publicly accessible.
"""
import abc
import collections
from typing import Any, Callable, Collection, Optional, Tuple, Union
from typing import Any, Callable, Collection, NamedTuple, Optional, Tuple, Union
import attr
import tensorflow as tf
# TODO(b/192464750): find a proper place for the helper functions, privatize
@ -170,8 +170,7 @@ class StatelessValueGenerator(ValueGenerator):
return self.value_fn(), state
@attr.s(eq=False, frozen=True, slots=True)
class TreeState(object):
class TreeState(NamedTuple):
"""Class defining state of the tree.
Attributes:
@ -183,9 +182,9 @@ class TreeState(object):
for the most recent leaf node.
value_generator_state: State of a stateful `ValueGenerator` for tree node.
"""
level_buffer = attr.ib(type=tf.Tensor)
level_buffer_idx = attr.ib(type=tf.Tensor)
value_generator_state = attr.ib(type=Any)
level_buffer: tf.Tensor
level_buffer_idx: tf.Tensor
value_generator_state: Any
# TODO(b/192464750): move `get_step_idx` to be a property of `TreeState`.

View file

@ -33,9 +33,11 @@ corresponding epsilon for a `target_delta` and `noise_multiplier` to achieve
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
"""
import attr
from typing import Any, NamedTuple
import dp_accounting
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation
@ -84,8 +86,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
O(clip_norm*log(T)/eps) to guarantee eps-DP.
"""
@attr.s(frozen=True)
class GlobalState(object):
class GlobalState(NamedTuple):
"""Class defining global state for Tree sum queries.
Attributes:
@ -94,9 +95,9 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
clip_value: The clipping value to be passed to clip_fn.
samples_cumulative_sum: Noiseless cumulative sum of samples over time.
"""
tree_state = attr.ib()
clip_value = attr.ib()
samples_cumulative_sum = attr.ib()
tree_state: Any
clip_value: Any
samples_cumulative_sum: Any
def __init__(self,
record_specs,
@ -182,10 +183,11 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
global_state.tree_state)
noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
cumulative_sum_noise)
new_global_state = attr.evolve(
global_state,
new_global_state = TreeCumulativeSumQuery.GlobalState(
tree_state=new_tree_state,
clip_value=global_state.clip_value,
samples_cumulative_sum=new_cumulative_sum,
tree_state=new_tree_state)
)
event = dp_accounting.UnsupportedDpEvent()
return noised_cumulative_sum, new_global_state, event
@ -206,10 +208,11 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
state for the next cumulative sum.
"""
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
return attr.evolve(
global_state,
return TreeCumulativeSumQuery.GlobalState(
tree_state=new_tree_state,
clip_value=global_state.clip_value,
samples_cumulative_sum=noised_results,
tree_state=new_tree_state)
)
@classmethod
def build_l2_gaussian_query(cls,
@ -312,8 +315,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
O(clip_norm*log(T)/eps) to guarantee eps-DP.
"""
@attr.s(frozen=True)
class GlobalState(object):
class GlobalState(NamedTuple):
"""Class defining global state for Tree sum queries.
Attributes:
@ -323,9 +325,9 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
previous_tree_noise: Cumulative noise by tree aggregation from the
previous time the query is called on a sample.
"""
tree_state = attr.ib()
clip_value = attr.ib()
previous_tree_noise = attr.ib()
tree_state: Any
clip_value: Any
previous_tree_noise: Any
def __init__(self,
record_specs,
@ -426,8 +428,11 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c,
sample_state, tree_noise,
global_state.previous_tree_noise)
new_global_state = attr.evolve(
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
new_global_state = TreeResidualSumQuery.GlobalState(
tree_state=new_tree_state,
clip_value=global_state.clip_value,
previous_tree_noise=tree_noise,
)
event = dp_accounting.UnsupportedDpEvent()
return noised_sample, new_global_state, event
@ -448,10 +453,11 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
"""
del noised_results
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
return attr.evolve(
global_state,
return TreeResidualSumQuery.GlobalState(
tree_state=new_tree_state,
clip_value=global_state.clip_value,
previous_tree_noise=self._zero_initial_noise(),
tree_state=new_tree_state)
)
def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev):
noise_generator_state = global_state.tree_state.value_generator_state
@ -459,10 +465,16 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
tree_aggregation.GaussianNoiseGenerator)
noise_generator_state = self._tree_aggregator.value_generator.make_state(
noise_generator_state.seeds, stddev)
new_tree_state = attr.evolve(
global_state.tree_state, value_generator_state=noise_generator_state)
return attr.evolve(
global_state, clip_value=clip_norm, tree_state=new_tree_state)
new_tree_state = tree_aggregation.TreeState(
level_buffer=global_state.tree_state.level_buffer,
level_buffer_idx=global_state.tree_state.level_buffer_idx,
value_generator_state=noise_generator_state,
)
return TreeResidualSumQuery.GlobalState(
tree_state=new_tree_state,
clip_value=clip_norm,
previous_tree_noise=global_state.previous_tree_noise,
)
@classmethod
def build_l2_gaussian_query(cls,

View file

@ -18,9 +18,8 @@
import distutils
import math
from typing import Optional
from typing import Any, NamedTuple, Optional
import attr
import dp_accounting
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
@ -102,8 +101,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
Improves efficiency and reduces noise scale.
"""
@attr.s(frozen=True)
class GlobalState(object):
class GlobalState(NamedTuple):
"""Class defining global state for TreeRangeSumQuery.
Attributes:
@ -111,8 +109,8 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
internal node has).
inner_query_state: The global state of the inner query.
"""
arity = attr.ib()
inner_query_state = attr.ib()
arity: Any
inner_query_state: Any
def __init__(self,
inner_query: dp_query.SumAggregationDPQuery,