Adds DistributedSkellamQuery
to public TF Privacy.
PiperOrigin-RevId: 429664212
This commit is contained in:
parent
ffc29e1d82
commit
89de03e0db
3 changed files with 355 additions and 0 deletions
|
@ -75,6 +75,28 @@ py_test(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "distributed_skellam_query",
|
||||
srcs = ["distributed_skellam_query.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":dp_query",
|
||||
":normalized_query",
|
||||
"//tensorflow_privacy/privacy/analysis:dp_event",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "distributed_skellam_query_test",
|
||||
srcs = ["distributed_skellam_query_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":distributed_skellam_query",
|
||||
":test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gaussian_query",
|
||||
srcs = ["gaussian_query.py"],
|
||||
|
|
165
tensorflow_privacy/privacy/dp_query/distributed_skellam_query.py
Normal file
165
tensorflow_privacy/privacy/dp_query/distributed_skellam_query.py
Normal file
|
@ -0,0 +1,165 @@
|
|||
# Copyright 2022, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Implements DPQuery interface for Skellam average queries."""
|
||||
|
||||
import collections
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.analysis import dp_event
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
from tensorflow_privacy.privacy.dp_query import normalized_query
|
||||
|
||||
|
||||
class DistributedSkellamSumQuery(dp_query.SumAggregationDPQuery):
|
||||
"""Implements DPQuery interface for discrete distributed sum queries.
|
||||
|
||||
This implementation is for the distributed queries where the Skellam noise
|
||||
is applied locally to a discrete vector that matches the norm bound.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_GlobalState = collections.namedtuple(
|
||||
'_GlobalState', ['l1_norm_bound', 'l2_norm_bound', 'local_stddev'])
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_SampleParams = collections.namedtuple(
|
||||
'_SampleParams', ['l1_norm_bound', 'l2_norm_bound', 'local_stddev'])
|
||||
|
||||
def __init__(self, l1_norm_bound, l2_norm_bound, local_stddev):
|
||||
"""Initializes the DistributedSkellamSumQuery.
|
||||
|
||||
Args:
|
||||
l1_norm_bound: The l1 norm bound to verify for each record.
|
||||
l2_norm_bound: The l2 norm bound to verify for each record.
|
||||
local_stddev: The standard deviation of the Skellam distribution.
|
||||
"""
|
||||
self._l1_norm_bound = l1_norm_bound
|
||||
self._l2_norm_bound = l2_norm_bound
|
||||
self._local_stddev = local_stddev
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Since we operate on discrete values, use int for L1 bound and float for L2 bound."""
|
||||
return self._GlobalState(
|
||||
tf.cast(self._l1_norm_bound, tf.int32),
|
||||
tf.cast(self._l2_norm_bound, tf.float32),
|
||||
tf.cast(self._local_stddev, tf.float32))
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
return self._SampleParams(global_state.l1_norm_bound,
|
||||
global_state.l2_norm_bound,
|
||||
global_state.local_stddev)
|
||||
|
||||
def add_noise_to_sample(self, local_stddev, record):
|
||||
"""Adds Skellam noise to the sample.
|
||||
|
||||
We use difference of two Poisson random variable with lambda hyperparameter
|
||||
that equals 'local_stddev**2/2' that results in a standard deviation
|
||||
'local_stddev' for the Skellam noise to be added locally.
|
||||
|
||||
Args:
|
||||
local_stddev: The standard deviation of the local Skellam noise.
|
||||
record: The record to be processed.
|
||||
|
||||
Returns:
|
||||
A record with added noise.
|
||||
"""
|
||||
# Use float64 as the stddev could be large after quantization.
|
||||
local_stddev = tf.cast(local_stddev, tf.float64)
|
||||
poisson_lam = 0.5 * local_stddev * local_stddev
|
||||
|
||||
def add_noise(v):
|
||||
poissons = tf.random.stateless_poisson(
|
||||
shape=tf.concat([tf.shape(v), [2]], axis=0), # Two draws of Poisson.
|
||||
seed=tf.cast([tf.timestamp() * 10**6, 0], tf.int64),
|
||||
lam=[poisson_lam, poisson_lam],
|
||||
dtype=tf.int64)
|
||||
return v + tf.cast(poissons[..., 0] - poissons[..., 1], v.dtype)
|
||||
|
||||
return tf.nest.map_structure(add_noise, record)
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""Check record norm and add noise to the record.
|
||||
|
||||
For both L1 and L2 norms we compute a global norm of the provided record.
|
||||
Since the record contains int32 tensors we cast them into float32 to
|
||||
compute L2 norm. In the end we run three asserts: type, l1, and l2 norms.
|
||||
|
||||
Args:
|
||||
params: The parameters for the particular sample.
|
||||
record: The record to be processed.
|
||||
|
||||
Returns:
|
||||
A tuple (preprocessed_records, params) where `preprocessed_records` is
|
||||
the structure of preprocessed tensors, and params contains sample
|
||||
params.
|
||||
"""
|
||||
record_as_list = tf.nest.flatten(record)
|
||||
record_as_float = [tf.cast(x, tf.float32) for x in record_as_list]
|
||||
tf.nest.map_structure(lambda x: tf.debugging.assert_type(x, tf.int32),
|
||||
record_as_list)
|
||||
dependencies = [
|
||||
tf.debugging.assert_less_equal(
|
||||
tf.reduce_sum([tf.norm(x, ord=1) for x in record_as_list]),
|
||||
params.l1_norm_bound,
|
||||
message=f'L1 norm exceeds {params.l1_norm_bound}.'),
|
||||
tf.debugging.assert_less_equal(
|
||||
tf.linalg.global_norm(record_as_float),
|
||||
params.l2_norm_bound,
|
||||
message=f'Global L2 norm exceeds {params.l2_norm_bound}.')
|
||||
]
|
||||
with tf.control_dependencies(dependencies):
|
||||
record = tf.cond(
|
||||
tf.equal(params.local_stddev, 0), lambda: record,
|
||||
lambda: self.add_noise_to_sample(params.local_stddev, record))
|
||||
return record
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""The noise was already added locally, therefore just continue."""
|
||||
event = dp_event.UnsupportedDpEvent()
|
||||
return sample_state, global_state, event
|
||||
|
||||
|
||||
class DistributedSkellamAverageQuery(normalized_query.NormalizedQuery):
|
||||
"""Implements DPQuery interface for Skellam average queries.
|
||||
|
||||
Checks norm bounds and adds Skellam noise to each vector, sums them up, casts
|
||||
to float32 and normalizes using the truediv operation.
|
||||
"""
|
||||
|
||||
def __init__(self, l1_norm_bound, l2_norm_bound, local_stddev, denominator):
|
||||
"""Initializes the GaussianAverageQuery.
|
||||
|
||||
Args:
|
||||
l1_norm_bound: The l1 norm bound to verify for each record.
|
||||
l2_norm_bound: The l2 norm bound to verify for each record.
|
||||
local_stddev: The local_stddev of the noise added to each record (before
|
||||
sum and normalization).
|
||||
denominator: The normalization constant (applied after sum).
|
||||
"""
|
||||
super().__init__(
|
||||
numerator_query=DistributedSkellamSumQuery(l1_norm_bound, l2_norm_bound,
|
||||
local_stddev),
|
||||
denominator=denominator)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Normalize accumulated sum with floordiv."""
|
||||
noised_sum, new_sum_global_state, event = self._numerator.get_noised_result(
|
||||
sample_state, global_state.numerator_state)
|
||||
|
||||
def normalize(v):
|
||||
return tf.math.truediv(tf.cast(v, tf.float32), global_state.denominator)
|
||||
|
||||
return (tf.nest.map_structure(normalize, noised_sum),
|
||||
self._GlobalState(new_sum_global_state,
|
||||
global_state.denominator), event)
|
|
@ -0,0 +1,168 @@
|
|||
# Copyright 2022, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.dp_query import distributed_skellam_query
|
||||
from tensorflow_privacy.privacy.dp_query import test_utils
|
||||
import tensorflow_probability as tfp
|
||||
|
||||
|
||||
class DistributedSkellamQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_skellam_sum_no_noise(self):
|
||||
with self.cached_session() as sess:
|
||||
record1 = tf.constant([2, 0], dtype=tf.int32)
|
||||
record2 = tf.constant([-1, 1], dtype=tf.int32)
|
||||
|
||||
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
||||
l1_norm_bound=10, l2_norm_bound=10, local_stddev=0.0)
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [1, 1]
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
def test_skellam_multiple_shapes(self):
|
||||
with self.cached_session() as sess:
|
||||
tensor1 = tf.constant([2, 0], dtype=tf.int32)
|
||||
tensor2 = tf.constant([-1, 1, 3], dtype=tf.int32)
|
||||
record = [tensor1, tensor2]
|
||||
|
||||
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
||||
l1_norm_bound=10, l2_norm_bound=10, local_stddev=0.0)
|
||||
query_result, _ = test_utils.run_query(query, [record, record])
|
||||
result = sess.run(query_result)
|
||||
expected = [2 * tensor1, 2 * tensor2]
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
def test_skellam_raise_type_exception(self):
|
||||
with self.cached_session() as sess, self.assertRaises(TypeError):
|
||||
record1 = tf.constant([2, 0], dtype=tf.float32)
|
||||
record2 = tf.constant([-1, 1], dtype=tf.float32)
|
||||
|
||||
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
||||
l1_norm_bound=10, l2_norm_bound=10, local_stddev=0.0)
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
sess.run(query_result)
|
||||
|
||||
def test_skellam_raise_l1_norm_exception(self):
|
||||
with self.cached_session() as sess, self.assertRaises(
|
||||
tf.errors.InvalidArgumentError):
|
||||
record1 = tf.constant([1, 2], dtype=tf.int32)
|
||||
record2 = tf.constant([3, 4], dtype=tf.int32)
|
||||
|
||||
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
||||
l1_norm_bound=1, l2_norm_bound=100, local_stddev=0.0)
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
|
||||
sess.run(query_result)
|
||||
|
||||
def test_skellam_raise_l2_norm_exception(self):
|
||||
with self.cached_session() as sess, self.assertRaises(
|
||||
tf.errors.InvalidArgumentError):
|
||||
record1 = tf.constant([1, 2], dtype=tf.int32)
|
||||
record2 = tf.constant([3, 4], dtype=tf.int32)
|
||||
|
||||
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
||||
l1_norm_bound=10, l2_norm_bound=4, local_stddev=0.0)
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
|
||||
sess.run(query_result)
|
||||
|
||||
def test_skellam_sum_with_noise(self):
|
||||
"""Use only one record to test std."""
|
||||
with self.cached_session() as sess:
|
||||
record = tf.constant([1], dtype=tf.int32)
|
||||
local_stddev = 1.0
|
||||
|
||||
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
||||
l1_norm_bound=10.0, l2_norm_bound=10, local_stddev=local_stddev)
|
||||
query_result, _ = test_utils.run_query(query, [record])
|
||||
|
||||
noised_sums = []
|
||||
for _ in range(1000):
|
||||
noised_sums.append(sess.run(query_result))
|
||||
|
||||
result_stddev = np.std(noised_sums)
|
||||
self.assertNear(result_stddev, local_stddev, 0.1)
|
||||
|
||||
def test_compare_centralized_distributed_skellam(self):
|
||||
"""Compare the percentiles of distributed and centralized Skellam.
|
||||
|
||||
The test creates a large zero-vector with shape [num_trials, num_users] to
|
||||
be processed with the distributed Skellam noise stddev=1. The result is
|
||||
summed over the num_users dimension. The centralized result is produced by
|
||||
adding noise to a zero vector [num_trials] with stddev = 1*sqrt(num_users).
|
||||
Both results are evaluated to match percentiles (25, 50, 75).
|
||||
"""
|
||||
|
||||
with self.cached_session() as sess:
|
||||
num_trials = 10000
|
||||
num_users = 100
|
||||
record = tf.zeros([num_trials], dtype=tf.int32)
|
||||
local_stddev = 1.0
|
||||
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
||||
l1_norm_bound=10.0, l2_norm_bound=10, local_stddev=local_stddev)
|
||||
query_result, _ = test_utils.run_query(query, [record])
|
||||
distributed_noised = tf.zeros([num_trials], dtype=tf.int32)
|
||||
for _ in range(num_users):
|
||||
distributed_noised += sess.run(query_result)
|
||||
|
||||
def add_noise(v, stddev):
|
||||
lam = stddev**2 / 2
|
||||
|
||||
noise_poisson1 = tf.random.poisson(
|
||||
lam=lam, shape=tf.shape(v), dtype=v.dtype)
|
||||
noise_poisson2 = tf.random.poisson(
|
||||
lam=lam, shape=tf.shape(v), dtype=v.dtype)
|
||||
res = v + (noise_poisson1 - noise_poisson2)
|
||||
return res
|
||||
|
||||
record_centralized = tf.zeros([num_trials], dtype=tf.int32)
|
||||
centralized_noised = sess.run(
|
||||
add_noise(record_centralized, local_stddev * np.sqrt(num_users)))
|
||||
|
||||
tolerance = 5
|
||||
self.assertAllClose(
|
||||
tfp.stats.percentile(distributed_noised, 50.0),
|
||||
tfp.stats.percentile(centralized_noised, 50.0),
|
||||
atol=tolerance)
|
||||
self.assertAllClose(
|
||||
tfp.stats.percentile(distributed_noised, 75.0),
|
||||
tfp.stats.percentile(centralized_noised, 75.0),
|
||||
atol=tolerance)
|
||||
self.assertAllClose(
|
||||
tfp.stats.percentile(distributed_noised, 25.0),
|
||||
tfp.stats.percentile(centralized_noised, 25.0),
|
||||
atol=tolerance)
|
||||
|
||||
def test_skellam_average_no_noise(self):
|
||||
with self.cached_session() as sess:
|
||||
record1 = tf.constant([1, 1], dtype=tf.int32)
|
||||
record2 = tf.constant([1, 1], dtype=tf.int32)
|
||||
|
||||
query = distributed_skellam_query.DistributedSkellamAverageQuery(
|
||||
l1_norm_bound=3.0,
|
||||
l2_norm_bound=3.0,
|
||||
local_stddev=0.0,
|
||||
denominator=2.0)
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected_average = [1, 1]
|
||||
self.assertAllClose(result, expected_average)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
Loading…
Reference in a new issue