Adds central discrete Gaussian DPQuery.

PiperOrigin-RevId: 389467360
This commit is contained in:
Ken Liu 2021-08-08 03:43:01 -07:00 committed by A. Unique TensorFlower
parent aa3f841893
commit f3af24b00e
5 changed files with 242 additions and 4 deletions

View file

@ -43,6 +43,7 @@ else:
# DPQuery classes # DPQuery classes
from tensorflow_privacy.privacy.dp_query.dp_query import DPQuery from tensorflow_privacy.privacy.dp_query.dp_query import DPQuery
from tensorflow_privacy.privacy.dp_query.dp_query import SumAggregationDPQuery from tensorflow_privacy.privacy.dp_query.dp_query import SumAggregationDPQuery
from tensorflow_privacy.privacy.dp_query.discrete_gaussian_query import DiscreteGaussianSumQuery
from tensorflow_privacy.privacy.dp_query.distributed_discrete_gaussian_query import DistributedDiscreteGaussianSumQuery from tensorflow_privacy.privacy.dp_query.distributed_discrete_gaussian_query import DistributedDiscreteGaussianSumQuery
from tensorflow_privacy.privacy.dp_query.gaussian_query import GaussianSumQuery from tensorflow_privacy.privacy.dp_query.gaussian_query import GaussianSumQuery
from tensorflow_privacy.privacy.dp_query.nested_query import NestedQuery from tensorflow_privacy.privacy.dp_query.nested_query import NestedQuery

View file

@ -0,0 +1,89 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements DPQuery interface for discrete Gaussian mechanism."""
import collections
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
from tensorflow_privacy.privacy.dp_query import dp_query
class DiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
"""Implements DPQuery for discrete Gaussian sum queries.
For each local record, we check the L2 norm bound and add discrete Gaussian
noise. In particular, this DPQuery does not perform L2 norm clipping and the
norms of the input records are expected to be bounded.
"""
# pylint: disable=invalid-name
_GlobalState = collections.namedtuple('_GlobalState',
['l2_norm_bound', 'stddev'])
# pylint: disable=invalid-name
_SampleParams = collections.namedtuple('_SampleParams',
['l2_norm_bound', 'stddev'])
def __init__(self, l2_norm_bound, stddev):
"""Initializes the DiscreteGaussianSumQuery.
Args:
l2_norm_bound: The L2 norm bound to verify for each record.
stddev: The stddev of the discrete Gaussian noise added to the sum.
"""
self._l2_norm_bound = l2_norm_bound
self._stddev = stddev
def set_ledger(self, ledger):
del ledger # Unused.
raise NotImplementedError('Ledger has not yet been implemented for'
'DiscreteGaussianSumQuery!')
def initial_global_state(self):
return self._GlobalState(
tf.cast(self._l2_norm_bound, tf.float32),
tf.cast(self._stddev, tf.float32))
def derive_sample_params(self, global_state):
return self._SampleParams(global_state.l2_norm_bound, global_state.stddev)
def preprocess_record(self, params, record):
"""Check record norm and add noise to the record."""
record_as_list = tf.nest.flatten(record)
record_as_float_list = [tf.cast(x, tf.float32) for x in record_as_list]
tf.nest.map_structure(lambda x: tf.compat.v1.assert_type(x, tf.int32),
record_as_list)
dependencies = [
tf.compat.v1.assert_less_equal(
tf.linalg.global_norm(record_as_float_list),
params.l2_norm_bound,
message=f'Global L2 norm exceeds {params.l2_norm_bound}.')
]
with tf.control_dependencies(dependencies):
return tf.nest.map_structure(tf.identity, record)
def get_noised_result(self, sample_state, global_state):
"""Adds discrete Gaussian noise to the aggregate."""
# Round up the noise as the TF discrete Gaussian sampler only takes
# integer noise stddevs for now.
ceil_stddev = tf.cast(tf.math.ceil(global_state.stddev), tf.int32)
def add_noise(v):
noised_v = v + discrete_gaussian_utils.sample_discrete_gaussian(
scale=ceil_stddev, shape=tf.shape(v), dtype=v.dtype)
# Ensure shape as TF shape inference may fail due to custom noise sampler.
return tf.ensure_shape(noised_v, v.shape)
return tf.nest.map_structure(add_noise, sample_state), global_state

View file

@ -0,0 +1,148 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for DiscreteGaussianSumQuery."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_query
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
from tensorflow_privacy.privacy.dp_query import test_utils
dg_sum_query = discrete_gaussian_query.DiscreteGaussianSumQuery
def silence_tf_error_messages(func):
"""Decorator that temporarily changes the TF logging levels."""
def wrapper(*args, **kwargs):
cur_verbosity = tf.compat.v1.logging.get_verbosity()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.FATAL)
func(*args, **kwargs)
tf.compat.v1.logging.set_verbosity(cur_verbosity) # Reset verbosity.
return wrapper
class DiscreteGaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_sum_no_noise(self):
with self.cached_session() as sess:
record1 = tf.constant([2, 0], dtype=tf.int32)
record2 = tf.constant([-1, 1], dtype=tf.int32)
query = dg_sum_query(l2_norm_bound=10, stddev=0.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result)
expected = [1, 1]
self.assertAllEqual(result, expected)
@parameterized.product(sample_size=[1, 3])
def test_sum_multiple_shapes(self, sample_size):
with self.cached_session() as sess:
t1 = tf.constant([2, 0], dtype=tf.int32)
t2 = tf.constant([-1, 1, 3], dtype=tf.int32)
t3 = tf.constant([-2], dtype=tf.int32)
record = [t1, t2, t3]
sample = [record] * sample_size
query = dg_sum_query(l2_norm_bound=10, stddev=0.0)
query_result, _ = test_utils.run_query(query, sample)
expected = [sample_size * t1, sample_size * t2, sample_size * t3]
result, expected = sess.run([query_result, expected])
# Use `assertAllClose` for nested structures equality (with tolerance=0).
self.assertAllClose(result, expected, atol=0)
@parameterized.product(sample_size=[1, 3])
def test_sum_nested_record_structure(self, sample_size):
with self.cached_session() as sess:
t1 = tf.constant([1, 0], dtype=tf.int32)
t2 = tf.constant([1, 1, 1], dtype=tf.int32)
t3 = tf.constant([1], dtype=tf.int32)
t4 = tf.constant([[1, 1], [1, 1]], dtype=tf.int32)
record = [t1, dict(a=t2, b=[t3, (t4, t1)])]
sample = [record] * sample_size
query = dg_sum_query(l2_norm_bound=10, stddev=0.0)
query_result, _ = test_utils.run_query(query, sample)
result = sess.run(query_result)
s = sample_size
expected = [t1 * s, dict(a=t2 * s, b=[t3 * s, (t4 * s, t1 * s)])]
# Use `assertAllClose` for nested structures equality (with tolerance=0)
self.assertAllClose(result, expected, atol=0)
def test_sum_raise_on_float_inputs(self):
with self.cached_session() as sess:
record1 = tf.constant([2, 0], dtype=tf.float32)
record2 = tf.constant([-1, 1], dtype=tf.float32)
query = dg_sum_query(l2_norm_bound=10, stddev=0.0)
with self.assertRaises(TypeError):
query_result, _ = test_utils.run_query(query, [record1, record2])
sess.run(query_result)
@parameterized.product(l2_norm_bound=[0, 3, 10, 14.1])
@silence_tf_error_messages
def test_sum_raise_on_l2_norm_excess(self, l2_norm_bound):
with self.cached_session() as sess:
record = tf.constant([10, 10], dtype=tf.int32)
query = dg_sum_query(l2_norm_bound=l2_norm_bound, stddev=0.0)
with self.assertRaises(tf.errors.InvalidArgumentError):
query_result, _ = test_utils.run_query(query, [record])
sess.run(query_result)
def test_sum_float_norm_not_rounded(self):
"""Test that the float L2 norm bound doesn't get rounded/casted to integers."""
with self.cached_session() as sess:
# A casted/rounded norm bound would be insufficient.
l2_norm_bound = 14.2
record = tf.constant([10, 10], dtype=tf.int32)
query = dg_sum_query(l2_norm_bound=l2_norm_bound, stddev=0.0)
query_result, _ = test_utils.run_query(query, [record])
result = sess.run(query_result)
expected = [10, 10]
self.assertAllEqual(result, expected)
@parameterized.product(stddev=[10, 100, 1000])
def test_noisy_sum(self, stddev):
num_trials = 1000
record_1 = tf.zeros([num_trials], dtype=tf.int32)
record_2 = tf.ones([num_trials], dtype=tf.int32)
sample = [record_1, record_2]
query = dg_sum_query(l2_norm_bound=num_trials, stddev=stddev)
result, _ = test_utils.run_query(query, sample)
sampled_noise = discrete_gaussian_utils.sample_discrete_gaussian(
scale=tf.cast(stddev, tf.int32), shape=[num_trials], dtype=tf.int32)
result, sampled_noise = self.evaluate([result, sampled_noise])
# The standard error of the stddev should be roughly sigma / sqrt(2N - 2),
# (https://stats.stackexchange.com/questions/156518) so set a rtol to give
# < 0.01% of failure (within ~4 standard errors).
rtol = 4 / np.sqrt(2 * num_trials - 2)
self.assertAllClose(np.std(result), stddev, rtol=rtol)
# Use standard error of the mean to compare percentiles.
stderr = stddev / np.sqrt(num_trials)
self.assertAllClose(
np.percentile(result, [25, 50, 75]),
np.percentile(sampled_noise, [25, 50, 75]),
atol=4 * stderr)
if __name__ == '__main__':
tf.test.main()

View file

@ -41,7 +41,7 @@ class DistributedDiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
Args: Args:
l2_norm_bound: The L2 norm bound to verify for each record. l2_norm_bound: The L2 norm bound to verify for each record.
local_stddev: The scale/stddev of the local discrete Gaussian noise. local_stddev: The stddev of the local discrete Gaussian noise.
""" """
self._l2_norm_bound = l2_norm_bound self._l2_norm_bound = l2_norm_bound
self._local_stddev = local_stddev self._local_stddev = local_stddev
@ -65,7 +65,7 @@ class DistributedDiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
Args: Args:
record: The record to which we generate and add local noise. record: The record to which we generate and add local noise.
local_stddev: The scale/stddev of the local discrete Gaussian noise. local_stddev: The stddev of the local discrete Gaussian noise.
shares: Number of shares of local noise to generate. Should be 1 for each shares: Number of shares of local noise to generate. Should be 1 for each
record. This can be useful when we want to generate multiple noise record. This can be useful when we want to generate multiple noise
shares at once. shares at once.
@ -84,7 +84,7 @@ class DistributedDiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
scale=ceil_local_stddev, shape=shape, dtype=v.dtype) scale=ceil_local_stddev, shape=shape, dtype=v.dtype)
# Sum across the number of noise shares and add it. # Sum across the number of noise shares and add it.
noised_v = v + tf.reduce_sum(dgauss_noise, axis=0) noised_v = v + tf.reduce_sum(dgauss_noise, axis=0)
# Ensure shape as TF shape inference may fail due to custom noise sampler. # Set shape as TF shape inference may fail due to custom noise sampler.
noised_v.set_shape(v.shape.as_list()) noised_v.set_shape(v.shape.as_list())
return noised_v return noised_v

View file

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for DistributedDiscreteGaussianQuery.""" """Tests for DistributedDiscreteGaussianSumQuery."""
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np