The previous version uses tf.nest.map_structure to apply add_noise to a tf.RaggedTensor. This causes a bug when used in tensorflow federated because tf.nest.map_structure will also map add_noise to the tensor for shape information in tf.RaggedTensor. This causes failure when tff conducts automatic type conversion.

Also use fixed random seed to avoid flaky timeouts and testing failures.

PiperOrigin-RevId: 384573740
This commit is contained in:
A. Unique TensorFlower 2021-07-13 16:13:50 -07:00
parent 7f44b02456
commit 2cafe28d8d
2 changed files with 39 additions and 42 deletions

View file

@ -25,10 +25,10 @@ what is defined in `dp_query.DPQuery`, is a histogram (i.e. the leaf nodes).
"""
import distutils
import math
from typing import Optional
import attr
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation
@ -442,16 +442,20 @@ def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
return tree
def _get_add_noise(stddev):
def _get_add_noise(stddev, seed: int = None):
"""Utility function to decide which `add_noise` to use according to tf version."""
if distutils.version.LooseVersion(
tf.__version__) < distutils.version.LooseVersion('2.0.0'):
# The seed should be only used for testing purpose.
if seed is not None:
tf.random.set_seed(seed)
def add_noise(v):
return v + tf.random.normal(
tf.shape(input=v), stddev=stddev, dtype=v.dtype)
else:
random_normal = tf.random_normal_initializer(stddev=stddev)
random_normal = tf.random_normal_initializer(stddev=stddev, seed=seed)
def add_noise(v):
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
@ -478,17 +482,16 @@ class CentralTreeSumQuery(dp_query.SumAggregationDPQuery):
"""Class defining global state for `CentralTreeSumQuery`.
Attributes:
stddev: The stddev of the noise added to each node in the tree.
arity: The branching factor of the tree (i.e. the number of children each
internal node has).
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
"""
stddev = attr.ib()
arity = attr.ib()
l1_bound = attr.ib()
def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10):
def __init__(self,
stddev: float,
arity: int = 2,
l1_bound: int = 10,
seed: Optional[int] = None):
"""Initializes the `CentralTreeSumQuery`.
Args:
@ -497,15 +500,17 @@ class CentralTreeSumQuery(dp_query.SumAggregationDPQuery):
arity: The branching factor of the tree.
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
test purpose.
"""
self._stddev = stddev
self._arity = arity
self._l1_bound = l1_bound
self._seed = seed
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return CentralTreeSumQuery.GlobalState(
stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound)
return CentralTreeSumQuery.GlobalState(l1_bound=self._l1_bound)
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
@ -536,10 +541,9 @@ class CentralTreeSumQuery(dp_query.SumAggregationDPQuery):
The jth node on the ith layer of the tree can be accessed by tree[i][j]
where tree is the returned value.
"""
add_noise = _get_add_noise(self._stddev)
tree = _build_tree_from_leaf(sample_state, global_state.arity)
return tf.nest.map_structure(
add_noise, tree, expand_composites=True), global_state
add_noise = _get_add_noise(self._stddev, self._seed)
tree = _build_tree_from_leaf(sample_state, self._arity)
return tf.map_fn(add_noise, tree), global_state
class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
@ -577,7 +581,11 @@ class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
arity = attr.ib()
l1_bound = attr.ib()
def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10):
def __init__(self,
stddev: float,
arity: int = 2,
l1_bound: int = 10,
seed: Optional[int] = None):
"""Initializes the `DistributedTreeSumQuery`.
Args:
@ -585,10 +593,13 @@ class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
arity: The branching factor of the tree.
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
test purpose.
"""
self._stddev = stddev
self._arity = arity
self._l1_bound = l1_bound
self._seed = seed
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
@ -628,9 +639,9 @@ class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
use_norm=l1_norm)
preprocessed_record = preprocessed_record[0]
add_noise = _get_add_noise(self._stddev)
add_noise = _get_add_noise(self._stddev, self._seed)
tree = _build_tree_from_leaf(preprocessed_record, arity)
noisy_tree = tf.nest.map_structure(add_noise, tree, expand_composites=True)
noisy_tree = tf.map_fn(add_noise, tree)
# The following codes reshape the output vector so the output shape of can
# be statically inferred. This is useful when used with

View file

@ -502,21 +502,15 @@ class CentralTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
)
def test_get_noised_result_with_noise(self, stddev, record, expected_tree):
query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev)
query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev, seed=0)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state_list = []
for _ in range(1000):
sample_state, _ = query.get_noised_result(preprocessed_record,
global_state)
sample_state_list.append(sample_state.flat_values.numpy())
expectation = np.mean(sample_state_list, axis=0)
variance = np.std(sample_state_list, axis=0)
self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4)
sample_state, _ = query.get_noised_result(preprocessed_record, global_state)
self.assertAllClose(
variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4)
sample_state.flat_values, expected_tree, atol=3 * stddev)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
@ -556,8 +550,7 @@ class DistributedTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_derive_sample_params(self):
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
global_state = query.initial_global_state()
stddev, arity, l1_bound = query.derive_sample_params(
global_state)
stddev, arity, l1_bound = query.derive_sample_params(global_state)
self.assertAllClose(stddev, NOISE_STD)
self.assertAllClose(arity, 2)
self.assertAllClose(l1_bound, 10)
@ -587,21 +580,14 @@ class DistributedTreeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
)
def test_preprocess_record_with_noise(self, stddev, record, expected_tree):
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=stddev)
query = tree_aggregation_query.DistributedTreeSumQuery(
stddev=stddev, seed=0)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record_list = []
for _ in range(1000):
preprocessed_record = query.preprocess_record(params, record)
preprocessed_record_list.append(preprocessed_record.numpy())
preprocessed_record = query.preprocess_record(params, record)
expectation = np.mean(preprocessed_record_list, axis=0)
variance = np.std(preprocessed_record_list, axis=0)
self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4)
self.assertAllClose(
variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4)
self.assertAllClose(preprocessed_record, expected_tree, atol=3 * stddev)
@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),