From f3af24b00ebe9a598337a592ff9096299c23b6e6 Mon Sep 17 00:00:00 2001 From: Ken Liu Date: Sun, 8 Aug 2021 03:43:01 -0700 Subject: [PATCH] Adds central discrete Gaussian DPQuery. PiperOrigin-RevId: 389467360 --- tensorflow_privacy/__init__.py | 1 + .../dp_query/discrete_gaussian_query.py | 89 +++++++++++ .../dp_query/discrete_gaussian_query_test.py | 148 ++++++++++++++++++ .../distributed_discrete_gaussian_query.py | 6 +- ...istributed_discrete_gaussian_query_test.py | 2 +- 5 files changed, 242 insertions(+), 4 deletions(-) create mode 100644 tensorflow_privacy/privacy/dp_query/discrete_gaussian_query.py create mode 100644 tensorflow_privacy/privacy/dp_query/discrete_gaussian_query_test.py diff --git a/tensorflow_privacy/__init__.py b/tensorflow_privacy/__init__.py index f775d80..cfd5344 100644 --- a/tensorflow_privacy/__init__.py +++ b/tensorflow_privacy/__init__.py @@ -43,6 +43,7 @@ else: # DPQuery classes from tensorflow_privacy.privacy.dp_query.dp_query import DPQuery from tensorflow_privacy.privacy.dp_query.dp_query import SumAggregationDPQuery + from tensorflow_privacy.privacy.dp_query.discrete_gaussian_query import DiscreteGaussianSumQuery from tensorflow_privacy.privacy.dp_query.distributed_discrete_gaussian_query import DistributedDiscreteGaussianSumQuery from tensorflow_privacy.privacy.dp_query.gaussian_query import GaussianSumQuery from tensorflow_privacy.privacy.dp_query.nested_query import NestedQuery diff --git a/tensorflow_privacy/privacy/dp_query/discrete_gaussian_query.py b/tensorflow_privacy/privacy/dp_query/discrete_gaussian_query.py new file mode 100644 index 0000000..444489b --- /dev/null +++ b/tensorflow_privacy/privacy/dp_query/discrete_gaussian_query.py @@ -0,0 +1,89 @@ +# Copyright 2021, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implements DPQuery interface for discrete Gaussian mechanism.""" + +import collections + +import tensorflow as tf +from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils +from tensorflow_privacy.privacy.dp_query import dp_query + + +class DiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery): + """Implements DPQuery for discrete Gaussian sum queries. + + For each local record, we check the L2 norm bound and add discrete Gaussian + noise. In particular, this DPQuery does not perform L2 norm clipping and the + norms of the input records are expected to be bounded. + """ + + # pylint: disable=invalid-name + _GlobalState = collections.namedtuple('_GlobalState', + ['l2_norm_bound', 'stddev']) + + # pylint: disable=invalid-name + _SampleParams = collections.namedtuple('_SampleParams', + ['l2_norm_bound', 'stddev']) + + def __init__(self, l2_norm_bound, stddev): + """Initializes the DiscreteGaussianSumQuery. + + Args: + l2_norm_bound: The L2 norm bound to verify for each record. + stddev: The stddev of the discrete Gaussian noise added to the sum. + """ + self._l2_norm_bound = l2_norm_bound + self._stddev = stddev + + def set_ledger(self, ledger): + del ledger # Unused. + raise NotImplementedError('Ledger has not yet been implemented for' + 'DiscreteGaussianSumQuery!') + + def initial_global_state(self): + return self._GlobalState( + tf.cast(self._l2_norm_bound, tf.float32), + tf.cast(self._stddev, tf.float32)) + + def derive_sample_params(self, global_state): + return self._SampleParams(global_state.l2_norm_bound, global_state.stddev) + + def preprocess_record(self, params, record): + """Check record norm and add noise to the record.""" + record_as_list = tf.nest.flatten(record) + record_as_float_list = [tf.cast(x, tf.float32) for x in record_as_list] + tf.nest.map_structure(lambda x: tf.compat.v1.assert_type(x, tf.int32), + record_as_list) + dependencies = [ + tf.compat.v1.assert_less_equal( + tf.linalg.global_norm(record_as_float_list), + params.l2_norm_bound, + message=f'Global L2 norm exceeds {params.l2_norm_bound}.') + ] + with tf.control_dependencies(dependencies): + return tf.nest.map_structure(tf.identity, record) + + def get_noised_result(self, sample_state, global_state): + """Adds discrete Gaussian noise to the aggregate.""" + # Round up the noise as the TF discrete Gaussian sampler only takes + # integer noise stddevs for now. + ceil_stddev = tf.cast(tf.math.ceil(global_state.stddev), tf.int32) + + def add_noise(v): + noised_v = v + discrete_gaussian_utils.sample_discrete_gaussian( + scale=ceil_stddev, shape=tf.shape(v), dtype=v.dtype) + # Ensure shape as TF shape inference may fail due to custom noise sampler. + return tf.ensure_shape(noised_v, v.shape) + + return tf.nest.map_structure(add_noise, sample_state), global_state diff --git a/tensorflow_privacy/privacy/dp_query/discrete_gaussian_query_test.py b/tensorflow_privacy/privacy/dp_query/discrete_gaussian_query_test.py new file mode 100644 index 0000000..fc14e7c --- /dev/null +++ b/tensorflow_privacy/privacy/dp_query/discrete_gaussian_query_test.py @@ -0,0 +1,148 @@ +# Copyright 2021, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for DiscreteGaussianSumQuery.""" + +from absl.testing import parameterized +import numpy as np +import tensorflow as tf +from tensorflow_privacy.privacy.dp_query import discrete_gaussian_query +from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils +from tensorflow_privacy.privacy.dp_query import test_utils + +dg_sum_query = discrete_gaussian_query.DiscreteGaussianSumQuery + + +def silence_tf_error_messages(func): + """Decorator that temporarily changes the TF logging levels.""" + + def wrapper(*args, **kwargs): + cur_verbosity = tf.compat.v1.logging.get_verbosity() + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.FATAL) + func(*args, **kwargs) + tf.compat.v1.logging.set_verbosity(cur_verbosity) # Reset verbosity. + + return wrapper + + +class DiscreteGaussianQueryTest(tf.test.TestCase, parameterized.TestCase): + + def test_sum_no_noise(self): + with self.cached_session() as sess: + record1 = tf.constant([2, 0], dtype=tf.int32) + record2 = tf.constant([-1, 1], dtype=tf.int32) + + query = dg_sum_query(l2_norm_bound=10, stddev=0.0) + query_result, _ = test_utils.run_query(query, [record1, record2]) + result = sess.run(query_result) + expected = [1, 1] + self.assertAllEqual(result, expected) + + @parameterized.product(sample_size=[1, 3]) + def test_sum_multiple_shapes(self, sample_size): + with self.cached_session() as sess: + t1 = tf.constant([2, 0], dtype=tf.int32) + t2 = tf.constant([-1, 1, 3], dtype=tf.int32) + t3 = tf.constant([-2], dtype=tf.int32) + record = [t1, t2, t3] + sample = [record] * sample_size + + query = dg_sum_query(l2_norm_bound=10, stddev=0.0) + query_result, _ = test_utils.run_query(query, sample) + expected = [sample_size * t1, sample_size * t2, sample_size * t3] + result, expected = sess.run([query_result, expected]) + # Use `assertAllClose` for nested structures equality (with tolerance=0). + self.assertAllClose(result, expected, atol=0) + + @parameterized.product(sample_size=[1, 3]) + def test_sum_nested_record_structure(self, sample_size): + with self.cached_session() as sess: + t1 = tf.constant([1, 0], dtype=tf.int32) + t2 = tf.constant([1, 1, 1], dtype=tf.int32) + t3 = tf.constant([1], dtype=tf.int32) + t4 = tf.constant([[1, 1], [1, 1]], dtype=tf.int32) + record = [t1, dict(a=t2, b=[t3, (t4, t1)])] + sample = [record] * sample_size + + query = dg_sum_query(l2_norm_bound=10, stddev=0.0) + query_result, _ = test_utils.run_query(query, sample) + result = sess.run(query_result) + + s = sample_size + expected = [t1 * s, dict(a=t2 * s, b=[t3 * s, (t4 * s, t1 * s)])] + # Use `assertAllClose` for nested structures equality (with tolerance=0) + self.assertAllClose(result, expected, atol=0) + + def test_sum_raise_on_float_inputs(self): + with self.cached_session() as sess: + record1 = tf.constant([2, 0], dtype=tf.float32) + record2 = tf.constant([-1, 1], dtype=tf.float32) + query = dg_sum_query(l2_norm_bound=10, stddev=0.0) + + with self.assertRaises(TypeError): + query_result, _ = test_utils.run_query(query, [record1, record2]) + sess.run(query_result) + + @parameterized.product(l2_norm_bound=[0, 3, 10, 14.1]) + @silence_tf_error_messages + def test_sum_raise_on_l2_norm_excess(self, l2_norm_bound): + with self.cached_session() as sess: + record = tf.constant([10, 10], dtype=tf.int32) + query = dg_sum_query(l2_norm_bound=l2_norm_bound, stddev=0.0) + + with self.assertRaises(tf.errors.InvalidArgumentError): + query_result, _ = test_utils.run_query(query, [record]) + sess.run(query_result) + + def test_sum_float_norm_not_rounded(self): + """Test that the float L2 norm bound doesn't get rounded/casted to integers.""" + with self.cached_session() as sess: + # A casted/rounded norm bound would be insufficient. + l2_norm_bound = 14.2 + record = tf.constant([10, 10], dtype=tf.int32) + query = dg_sum_query(l2_norm_bound=l2_norm_bound, stddev=0.0) + query_result, _ = test_utils.run_query(query, [record]) + result = sess.run(query_result) + expected = [10, 10] + self.assertAllEqual(result, expected) + + @parameterized.product(stddev=[10, 100, 1000]) + def test_noisy_sum(self, stddev): + num_trials = 1000 + record_1 = tf.zeros([num_trials], dtype=tf.int32) + record_2 = tf.ones([num_trials], dtype=tf.int32) + sample = [record_1, record_2] + query = dg_sum_query(l2_norm_bound=num_trials, stddev=stddev) + result, _ = test_utils.run_query(query, sample) + + sampled_noise = discrete_gaussian_utils.sample_discrete_gaussian( + scale=tf.cast(stddev, tf.int32), shape=[num_trials], dtype=tf.int32) + + result, sampled_noise = self.evaluate([result, sampled_noise]) + + # The standard error of the stddev should be roughly sigma / sqrt(2N - 2), + # (https://stats.stackexchange.com/questions/156518) so set a rtol to give + # < 0.01% of failure (within ~4 standard errors). + rtol = 4 / np.sqrt(2 * num_trials - 2) + self.assertAllClose(np.std(result), stddev, rtol=rtol) + + # Use standard error of the mean to compare percentiles. + stderr = stddev / np.sqrt(num_trials) + self.assertAllClose( + np.percentile(result, [25, 50, 75]), + np.percentile(sampled_noise, [25, 50, 75]), + atol=4 * stderr) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/dp_query/distributed_discrete_gaussian_query.py b/tensorflow_privacy/privacy/dp_query/distributed_discrete_gaussian_query.py index 5b450ee..8dd4dba 100644 --- a/tensorflow_privacy/privacy/dp_query/distributed_discrete_gaussian_query.py +++ b/tensorflow_privacy/privacy/dp_query/distributed_discrete_gaussian_query.py @@ -41,7 +41,7 @@ class DistributedDiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery): Args: l2_norm_bound: The L2 norm bound to verify for each record. - local_stddev: The scale/stddev of the local discrete Gaussian noise. + local_stddev: The stddev of the local discrete Gaussian noise. """ self._l2_norm_bound = l2_norm_bound self._local_stddev = local_stddev @@ -65,7 +65,7 @@ class DistributedDiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery): Args: record: The record to which we generate and add local noise. - local_stddev: The scale/stddev of the local discrete Gaussian noise. + local_stddev: The stddev of the local discrete Gaussian noise. shares: Number of shares of local noise to generate. Should be 1 for each record. This can be useful when we want to generate multiple noise shares at once. @@ -84,7 +84,7 @@ class DistributedDiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery): scale=ceil_local_stddev, shape=shape, dtype=v.dtype) # Sum across the number of noise shares and add it. noised_v = v + tf.reduce_sum(dgauss_noise, axis=0) - # Ensure shape as TF shape inference may fail due to custom noise sampler. + # Set shape as TF shape inference may fail due to custom noise sampler. noised_v.set_shape(v.shape.as_list()) return noised_v diff --git a/tensorflow_privacy/privacy/dp_query/distributed_discrete_gaussian_query_test.py b/tensorflow_privacy/privacy/dp_query/distributed_discrete_gaussian_query_test.py index b2f6051..1c1a461 100644 --- a/tensorflow_privacy/privacy/dp_query/distributed_discrete_gaussian_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/distributed_discrete_gaussian_query_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for DistributedDiscreteGaussianQuery.""" +"""Tests for DistributedDiscreteGaussianSumQuery.""" from absl.testing import parameterized import numpy as np