Update tensorflow privacy to use NamedTuple instead of attrs.

This allows these objects to be traversed when nested in tree-like structures more easily.

PiperOrigin-RevId: 525532511
This commit is contained in:
Michael Reneer 2023-04-19 13:17:52 -07:00 committed by A. Unique TensorFlower
parent e362f51773
commit 60cb0dd2fb
3 changed files with 48 additions and 39 deletions

View file

@ -21,11 +21,11 @@ module implements the core logic of tree aggregation in Tensorflow, which serves
as helper functions for `tree_aggregation_query`. This module and helper as helper functions for `tree_aggregation_query`. This module and helper
functions are publicly accessible. functions are publicly accessible.
""" """
import abc import abc
import collections import collections
from typing import Any, Callable, Collection, Optional, Tuple, Union from typing import Any, Callable, Collection, NamedTuple, Optional, Tuple, Union
import attr
import tensorflow as tf import tensorflow as tf
# TODO(b/192464750): find a proper place for the helper functions, privatize # TODO(b/192464750): find a proper place for the helper functions, privatize
@ -170,8 +170,7 @@ class StatelessValueGenerator(ValueGenerator):
return self.value_fn(), state return self.value_fn(), state
@attr.s(eq=False, frozen=True, slots=True) class TreeState(NamedTuple):
class TreeState(object):
"""Class defining state of the tree. """Class defining state of the tree.
Attributes: Attributes:
@ -183,9 +182,9 @@ class TreeState(object):
for the most recent leaf node. for the most recent leaf node.
value_generator_state: State of a stateful `ValueGenerator` for tree node. value_generator_state: State of a stateful `ValueGenerator` for tree node.
""" """
level_buffer = attr.ib(type=tf.Tensor) level_buffer: tf.Tensor
level_buffer_idx = attr.ib(type=tf.Tensor) level_buffer_idx: tf.Tensor
value_generator_state = attr.ib(type=Any) value_generator_state: Any
# TODO(b/192464750): move `get_step_idx` to be a property of `TreeState`. # TODO(b/192464750): move `get_step_idx` to be a property of `TreeState`.

View file

@ -33,9 +33,11 @@ corresponding epsilon for a `target_delta` and `noise_multiplier` to achieve
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0] eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
""" """
import attr from typing import Any, NamedTuple
import dp_accounting import dp_accounting
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation from tensorflow_privacy.privacy.dp_query import tree_aggregation
@ -84,8 +86,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
O(clip_norm*log(T)/eps) to guarantee eps-DP. O(clip_norm*log(T)/eps) to guarantee eps-DP.
""" """
@attr.s(frozen=True) class GlobalState(NamedTuple):
class GlobalState(object):
"""Class defining global state for Tree sum queries. """Class defining global state for Tree sum queries.
Attributes: Attributes:
@ -94,9 +95,9 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
clip_value: The clipping value to be passed to clip_fn. clip_value: The clipping value to be passed to clip_fn.
samples_cumulative_sum: Noiseless cumulative sum of samples over time. samples_cumulative_sum: Noiseless cumulative sum of samples over time.
""" """
tree_state = attr.ib() tree_state: Any
clip_value = attr.ib() clip_value: Any
samples_cumulative_sum = attr.ib() samples_cumulative_sum: Any
def __init__(self, def __init__(self,
record_specs, record_specs,
@ -182,10 +183,11 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
global_state.tree_state) global_state.tree_state)
noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum, noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
cumulative_sum_noise) cumulative_sum_noise)
new_global_state = attr.evolve( new_global_state = TreeCumulativeSumQuery.GlobalState(
global_state, tree_state=new_tree_state,
clip_value=global_state.clip_value,
samples_cumulative_sum=new_cumulative_sum, samples_cumulative_sum=new_cumulative_sum,
tree_state=new_tree_state) )
event = dp_accounting.UnsupportedDpEvent() event = dp_accounting.UnsupportedDpEvent()
return noised_cumulative_sum, new_global_state, event return noised_cumulative_sum, new_global_state, event
@ -206,10 +208,11 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
state for the next cumulative sum. state for the next cumulative sum.
""" """
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state) new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
return attr.evolve( return TreeCumulativeSumQuery.GlobalState(
global_state, tree_state=new_tree_state,
clip_value=global_state.clip_value,
samples_cumulative_sum=noised_results, samples_cumulative_sum=noised_results,
tree_state=new_tree_state) )
@classmethod @classmethod
def build_l2_gaussian_query(cls, def build_l2_gaussian_query(cls,
@ -312,8 +315,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
O(clip_norm*log(T)/eps) to guarantee eps-DP. O(clip_norm*log(T)/eps) to guarantee eps-DP.
""" """
@attr.s(frozen=True) class GlobalState(NamedTuple):
class GlobalState(object):
"""Class defining global state for Tree sum queries. """Class defining global state for Tree sum queries.
Attributes: Attributes:
@ -323,9 +325,9 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
previous_tree_noise: Cumulative noise by tree aggregation from the previous_tree_noise: Cumulative noise by tree aggregation from the
previous time the query is called on a sample. previous time the query is called on a sample.
""" """
tree_state = attr.ib() tree_state: Any
clip_value = attr.ib() clip_value: Any
previous_tree_noise = attr.ib() previous_tree_noise: Any
def __init__(self, def __init__(self,
record_specs, record_specs,
@ -426,8 +428,11 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c, noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c,
sample_state, tree_noise, sample_state, tree_noise,
global_state.previous_tree_noise) global_state.previous_tree_noise)
new_global_state = attr.evolve( new_global_state = TreeResidualSumQuery.GlobalState(
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state) tree_state=new_tree_state,
clip_value=global_state.clip_value,
previous_tree_noise=tree_noise,
)
event = dp_accounting.UnsupportedDpEvent() event = dp_accounting.UnsupportedDpEvent()
return noised_sample, new_global_state, event return noised_sample, new_global_state, event
@ -448,10 +453,11 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
""" """
del noised_results del noised_results
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state) new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
return attr.evolve( return TreeResidualSumQuery.GlobalState(
global_state, tree_state=new_tree_state,
clip_value=global_state.clip_value,
previous_tree_noise=self._zero_initial_noise(), previous_tree_noise=self._zero_initial_noise(),
tree_state=new_tree_state) )
def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev): def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev):
noise_generator_state = global_state.tree_state.value_generator_state noise_generator_state = global_state.tree_state.value_generator_state
@ -459,10 +465,16 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
tree_aggregation.GaussianNoiseGenerator) tree_aggregation.GaussianNoiseGenerator)
noise_generator_state = self._tree_aggregator.value_generator.make_state( noise_generator_state = self._tree_aggregator.value_generator.make_state(
noise_generator_state.seeds, stddev) noise_generator_state.seeds, stddev)
new_tree_state = attr.evolve( new_tree_state = tree_aggregation.TreeState(
global_state.tree_state, value_generator_state=noise_generator_state) level_buffer=global_state.tree_state.level_buffer,
return attr.evolve( level_buffer_idx=global_state.tree_state.level_buffer_idx,
global_state, clip_value=clip_norm, tree_state=new_tree_state) value_generator_state=noise_generator_state,
)
return TreeResidualSumQuery.GlobalState(
tree_state=new_tree_state,
clip_value=clip_norm,
previous_tree_noise=global_state.previous_tree_noise,
)
@classmethod @classmethod
def build_l2_gaussian_query(cls, def build_l2_gaussian_query(cls,

View file

@ -18,9 +18,8 @@
import distutils import distutils
import math import math
from typing import Optional from typing import Any, NamedTuple, Optional
import attr
import dp_accounting import dp_accounting
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
@ -102,8 +101,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
Improves efficiency and reduces noise scale. Improves efficiency and reduces noise scale.
""" """
@attr.s(frozen=True) class GlobalState(NamedTuple):
class GlobalState(object):
"""Class defining global state for TreeRangeSumQuery. """Class defining global state for TreeRangeSumQuery.
Attributes: Attributes:
@ -111,8 +109,8 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
internal node has). internal node has).
inner_query_state: The global state of the inner query. inner_query_state: The global state of the inner query.
""" """
arity = attr.ib() arity: Any
inner_query_state = attr.ib() inner_query_state: Any
def __init__(self, def __init__(self,
inner_query: dp_query.SumAggregationDPQuery, inner_query: dp_query.SumAggregationDPQuery,