Adds DistributedSkellamQuery to public TF Privacy.

PiperOrigin-RevId: 429664212
This commit is contained in:
Peter Kairouz 2022-02-18 15:47:53 -08:00 committed by A. Unique TensorFlower
parent ffc29e1d82
commit 89de03e0db
3 changed files with 355 additions and 0 deletions

View file

@ -75,6 +75,28 @@ py_test(
],
)
py_library(
name = "distributed_skellam_query",
srcs = ["distributed_skellam_query.py"],
srcs_version = "PY3",
deps = [
":dp_query",
":normalized_query",
"//tensorflow_privacy/privacy/analysis:dp_event",
],
)
py_test(
name = "distributed_skellam_query_test",
srcs = ["distributed_skellam_query_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":distributed_skellam_query",
":test_utils",
],
)
py_library(
name = "gaussian_query",
srcs = ["gaussian_query.py"],

View file

@ -0,0 +1,165 @@
# Copyright 2022, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements DPQuery interface for Skellam average queries."""
import collections
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import normalized_query
class DistributedSkellamSumQuery(dp_query.SumAggregationDPQuery):
"""Implements DPQuery interface for discrete distributed sum queries.
This implementation is for the distributed queries where the Skellam noise
is applied locally to a discrete vector that matches the norm bound.
"""
# pylint: disable=invalid-name
_GlobalState = collections.namedtuple(
'_GlobalState', ['l1_norm_bound', 'l2_norm_bound', 'local_stddev'])
# pylint: disable=invalid-name
_SampleParams = collections.namedtuple(
'_SampleParams', ['l1_norm_bound', 'l2_norm_bound', 'local_stddev'])
def __init__(self, l1_norm_bound, l2_norm_bound, local_stddev):
"""Initializes the DistributedSkellamSumQuery.
Args:
l1_norm_bound: The l1 norm bound to verify for each record.
l2_norm_bound: The l2 norm bound to verify for each record.
local_stddev: The standard deviation of the Skellam distribution.
"""
self._l1_norm_bound = l1_norm_bound
self._l2_norm_bound = l2_norm_bound
self._local_stddev = local_stddev
def initial_global_state(self):
"""Since we operate on discrete values, use int for L1 bound and float for L2 bound."""
return self._GlobalState(
tf.cast(self._l1_norm_bound, tf.int32),
tf.cast(self._l2_norm_bound, tf.float32),
tf.cast(self._local_stddev, tf.float32))
def derive_sample_params(self, global_state):
return self._SampleParams(global_state.l1_norm_bound,
global_state.l2_norm_bound,
global_state.local_stddev)
def add_noise_to_sample(self, local_stddev, record):
"""Adds Skellam noise to the sample.
We use difference of two Poisson random variable with lambda hyperparameter
that equals 'local_stddev**2/2' that results in a standard deviation
'local_stddev' for the Skellam noise to be added locally.
Args:
local_stddev: The standard deviation of the local Skellam noise.
record: The record to be processed.
Returns:
A record with added noise.
"""
# Use float64 as the stddev could be large after quantization.
local_stddev = tf.cast(local_stddev, tf.float64)
poisson_lam = 0.5 * local_stddev * local_stddev
def add_noise(v):
poissons = tf.random.stateless_poisson(
shape=tf.concat([tf.shape(v), [2]], axis=0), # Two draws of Poisson.
seed=tf.cast([tf.timestamp() * 10**6, 0], tf.int64),
lam=[poisson_lam, poisson_lam],
dtype=tf.int64)
return v + tf.cast(poissons[..., 0] - poissons[..., 1], v.dtype)
return tf.nest.map_structure(add_noise, record)
def preprocess_record(self, params, record):
"""Check record norm and add noise to the record.
For both L1 and L2 norms we compute a global norm of the provided record.
Since the record contains int32 tensors we cast them into float32 to
compute L2 norm. In the end we run three asserts: type, l1, and l2 norms.
Args:
params: The parameters for the particular sample.
record: The record to be processed.
Returns:
A tuple (preprocessed_records, params) where `preprocessed_records` is
the structure of preprocessed tensors, and params contains sample
params.
"""
record_as_list = tf.nest.flatten(record)
record_as_float = [tf.cast(x, tf.float32) for x in record_as_list]
tf.nest.map_structure(lambda x: tf.debugging.assert_type(x, tf.int32),
record_as_list)
dependencies = [
tf.debugging.assert_less_equal(
tf.reduce_sum([tf.norm(x, ord=1) for x in record_as_list]),
params.l1_norm_bound,
message=f'L1 norm exceeds {params.l1_norm_bound}.'),
tf.debugging.assert_less_equal(
tf.linalg.global_norm(record_as_float),
params.l2_norm_bound,
message=f'Global L2 norm exceeds {params.l2_norm_bound}.')
]
with tf.control_dependencies(dependencies):
record = tf.cond(
tf.equal(params.local_stddev, 0), lambda: record,
lambda: self.add_noise_to_sample(params.local_stddev, record))
return record
def get_noised_result(self, sample_state, global_state):
"""The noise was already added locally, therefore just continue."""
event = dp_event.UnsupportedDpEvent()
return sample_state, global_state, event
class DistributedSkellamAverageQuery(normalized_query.NormalizedQuery):
"""Implements DPQuery interface for Skellam average queries.
Checks norm bounds and adds Skellam noise to each vector, sums them up, casts
to float32 and normalizes using the truediv operation.
"""
def __init__(self, l1_norm_bound, l2_norm_bound, local_stddev, denominator):
"""Initializes the GaussianAverageQuery.
Args:
l1_norm_bound: The l1 norm bound to verify for each record.
l2_norm_bound: The l2 norm bound to verify for each record.
local_stddev: The local_stddev of the noise added to each record (before
sum and normalization).
denominator: The normalization constant (applied after sum).
"""
super().__init__(
numerator_query=DistributedSkellamSumQuery(l1_norm_bound, l2_norm_bound,
local_stddev),
denominator=denominator)
def get_noised_result(self, sample_state, global_state):
"""Normalize accumulated sum with floordiv."""
noised_sum, new_sum_global_state, event = self._numerator.get_noised_result(
sample_state, global_state.numerator_state)
def normalize(v):
return tf.math.truediv(tf.cast(v, tf.float32), global_state.denominator)
return (tf.nest.map_structure(normalize, noised_sum),
self._GlobalState(new_sum_global_state,
global_state.denominator), event)

View file

@ -0,0 +1,168 @@
# Copyright 2022, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import distributed_skellam_query
from tensorflow_privacy.privacy.dp_query import test_utils
import tensorflow_probability as tfp
class DistributedSkellamQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_skellam_sum_no_noise(self):
with self.cached_session() as sess:
record1 = tf.constant([2, 0], dtype=tf.int32)
record2 = tf.constant([-1, 1], dtype=tf.int32)
query = distributed_skellam_query.DistributedSkellamSumQuery(
l1_norm_bound=10, l2_norm_bound=10, local_stddev=0.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result)
expected = [1, 1]
self.assertAllClose(result, expected)
def test_skellam_multiple_shapes(self):
with self.cached_session() as sess:
tensor1 = tf.constant([2, 0], dtype=tf.int32)
tensor2 = tf.constant([-1, 1, 3], dtype=tf.int32)
record = [tensor1, tensor2]
query = distributed_skellam_query.DistributedSkellamSumQuery(
l1_norm_bound=10, l2_norm_bound=10, local_stddev=0.0)
query_result, _ = test_utils.run_query(query, [record, record])
result = sess.run(query_result)
expected = [2 * tensor1, 2 * tensor2]
self.assertAllClose(result, expected)
def test_skellam_raise_type_exception(self):
with self.cached_session() as sess, self.assertRaises(TypeError):
record1 = tf.constant([2, 0], dtype=tf.float32)
record2 = tf.constant([-1, 1], dtype=tf.float32)
query = distributed_skellam_query.DistributedSkellamSumQuery(
l1_norm_bound=10, l2_norm_bound=10, local_stddev=0.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
sess.run(query_result)
def test_skellam_raise_l1_norm_exception(self):
with self.cached_session() as sess, self.assertRaises(
tf.errors.InvalidArgumentError):
record1 = tf.constant([1, 2], dtype=tf.int32)
record2 = tf.constant([3, 4], dtype=tf.int32)
query = distributed_skellam_query.DistributedSkellamSumQuery(
l1_norm_bound=1, l2_norm_bound=100, local_stddev=0.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
sess.run(query_result)
def test_skellam_raise_l2_norm_exception(self):
with self.cached_session() as sess, self.assertRaises(
tf.errors.InvalidArgumentError):
record1 = tf.constant([1, 2], dtype=tf.int32)
record2 = tf.constant([3, 4], dtype=tf.int32)
query = distributed_skellam_query.DistributedSkellamSumQuery(
l1_norm_bound=10, l2_norm_bound=4, local_stddev=0.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
sess.run(query_result)
def test_skellam_sum_with_noise(self):
"""Use only one record to test std."""
with self.cached_session() as sess:
record = tf.constant([1], dtype=tf.int32)
local_stddev = 1.0
query = distributed_skellam_query.DistributedSkellamSumQuery(
l1_norm_bound=10.0, l2_norm_bound=10, local_stddev=local_stddev)
query_result, _ = test_utils.run_query(query, [record])
noised_sums = []
for _ in range(1000):
noised_sums.append(sess.run(query_result))
result_stddev = np.std(noised_sums)
self.assertNear(result_stddev, local_stddev, 0.1)
def test_compare_centralized_distributed_skellam(self):
"""Compare the percentiles of distributed and centralized Skellam.
The test creates a large zero-vector with shape [num_trials, num_users] to
be processed with the distributed Skellam noise stddev=1. The result is
summed over the num_users dimension. The centralized result is produced by
adding noise to a zero vector [num_trials] with stddev = 1*sqrt(num_users).
Both results are evaluated to match percentiles (25, 50, 75).
"""
with self.cached_session() as sess:
num_trials = 10000
num_users = 100
record = tf.zeros([num_trials], dtype=tf.int32)
local_stddev = 1.0
query = distributed_skellam_query.DistributedSkellamSumQuery(
l1_norm_bound=10.0, l2_norm_bound=10, local_stddev=local_stddev)
query_result, _ = test_utils.run_query(query, [record])
distributed_noised = tf.zeros([num_trials], dtype=tf.int32)
for _ in range(num_users):
distributed_noised += sess.run(query_result)
def add_noise(v, stddev):
lam = stddev**2 / 2
noise_poisson1 = tf.random.poisson(
lam=lam, shape=tf.shape(v), dtype=v.dtype)
noise_poisson2 = tf.random.poisson(
lam=lam, shape=tf.shape(v), dtype=v.dtype)
res = v + (noise_poisson1 - noise_poisson2)
return res
record_centralized = tf.zeros([num_trials], dtype=tf.int32)
centralized_noised = sess.run(
add_noise(record_centralized, local_stddev * np.sqrt(num_users)))
tolerance = 5
self.assertAllClose(
tfp.stats.percentile(distributed_noised, 50.0),
tfp.stats.percentile(centralized_noised, 50.0),
atol=tolerance)
self.assertAllClose(
tfp.stats.percentile(distributed_noised, 75.0),
tfp.stats.percentile(centralized_noised, 75.0),
atol=tolerance)
self.assertAllClose(
tfp.stats.percentile(distributed_noised, 25.0),
tfp.stats.percentile(centralized_noised, 25.0),
atol=tolerance)
def test_skellam_average_no_noise(self):
with self.cached_session() as sess:
record1 = tf.constant([1, 1], dtype=tf.int32)
record2 = tf.constant([1, 1], dtype=tf.int32)
query = distributed_skellam_query.DistributedSkellamAverageQuery(
l1_norm_bound=3.0,
l2_norm_bound=3.0,
local_stddev=0.0,
denominator=2.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result)
expected_average = [1, 1]
self.assertAllClose(result, expected_average)
if __name__ == '__main__':
tf.test.main()