From 89de03e0dbc5ebf32835ca00a2426ea607bf6516 Mon Sep 17 00:00:00 2001 From: Peter Kairouz Date: Fri, 18 Feb 2022 15:47:53 -0800 Subject: [PATCH] Adds `DistributedSkellamQuery` to public TF Privacy. PiperOrigin-RevId: 429664212 --- tensorflow_privacy/privacy/dp_query/BUILD | 22 +++ .../dp_query/distributed_skellam_query.py | 165 +++++++++++++++++ .../distributed_skellam_query_test.py | 168 ++++++++++++++++++ 3 files changed, 355 insertions(+) create mode 100644 tensorflow_privacy/privacy/dp_query/distributed_skellam_query.py create mode 100644 tensorflow_privacy/privacy/dp_query/distributed_skellam_query_test.py diff --git a/tensorflow_privacy/privacy/dp_query/BUILD b/tensorflow_privacy/privacy/dp_query/BUILD index a1bb9e1..b742f0c 100644 --- a/tensorflow_privacy/privacy/dp_query/BUILD +++ b/tensorflow_privacy/privacy/dp_query/BUILD @@ -75,6 +75,28 @@ py_test( ], ) +py_library( + name = "distributed_skellam_query", + srcs = ["distributed_skellam_query.py"], + srcs_version = "PY3", + deps = [ + ":dp_query", + ":normalized_query", + "//tensorflow_privacy/privacy/analysis:dp_event", + ], +) + +py_test( + name = "distributed_skellam_query_test", + srcs = ["distributed_skellam_query_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":distributed_skellam_query", + ":test_utils", + ], +) + py_library( name = "gaussian_query", srcs = ["gaussian_query.py"], diff --git a/tensorflow_privacy/privacy/dp_query/distributed_skellam_query.py b/tensorflow_privacy/privacy/dp_query/distributed_skellam_query.py new file mode 100644 index 0000000..4a4266a --- /dev/null +++ b/tensorflow_privacy/privacy/dp_query/distributed_skellam_query.py @@ -0,0 +1,165 @@ +# Copyright 2022, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implements DPQuery interface for Skellam average queries.""" + +import collections + +import tensorflow as tf +from tensorflow_privacy.privacy.analysis import dp_event +from tensorflow_privacy.privacy.dp_query import dp_query +from tensorflow_privacy.privacy.dp_query import normalized_query + + +class DistributedSkellamSumQuery(dp_query.SumAggregationDPQuery): + """Implements DPQuery interface for discrete distributed sum queries. + + This implementation is for the distributed queries where the Skellam noise + is applied locally to a discrete vector that matches the norm bound. + """ + + # pylint: disable=invalid-name + _GlobalState = collections.namedtuple( + '_GlobalState', ['l1_norm_bound', 'l2_norm_bound', 'local_stddev']) + + # pylint: disable=invalid-name + _SampleParams = collections.namedtuple( + '_SampleParams', ['l1_norm_bound', 'l2_norm_bound', 'local_stddev']) + + def __init__(self, l1_norm_bound, l2_norm_bound, local_stddev): + """Initializes the DistributedSkellamSumQuery. + + Args: + l1_norm_bound: The l1 norm bound to verify for each record. + l2_norm_bound: The l2 norm bound to verify for each record. + local_stddev: The standard deviation of the Skellam distribution. + """ + self._l1_norm_bound = l1_norm_bound + self._l2_norm_bound = l2_norm_bound + self._local_stddev = local_stddev + + def initial_global_state(self): + """Since we operate on discrete values, use int for L1 bound and float for L2 bound.""" + return self._GlobalState( + tf.cast(self._l1_norm_bound, tf.int32), + tf.cast(self._l2_norm_bound, tf.float32), + tf.cast(self._local_stddev, tf.float32)) + + def derive_sample_params(self, global_state): + return self._SampleParams(global_state.l1_norm_bound, + global_state.l2_norm_bound, + global_state.local_stddev) + + def add_noise_to_sample(self, local_stddev, record): + """Adds Skellam noise to the sample. + + We use difference of two Poisson random variable with lambda hyperparameter + that equals 'local_stddev**2/2' that results in a standard deviation + 'local_stddev' for the Skellam noise to be added locally. + + Args: + local_stddev: The standard deviation of the local Skellam noise. + record: The record to be processed. + + Returns: + A record with added noise. + """ + # Use float64 as the stddev could be large after quantization. + local_stddev = tf.cast(local_stddev, tf.float64) + poisson_lam = 0.5 * local_stddev * local_stddev + + def add_noise(v): + poissons = tf.random.stateless_poisson( + shape=tf.concat([tf.shape(v), [2]], axis=0), # Two draws of Poisson. + seed=tf.cast([tf.timestamp() * 10**6, 0], tf.int64), + lam=[poisson_lam, poisson_lam], + dtype=tf.int64) + return v + tf.cast(poissons[..., 0] - poissons[..., 1], v.dtype) + + return tf.nest.map_structure(add_noise, record) + + def preprocess_record(self, params, record): + """Check record norm and add noise to the record. + + For both L1 and L2 norms we compute a global norm of the provided record. + Since the record contains int32 tensors we cast them into float32 to + compute L2 norm. In the end we run three asserts: type, l1, and l2 norms. + + Args: + params: The parameters for the particular sample. + record: The record to be processed. + + Returns: + A tuple (preprocessed_records, params) where `preprocessed_records` is + the structure of preprocessed tensors, and params contains sample + params. + """ + record_as_list = tf.nest.flatten(record) + record_as_float = [tf.cast(x, tf.float32) for x in record_as_list] + tf.nest.map_structure(lambda x: tf.debugging.assert_type(x, tf.int32), + record_as_list) + dependencies = [ + tf.debugging.assert_less_equal( + tf.reduce_sum([tf.norm(x, ord=1) for x in record_as_list]), + params.l1_norm_bound, + message=f'L1 norm exceeds {params.l1_norm_bound}.'), + tf.debugging.assert_less_equal( + tf.linalg.global_norm(record_as_float), + params.l2_norm_bound, + message=f'Global L2 norm exceeds {params.l2_norm_bound}.') + ] + with tf.control_dependencies(dependencies): + record = tf.cond( + tf.equal(params.local_stddev, 0), lambda: record, + lambda: self.add_noise_to_sample(params.local_stddev, record)) + return record + + def get_noised_result(self, sample_state, global_state): + """The noise was already added locally, therefore just continue.""" + event = dp_event.UnsupportedDpEvent() + return sample_state, global_state, event + + +class DistributedSkellamAverageQuery(normalized_query.NormalizedQuery): + """Implements DPQuery interface for Skellam average queries. + + Checks norm bounds and adds Skellam noise to each vector, sums them up, casts + to float32 and normalizes using the truediv operation. + """ + + def __init__(self, l1_norm_bound, l2_norm_bound, local_stddev, denominator): + """Initializes the GaussianAverageQuery. + + Args: + l1_norm_bound: The l1 norm bound to verify for each record. + l2_norm_bound: The l2 norm bound to verify for each record. + local_stddev: The local_stddev of the noise added to each record (before + sum and normalization). + denominator: The normalization constant (applied after sum). + """ + super().__init__( + numerator_query=DistributedSkellamSumQuery(l1_norm_bound, l2_norm_bound, + local_stddev), + denominator=denominator) + + def get_noised_result(self, sample_state, global_state): + """Normalize accumulated sum with floordiv.""" + noised_sum, new_sum_global_state, event = self._numerator.get_noised_result( + sample_state, global_state.numerator_state) + + def normalize(v): + return tf.math.truediv(tf.cast(v, tf.float32), global_state.denominator) + + return (tf.nest.map_structure(normalize, noised_sum), + self._GlobalState(new_sum_global_state, + global_state.denominator), event) diff --git a/tensorflow_privacy/privacy/dp_query/distributed_skellam_query_test.py b/tensorflow_privacy/privacy/dp_query/distributed_skellam_query_test.py new file mode 100644 index 0000000..8e6e7a4 --- /dev/null +++ b/tensorflow_privacy/privacy/dp_query/distributed_skellam_query_test.py @@ -0,0 +1,168 @@ +# Copyright 2022, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import numpy as np +import tensorflow as tf +from tensorflow_privacy.privacy.dp_query import distributed_skellam_query +from tensorflow_privacy.privacy.dp_query import test_utils +import tensorflow_probability as tfp + + +class DistributedSkellamQueryTest(tf.test.TestCase, parameterized.TestCase): + + def test_skellam_sum_no_noise(self): + with self.cached_session() as sess: + record1 = tf.constant([2, 0], dtype=tf.int32) + record2 = tf.constant([-1, 1], dtype=tf.int32) + + query = distributed_skellam_query.DistributedSkellamSumQuery( + l1_norm_bound=10, l2_norm_bound=10, local_stddev=0.0) + query_result, _ = test_utils.run_query(query, [record1, record2]) + result = sess.run(query_result) + expected = [1, 1] + self.assertAllClose(result, expected) + + def test_skellam_multiple_shapes(self): + with self.cached_session() as sess: + tensor1 = tf.constant([2, 0], dtype=tf.int32) + tensor2 = tf.constant([-1, 1, 3], dtype=tf.int32) + record = [tensor1, tensor2] + + query = distributed_skellam_query.DistributedSkellamSumQuery( + l1_norm_bound=10, l2_norm_bound=10, local_stddev=0.0) + query_result, _ = test_utils.run_query(query, [record, record]) + result = sess.run(query_result) + expected = [2 * tensor1, 2 * tensor2] + self.assertAllClose(result, expected) + + def test_skellam_raise_type_exception(self): + with self.cached_session() as sess, self.assertRaises(TypeError): + record1 = tf.constant([2, 0], dtype=tf.float32) + record2 = tf.constant([-1, 1], dtype=tf.float32) + + query = distributed_skellam_query.DistributedSkellamSumQuery( + l1_norm_bound=10, l2_norm_bound=10, local_stddev=0.0) + query_result, _ = test_utils.run_query(query, [record1, record2]) + sess.run(query_result) + + def test_skellam_raise_l1_norm_exception(self): + with self.cached_session() as sess, self.assertRaises( + tf.errors.InvalidArgumentError): + record1 = tf.constant([1, 2], dtype=tf.int32) + record2 = tf.constant([3, 4], dtype=tf.int32) + + query = distributed_skellam_query.DistributedSkellamSumQuery( + l1_norm_bound=1, l2_norm_bound=100, local_stddev=0.0) + query_result, _ = test_utils.run_query(query, [record1, record2]) + + sess.run(query_result) + + def test_skellam_raise_l2_norm_exception(self): + with self.cached_session() as sess, self.assertRaises( + tf.errors.InvalidArgumentError): + record1 = tf.constant([1, 2], dtype=tf.int32) + record2 = tf.constant([3, 4], dtype=tf.int32) + + query = distributed_skellam_query.DistributedSkellamSumQuery( + l1_norm_bound=10, l2_norm_bound=4, local_stddev=0.0) + query_result, _ = test_utils.run_query(query, [record1, record2]) + + sess.run(query_result) + + def test_skellam_sum_with_noise(self): + """Use only one record to test std.""" + with self.cached_session() as sess: + record = tf.constant([1], dtype=tf.int32) + local_stddev = 1.0 + + query = distributed_skellam_query.DistributedSkellamSumQuery( + l1_norm_bound=10.0, l2_norm_bound=10, local_stddev=local_stddev) + query_result, _ = test_utils.run_query(query, [record]) + + noised_sums = [] + for _ in range(1000): + noised_sums.append(sess.run(query_result)) + + result_stddev = np.std(noised_sums) + self.assertNear(result_stddev, local_stddev, 0.1) + + def test_compare_centralized_distributed_skellam(self): + """Compare the percentiles of distributed and centralized Skellam. + + The test creates a large zero-vector with shape [num_trials, num_users] to + be processed with the distributed Skellam noise stddev=1. The result is + summed over the num_users dimension. The centralized result is produced by + adding noise to a zero vector [num_trials] with stddev = 1*sqrt(num_users). + Both results are evaluated to match percentiles (25, 50, 75). + """ + + with self.cached_session() as sess: + num_trials = 10000 + num_users = 100 + record = tf.zeros([num_trials], dtype=tf.int32) + local_stddev = 1.0 + query = distributed_skellam_query.DistributedSkellamSumQuery( + l1_norm_bound=10.0, l2_norm_bound=10, local_stddev=local_stddev) + query_result, _ = test_utils.run_query(query, [record]) + distributed_noised = tf.zeros([num_trials], dtype=tf.int32) + for _ in range(num_users): + distributed_noised += sess.run(query_result) + + def add_noise(v, stddev): + lam = stddev**2 / 2 + + noise_poisson1 = tf.random.poisson( + lam=lam, shape=tf.shape(v), dtype=v.dtype) + noise_poisson2 = tf.random.poisson( + lam=lam, shape=tf.shape(v), dtype=v.dtype) + res = v + (noise_poisson1 - noise_poisson2) + return res + + record_centralized = tf.zeros([num_trials], dtype=tf.int32) + centralized_noised = sess.run( + add_noise(record_centralized, local_stddev * np.sqrt(num_users))) + + tolerance = 5 + self.assertAllClose( + tfp.stats.percentile(distributed_noised, 50.0), + tfp.stats.percentile(centralized_noised, 50.0), + atol=tolerance) + self.assertAllClose( + tfp.stats.percentile(distributed_noised, 75.0), + tfp.stats.percentile(centralized_noised, 75.0), + atol=tolerance) + self.assertAllClose( + tfp.stats.percentile(distributed_noised, 25.0), + tfp.stats.percentile(centralized_noised, 25.0), + atol=tolerance) + + def test_skellam_average_no_noise(self): + with self.cached_session() as sess: + record1 = tf.constant([1, 1], dtype=tf.int32) + record2 = tf.constant([1, 1], dtype=tf.int32) + + query = distributed_skellam_query.DistributedSkellamAverageQuery( + l1_norm_bound=3.0, + l2_norm_bound=3.0, + local_stddev=0.0, + denominator=2.0) + query_result, _ = test_utils.run_query(query, [record1, record2]) + result = sess.run(query_result) + expected_average = [1, 1] + self.assertAllClose(result, expected_average) + + +if __name__ == '__main__': + tf.test.main()