The previous version uses tf.nest.map_structure
to apply add_noise
to a tf.RaggedTensor
. This causes a bug when used in tensorflow federated because tf.nest.map_structure
will also map add_noise
to the tensor for shape information in tf.RaggedTensor
. This causes failure when tff conducts automatic type conversion.
Also use fixed random seed to avoid flaky timeouts and testing failures. PiperOrigin-RevId: 384573740
This commit is contained in:
parent
7f44b02456
commit
2cafe28d8d
2 changed files with 39 additions and 42 deletions
|
@ -25,10 +25,10 @@ what is defined in `dp_query.DPQuery`, is a histogram (i.e. the leaf nodes).
|
|||
"""
|
||||
import distutils
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import attr
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
||||
|
||||
|
@ -442,16 +442,20 @@ def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
|
|||
return tree
|
||||
|
||||
|
||||
def _get_add_noise(stddev):
|
||||
def _get_add_noise(stddev, seed: int = None):
|
||||
"""Utility function to decide which `add_noise` to use according to tf version."""
|
||||
if distutils.version.LooseVersion(
|
||||
tf.__version__) < distutils.version.LooseVersion('2.0.0'):
|
||||
|
||||
# The seed should be only used for testing purpose.
|
||||
if seed is not None:
|
||||
tf.random.set_seed(seed)
|
||||
|
||||
def add_noise(v):
|
||||
return v + tf.random.normal(
|
||||
tf.shape(input=v), stddev=stddev, dtype=v.dtype)
|
||||
else:
|
||||
random_normal = tf.random_normal_initializer(stddev=stddev)
|
||||
random_normal = tf.random_normal_initializer(stddev=stddev, seed=seed)
|
||||
|
||||
def add_noise(v):
|
||||
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
|
||||
|
@ -478,17 +482,16 @@ class CentralTreeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
"""Class defining global state for `CentralTreeSumQuery`.
|
||||
|
||||
Attributes:
|
||||
stddev: The stddev of the noise added to each node in the tree.
|
||||
arity: The branching factor of the tree (i.e. the number of children each
|
||||
internal node has).
|
||||
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||
needed to bound the sensitivity and deploy differential privacy.
|
||||
"""
|
||||
stddev = attr.ib()
|
||||
arity = attr.ib()
|
||||
l1_bound = attr.ib()
|
||||
|
||||
def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10):
|
||||
def __init__(self,
|
||||
stddev: float,
|
||||
arity: int = 2,
|
||||
l1_bound: int = 10,
|
||||
seed: Optional[int] = None):
|
||||
"""Initializes the `CentralTreeSumQuery`.
|
||||
|
||||
Args:
|
||||
|
@ -497,15 +500,17 @@ class CentralTreeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
arity: The branching factor of the tree.
|
||||
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||
needed to bound the sensitivity and deploy differential privacy.
|
||||
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
|
||||
test purpose.
|
||||
"""
|
||||
self._stddev = stddev
|
||||
self._arity = arity
|
||||
self._l1_bound = l1_bound
|
||||
self._seed = seed
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
return CentralTreeSumQuery.GlobalState(
|
||||
stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound)
|
||||
return CentralTreeSumQuery.GlobalState(l1_bound=self._l1_bound)
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||
|
@ -536,10 +541,9 @@ class CentralTreeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
The jth node on the ith layer of the tree can be accessed by tree[i][j]
|
||||
where tree is the returned value.
|
||||
"""
|
||||
add_noise = _get_add_noise(self._stddev)
|
||||
tree = _build_tree_from_leaf(sample_state, global_state.arity)
|
||||
return tf.nest.map_structure(
|
||||
add_noise, tree, expand_composites=True), global_state
|
||||
add_noise = _get_add_noise(self._stddev, self._seed)
|
||||
tree = _build_tree_from_leaf(sample_state, self._arity)
|
||||
return tf.map_fn(add_noise, tree), global_state
|
||||
|
||||
|
||||
class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
|
||||
|
@ -577,7 +581,11 @@ class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
arity = attr.ib()
|
||||
l1_bound = attr.ib()
|
||||
|
||||
def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10):
|
||||
def __init__(self,
|
||||
stddev: float,
|
||||
arity: int = 2,
|
||||
l1_bound: int = 10,
|
||||
seed: Optional[int] = None):
|
||||
"""Initializes the `DistributedTreeSumQuery`.
|
||||
|
||||
Args:
|
||||
|
@ -585,10 +593,13 @@ class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
arity: The branching factor of the tree.
|
||||
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||
needed to bound the sensitivity and deploy differential privacy.
|
||||
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
|
||||
test purpose.
|
||||
"""
|
||||
self._stddev = stddev
|
||||
self._arity = arity
|
||||
self._l1_bound = l1_bound
|
||||
self._seed = seed
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||
|
@ -628,9 +639,9 @@ class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
|
|||
use_norm=l1_norm)
|
||||
preprocessed_record = preprocessed_record[0]
|
||||
|
||||
add_noise = _get_add_noise(self._stddev)
|
||||
add_noise = _get_add_noise(self._stddev, self._seed)
|
||||
tree = _build_tree_from_leaf(preprocessed_record, arity)
|
||||
noisy_tree = tf.nest.map_structure(add_noise, tree, expand_composites=True)
|
||||
noisy_tree = tf.map_fn(add_noise, tree)
|
||||
|
||||
# The following codes reshape the output vector so the output shape of can
|
||||
# be statically inferred. This is useful when used with
|
||||
|
|
|
@ -502,21 +502,15 @@ class CentralTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
||||
)
|
||||
def test_get_noised_result_with_noise(self, stddev, record, expected_tree):
|
||||
query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev)
|
||||
query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev, seed=0)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
sample_state_list = []
|
||||
for _ in range(1000):
|
||||
sample_state, _ = query.get_noised_result(preprocessed_record,
|
||||
global_state)
|
||||
sample_state_list.append(sample_state.flat_values.numpy())
|
||||
expectation = np.mean(sample_state_list, axis=0)
|
||||
variance = np.std(sample_state_list, axis=0)
|
||||
|
||||
self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4)
|
||||
sample_state, _ = query.get_noised_result(preprocessed_record, global_state)
|
||||
|
||||
self.assertAllClose(
|
||||
variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4)
|
||||
sample_state.flat_values, expected_tree, atol=3 * stddev)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||
|
@ -556,8 +550,7 @@ class DistributedTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
def test_derive_sample_params(self):
|
||||
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
|
||||
global_state = query.initial_global_state()
|
||||
stddev, arity, l1_bound = query.derive_sample_params(
|
||||
global_state)
|
||||
stddev, arity, l1_bound = query.derive_sample_params(global_state)
|
||||
self.assertAllClose(stddev, NOISE_STD)
|
||||
self.assertAllClose(arity, 2)
|
||||
self.assertAllClose(l1_bound, 10)
|
||||
|
@ -587,21 +580,14 @@ class DistributedTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
||||
)
|
||||
def test_preprocess_record_with_noise(self, stddev, record, expected_tree):
|
||||
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=stddev)
|
||||
query = tree_aggregation_query.DistributedTreeSumQuery(
|
||||
stddev=stddev, seed=0)
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
|
||||
preprocessed_record_list = []
|
||||
for _ in range(1000):
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
preprocessed_record_list.append(preprocessed_record.numpy())
|
||||
preprocessed_record = query.preprocess_record(params, record)
|
||||
|
||||
expectation = np.mean(preprocessed_record_list, axis=0)
|
||||
variance = np.std(preprocessed_record_list, axis=0)
|
||||
|
||||
self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4)
|
||||
self.assertAllClose(
|
||||
variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4)
|
||||
self.assertAllClose(preprocessed_record, expected_tree, atol=3 * stddev)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||
|
|
Loading…
Reference in a new issue