The previous version uses tf.nest.map_structure
to apply add_noise
to a tf.RaggedTensor
. This causes a bug when used in tensorflow federated because tf.nest.map_structure
will also map add_noise
to the tensor for shape information in tf.RaggedTensor
. This causes failure when tff conducts automatic type conversion.
Also use fixed random seed to avoid flaky timeouts and testing failures. PiperOrigin-RevId: 384573740
This commit is contained in:
parent
7f44b02456
commit
2cafe28d8d
2 changed files with 39 additions and 42 deletions
|
@ -25,10 +25,10 @@ what is defined in `dp_query.DPQuery`, is a histogram (i.e. the leaf nodes).
|
||||||
"""
|
"""
|
||||||
import distutils
|
import distutils
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
||||||
|
|
||||||
|
@ -442,16 +442,20 @@ def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
|
||||||
return tree
|
return tree
|
||||||
|
|
||||||
|
|
||||||
def _get_add_noise(stddev):
|
def _get_add_noise(stddev, seed: int = None):
|
||||||
"""Utility function to decide which `add_noise` to use according to tf version."""
|
"""Utility function to decide which `add_noise` to use according to tf version."""
|
||||||
if distutils.version.LooseVersion(
|
if distutils.version.LooseVersion(
|
||||||
tf.__version__) < distutils.version.LooseVersion('2.0.0'):
|
tf.__version__) < distutils.version.LooseVersion('2.0.0'):
|
||||||
|
|
||||||
|
# The seed should be only used for testing purpose.
|
||||||
|
if seed is not None:
|
||||||
|
tf.random.set_seed(seed)
|
||||||
|
|
||||||
def add_noise(v):
|
def add_noise(v):
|
||||||
return v + tf.random.normal(
|
return v + tf.random.normal(
|
||||||
tf.shape(input=v), stddev=stddev, dtype=v.dtype)
|
tf.shape(input=v), stddev=stddev, dtype=v.dtype)
|
||||||
else:
|
else:
|
||||||
random_normal = tf.random_normal_initializer(stddev=stddev)
|
random_normal = tf.random_normal_initializer(stddev=stddev, seed=seed)
|
||||||
|
|
||||||
def add_noise(v):
|
def add_noise(v):
|
||||||
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
|
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
|
||||||
|
@ -478,17 +482,16 @@ class CentralTreeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""Class defining global state for `CentralTreeSumQuery`.
|
"""Class defining global state for `CentralTreeSumQuery`.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
stddev: The stddev of the noise added to each node in the tree.
|
|
||||||
arity: The branching factor of the tree (i.e. the number of children each
|
|
||||||
internal node has).
|
|
||||||
l1_bound: An upper bound on the L1 norm of the input record. This is
|
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||||
needed to bound the sensitivity and deploy differential privacy.
|
needed to bound the sensitivity and deploy differential privacy.
|
||||||
"""
|
"""
|
||||||
stddev = attr.ib()
|
|
||||||
arity = attr.ib()
|
|
||||||
l1_bound = attr.ib()
|
l1_bound = attr.ib()
|
||||||
|
|
||||||
def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10):
|
def __init__(self,
|
||||||
|
stddev: float,
|
||||||
|
arity: int = 2,
|
||||||
|
l1_bound: int = 10,
|
||||||
|
seed: Optional[int] = None):
|
||||||
"""Initializes the `CentralTreeSumQuery`.
|
"""Initializes the `CentralTreeSumQuery`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -497,15 +500,17 @@ class CentralTreeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
arity: The branching factor of the tree.
|
arity: The branching factor of the tree.
|
||||||
l1_bound: An upper bound on the L1 norm of the input record. This is
|
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||||
needed to bound the sensitivity and deploy differential privacy.
|
needed to bound the sensitivity and deploy differential privacy.
|
||||||
|
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
|
||||||
|
test purpose.
|
||||||
"""
|
"""
|
||||||
self._stddev = stddev
|
self._stddev = stddev
|
||||||
self._arity = arity
|
self._arity = arity
|
||||||
self._l1_bound = l1_bound
|
self._l1_bound = l1_bound
|
||||||
|
self._seed = seed
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||||
return CentralTreeSumQuery.GlobalState(
|
return CentralTreeSumQuery.GlobalState(l1_bound=self._l1_bound)
|
||||||
stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound)
|
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||||
|
@ -536,10 +541,9 @@ class CentralTreeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
The jth node on the ith layer of the tree can be accessed by tree[i][j]
|
The jth node on the ith layer of the tree can be accessed by tree[i][j]
|
||||||
where tree is the returned value.
|
where tree is the returned value.
|
||||||
"""
|
"""
|
||||||
add_noise = _get_add_noise(self._stddev)
|
add_noise = _get_add_noise(self._stddev, self._seed)
|
||||||
tree = _build_tree_from_leaf(sample_state, global_state.arity)
|
tree = _build_tree_from_leaf(sample_state, self._arity)
|
||||||
return tf.nest.map_structure(
|
return tf.map_fn(add_noise, tree), global_state
|
||||||
add_noise, tree, expand_composites=True), global_state
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
|
class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
@ -577,7 +581,11 @@ class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
arity = attr.ib()
|
arity = attr.ib()
|
||||||
l1_bound = attr.ib()
|
l1_bound = attr.ib()
|
||||||
|
|
||||||
def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10):
|
def __init__(self,
|
||||||
|
stddev: float,
|
||||||
|
arity: int = 2,
|
||||||
|
l1_bound: int = 10,
|
||||||
|
seed: Optional[int] = None):
|
||||||
"""Initializes the `DistributedTreeSumQuery`.
|
"""Initializes the `DistributedTreeSumQuery`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -585,10 +593,13 @@ class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
arity: The branching factor of the tree.
|
arity: The branching factor of the tree.
|
||||||
l1_bound: An upper bound on the L1 norm of the input record. This is
|
l1_bound: An upper bound on the L1 norm of the input record. This is
|
||||||
needed to bound the sensitivity and deploy differential privacy.
|
needed to bound the sensitivity and deploy differential privacy.
|
||||||
|
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
|
||||||
|
test purpose.
|
||||||
"""
|
"""
|
||||||
self._stddev = stddev
|
self._stddev = stddev
|
||||||
self._arity = arity
|
self._arity = arity
|
||||||
self._l1_bound = l1_bound
|
self._l1_bound = l1_bound
|
||||||
|
self._seed = seed
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||||
|
@ -628,9 +639,9 @@ class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
use_norm=l1_norm)
|
use_norm=l1_norm)
|
||||||
preprocessed_record = preprocessed_record[0]
|
preprocessed_record = preprocessed_record[0]
|
||||||
|
|
||||||
add_noise = _get_add_noise(self._stddev)
|
add_noise = _get_add_noise(self._stddev, self._seed)
|
||||||
tree = _build_tree_from_leaf(preprocessed_record, arity)
|
tree = _build_tree_from_leaf(preprocessed_record, arity)
|
||||||
noisy_tree = tf.nest.map_structure(add_noise, tree, expand_composites=True)
|
noisy_tree = tf.map_fn(add_noise, tree)
|
||||||
|
|
||||||
# The following codes reshape the output vector so the output shape of can
|
# The following codes reshape the output vector so the output shape of can
|
||||||
# be statically inferred. This is useful when used with
|
# be statically inferred. This is useful when used with
|
||||||
|
|
|
@ -502,21 +502,15 @@ class CentralTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
||||||
)
|
)
|
||||||
def test_get_noised_result_with_noise(self, stddev, record, expected_tree):
|
def test_get_noised_result_with_noise(self, stddev, record, expected_tree):
|
||||||
query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev)
|
query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev, seed=0)
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
params = query.derive_sample_params(global_state)
|
params = query.derive_sample_params(global_state)
|
||||||
preprocessed_record = query.preprocess_record(params, record)
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
sample_state_list = []
|
|
||||||
for _ in range(1000):
|
|
||||||
sample_state, _ = query.get_noised_result(preprocessed_record,
|
|
||||||
global_state)
|
|
||||||
sample_state_list.append(sample_state.flat_values.numpy())
|
|
||||||
expectation = np.mean(sample_state_list, axis=0)
|
|
||||||
variance = np.std(sample_state_list, axis=0)
|
|
||||||
|
|
||||||
self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4)
|
sample_state, _ = query.get_noised_result(preprocessed_record, global_state)
|
||||||
|
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4)
|
sample_state.flat_values, expected_tree, atol=3 * stddev)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
|
@ -556,8 +550,7 @@ class DistributedTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
def test_derive_sample_params(self):
|
def test_derive_sample_params(self):
|
||||||
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
|
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
stddev, arity, l1_bound = query.derive_sample_params(
|
stddev, arity, l1_bound = query.derive_sample_params(global_state)
|
||||||
global_state)
|
|
||||||
self.assertAllClose(stddev, NOISE_STD)
|
self.assertAllClose(stddev, NOISE_STD)
|
||||||
self.assertAllClose(arity, 2)
|
self.assertAllClose(arity, 2)
|
||||||
self.assertAllClose(l1_bound, 10)
|
self.assertAllClose(l1_bound, 10)
|
||||||
|
@ -587,21 +580,14 @@ class DistributedTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
|
||||||
)
|
)
|
||||||
def test_preprocess_record_with_noise(self, stddev, record, expected_tree):
|
def test_preprocess_record_with_noise(self, stddev, record, expected_tree):
|
||||||
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=stddev)
|
query = tree_aggregation_query.DistributedTreeSumQuery(
|
||||||
|
stddev=stddev, seed=0)
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
params = query.derive_sample_params(global_state)
|
params = query.derive_sample_params(global_state)
|
||||||
|
|
||||||
preprocessed_record_list = []
|
|
||||||
for _ in range(1000):
|
|
||||||
preprocessed_record = query.preprocess_record(params, record)
|
preprocessed_record = query.preprocess_record(params, record)
|
||||||
preprocessed_record_list.append(preprocessed_record.numpy())
|
|
||||||
|
|
||||||
expectation = np.mean(preprocessed_record_list, axis=0)
|
self.assertAllClose(preprocessed_record, expected_tree, atol=3 * stddev)
|
||||||
variance = np.std(preprocessed_record_list, axis=0)
|
|
||||||
|
|
||||||
self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4)
|
|
||||||
self.assertAllClose(
|
|
||||||
variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4)
|
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
|
||||||
|
|
Loading…
Reference in a new issue