Adds discrete Gaussian (sampler and distributed DPQuery) to public TF Privacy.
PiperOrigin-RevId: 387232449
This commit is contained in:
parent
2f862eba9b
commit
e7e11d14d9
5 changed files with 697 additions and 0 deletions
|
@ -43,6 +43,7 @@ else:
|
|||
# DPQuery classes
|
||||
from tensorflow_privacy.privacy.dp_query.dp_query import DPQuery
|
||||
from tensorflow_privacy.privacy.dp_query.dp_query import SumAggregationDPQuery
|
||||
from tensorflow_privacy.privacy.dp_query.distributed_discrete_gaussian_query import DistributedDiscreteGaussianSumQuery
|
||||
from tensorflow_privacy.privacy.dp_query.gaussian_query import GaussianSumQuery
|
||||
from tensorflow_privacy.privacy.dp_query.nested_query import NestedQuery
|
||||
from tensorflow_privacy.privacy.dp_query.no_privacy_query import NoPrivacyAverageQuery
|
||||
|
|
142
tensorflow_privacy/privacy/dp_query/discrete_gaussian_utils.py
Normal file
142
tensorflow_privacy/privacy/dp_query/discrete_gaussian_utils.py
Normal file
|
@ -0,0 +1,142 @@
|
|||
# Copyright 2021, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Util functions for drawing discrete Gaussian samples.
|
||||
|
||||
The following functions implement a vectorized TF version of the sampling
|
||||
algorithm described in the paper:
|
||||
|
||||
The Discrete Gaussian for Differential Privacy
|
||||
https://arxiv.org/pdf/2004.00010.pdf
|
||||
|
||||
Note that the exact sampling implementation should use integer and fractional
|
||||
parameters only. Here, we relax this constraint a bit and use vectorized
|
||||
implementations of Bernoulli and discrete Laplace sampling that can take float
|
||||
parameters.
|
||||
"""
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_probability as tf_prob
|
||||
|
||||
|
||||
def _sample_discrete_laplace(t, shape):
|
||||
"""Sample from discrete Laplace with scale t.
|
||||
|
||||
This method is based on the observation that sampling from Z ~ Lap(t) is
|
||||
equivalent to sampling X, Y independently from Geo(1 - exp(-1/t)) and take
|
||||
Z = X - Y.
|
||||
|
||||
Note also that tensorflow_probability's geometric sampler is based on floating
|
||||
operations and may possibly be inexact.
|
||||
|
||||
Args:
|
||||
t: The scale of the discrete Laplace distribution.
|
||||
shape: The tensor shape of the tensors drawn.
|
||||
|
||||
Returns:
|
||||
A tensor of the specified shape filled with random values.
|
||||
"""
|
||||
geometric_probs = 1.0 - tf.exp(-1.0 / tf.cast(t, tf.float64))
|
||||
sampler = tf_prob.distributions.Geometric(probs=geometric_probs)
|
||||
return tf.cast(sampler.sample(shape) - sampler.sample(shape), tf.int64)
|
||||
|
||||
|
||||
def _sample_bernoulli(p):
|
||||
"""Sample from Bernoulli(p)."""
|
||||
return tf_prob.distributions.Bernoulli(probs=p, dtype=tf.int64).sample()
|
||||
|
||||
|
||||
def _check_input_args(scale, shape, dtype):
|
||||
"""Checks the input args to the discrete Gaussian sampler."""
|
||||
if tf.as_dtype(dtype) not in (tf.int32, tf.int64):
|
||||
raise ValueError(
|
||||
f'Only tf.int32 and tf.int64 are supported. Found dtype `{dtype}`.')
|
||||
|
||||
checks = [
|
||||
tf.compat.v1.assert_non_negative(scale),
|
||||
tf.compat.v1.assert_integer(scale)
|
||||
]
|
||||
with tf.control_dependencies(checks):
|
||||
return tf.identity(scale), shape, dtype
|
||||
|
||||
|
||||
def _int_square(value):
|
||||
"""Avoids the TF op `Square(T=...)` for ints as sampling can happen on clients."""
|
||||
return (value - 1) * (value + 1) + 1
|
||||
|
||||
|
||||
@tf.function
|
||||
def _sample_discrete_gaussian_helper(scale, shape, dtype):
|
||||
"""Draw samples from discrete Gaussian, assuming scale >= 0."""
|
||||
scale = tf.cast(scale, tf.int64)
|
||||
sq_scale = _int_square(scale)
|
||||
|
||||
# Scale for discrete Laplace. The sampling algorithm should be correct
|
||||
# for any discrete Laplace scale, and the original paper uses
|
||||
# `dlap_scale = floor(scale) + 1`. Here we use `dlap_scale = scale` (where
|
||||
# input `scale` is restricted to integers >= 1) to simplify the fraction
|
||||
# below. It turns out that for integer scales >= 1, `dlap_scale = scale` gives
|
||||
# a good minimum success rate of ~70%, allowing a small oversampling factor.
|
||||
dlap_scale = scale
|
||||
oversample_factor = 1.5
|
||||
|
||||
# Draw at least some samples in case we got unlucky with small input shape.
|
||||
min_n = 1000
|
||||
target_n = tf.reduce_prod(tf.cast(shape, tf.int64))
|
||||
oversample_n = oversample_factor * tf.cast(target_n, tf.float32)
|
||||
draw_n = tf.maximum(min_n, tf.cast(oversample_n, tf.int32))
|
||||
|
||||
accepted_n = tf.constant(0, dtype=target_n.dtype)
|
||||
result = tf.zeros((0,), dtype=tf.int64)
|
||||
|
||||
while accepted_n < target_n:
|
||||
# Since the number of samples could be different in every retry, we need to
|
||||
# manually specify the shape info for TF.
|
||||
tf.autograph.experimental.set_loop_options(
|
||||
shape_invariants=[(result, tf.TensorShape([None]))])
|
||||
# Draw samples.
|
||||
samples = _sample_discrete_laplace(dlap_scale, shape=(draw_n,))
|
||||
z_numer = _int_square(tf.abs(samples) - scale)
|
||||
z_denom = 2 * sq_scale
|
||||
bern_probs = tf.exp(-1.0 * tf.divide(z_numer, z_denom))
|
||||
accept = _sample_bernoulli(bern_probs)
|
||||
# Keep successful samples and increment counter.
|
||||
accepted_samples = samples[tf.equal(accept, 1)]
|
||||
accepted_n += tf.cast(tf.size(accepted_samples), accepted_n.dtype)
|
||||
result = tf.concat([result, accepted_samples], axis=0)
|
||||
# Reduce the number of draws for any retries.
|
||||
draw_n = tf.cast(target_n - accepted_n, tf.float32) * oversample_factor
|
||||
draw_n = tf.maximum(min_n, tf.cast(draw_n, tf.int32))
|
||||
|
||||
return tf.cast(tf.reshape(result[:target_n], shape), dtype)
|
||||
|
||||
|
||||
def sample_discrete_gaussian(scale, shape, dtype=tf.int32):
|
||||
"""Draws (possibly inexact) samples from the discrete Gaussian distribution.
|
||||
|
||||
We relax some integer constraints to use vectorized implementations of
|
||||
Bernoulli and discrete Laplace sampling. Integer operations are done in
|
||||
tf.int64 as TF does not have direct support for fractions.
|
||||
|
||||
Args:
|
||||
scale: The scale of the discrete Gaussian distribution.
|
||||
shape: The shape of the output tensor.
|
||||
dtype: The type of the output.
|
||||
|
||||
Returns:
|
||||
A tensor of the specified shape filled with random values.
|
||||
"""
|
||||
scale, shape, dtype = _check_input_args(scale, shape, dtype)
|
||||
return tf.cond(
|
||||
tf.equal(scale, 0), lambda: tf.zeros(shape, dtype),
|
||||
lambda: _sample_discrete_gaussian_helper(scale, shape, dtype))
|
|
@ -0,0 +1,275 @@
|
|||
# Copyright 2021, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for discrete_gaussian_utils."""
|
||||
|
||||
import collections
|
||||
import fractions
|
||||
import math
|
||||
import random
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
|
||||
|
||||
EXACT_SAMPLER_SEED = 4242
|
||||
|
||||
|
||||
class DiscreteGaussianUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.product(dtype=[tf.bool, tf.float32, tf.float64])
|
||||
def test_raise_on_bad_dtype(self, dtype):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = discrete_gaussian_utils.sample_discrete_gaussian(1, (1,), dtype)
|
||||
|
||||
def test_raise_on_negative_scale(self):
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
_ = discrete_gaussian_utils.sample_discrete_gaussian(-10, (1,))
|
||||
|
||||
def test_raise_on_float_scale(self):
|
||||
with self.assertRaises(TypeError):
|
||||
_ = discrete_gaussian_utils.sample_discrete_gaussian(3.14, (1,))
|
||||
|
||||
@parameterized.product(shape=[(), (1,), (100,), (2, 2), (3, 3, 3),
|
||||
(4, 1, 1, 1)])
|
||||
def test_shapes(self, shape):
|
||||
samples = discrete_gaussian_utils.sample_discrete_gaussian(100, shape)
|
||||
samples = self.evaluate(samples)
|
||||
self.assertAllEqual(samples.shape, shape)
|
||||
|
||||
@parameterized.product(dtype=[tf.int32, tf.int64])
|
||||
def test_dtypes(self, dtype):
|
||||
samples = discrete_gaussian_utils.sample_discrete_gaussian(1, (10,), dtype)
|
||||
samples = self.evaluate(samples)
|
||||
# Convert output np dtypes to tf dtypes.
|
||||
self.assertEqual(tf.as_dtype(samples.dtype), dtype)
|
||||
|
||||
def test_zero_noise(self):
|
||||
scale = 0
|
||||
shape = (100,)
|
||||
dtype = tf.int32
|
||||
samples = discrete_gaussian_utils.sample_discrete_gaussian(
|
||||
scale, shape, dtype=dtype)
|
||||
samples = self.evaluate(samples)
|
||||
self.assertAllEqual(samples, tf.zeros(shape, dtype=dtype))
|
||||
|
||||
@parameterized.named_parameters([('small_scale_small_n', 10, 2000, 1, 2),
|
||||
('small_scale_large_n', 10, 5000, 1, 1),
|
||||
('large_scale_small_n', 50, 2000, 2, 5),
|
||||
('large_scale_large_n', 50, 5000, 2, 3)])
|
||||
def test_match_exact_sampler(self, scale, num_samples, mean_std_atol,
|
||||
percentile_atol):
|
||||
true_samples = exact_sampler(scale, num_samples)
|
||||
drawn_samples = discrete_gaussian_utils.sample_discrete_gaussian(
|
||||
scale=scale, shape=(num_samples,))
|
||||
drawn_samples = self.evaluate(drawn_samples)
|
||||
|
||||
# Check mean, std, and percentiles.
|
||||
self.assertAllClose(
|
||||
np.mean(true_samples), np.mean(drawn_samples), atol=mean_std_atol)
|
||||
self.assertAllClose(
|
||||
np.std(true_samples), np.std(drawn_samples), atol=mean_std_atol)
|
||||
self.assertAllClose(
|
||||
np.percentile(true_samples, [10, 30, 50, 70, 90]),
|
||||
np.percentile(drawn_samples, [10, 30, 50, 70, 90]),
|
||||
atol=percentile_atol)
|
||||
|
||||
@parameterized.named_parameters([('n_1000', 1000, 5e-2),
|
||||
('n_10000', 10000, 5e-3)])
|
||||
def test_kl_divergence(self, num_samples, kl_tolerance):
|
||||
"""Compute KL divergence betwen empirical & true distribution."""
|
||||
scale = 10
|
||||
sq_sigma = scale * scale
|
||||
drawn_samples = discrete_gaussian_utils.sample_discrete_gaussian(
|
||||
scale=scale, shape=(num_samples,))
|
||||
drawn_samples = self.evaluate(drawn_samples)
|
||||
value_counts = collections.Counter(drawn_samples)
|
||||
|
||||
kl = 0
|
||||
norm_const = dgauss_normalizing_constant(sq_sigma)
|
||||
|
||||
for value, count in value_counts.items():
|
||||
kl += count * (
|
||||
math.log(count * norm_const / num_samples) + value * value /
|
||||
(2.0 * sq_sigma))
|
||||
|
||||
kl /= num_samples
|
||||
self.assertLess(kl, kl_tolerance)
|
||||
|
||||
|
||||
def exact_sampler(scale, num_samples, seed=EXACT_SAMPLER_SEED):
|
||||
"""Implementation of the exact discrete gaussian distribution sampler.
|
||||
|
||||
Source: https://arxiv.org/pdf/2004.00010.pdf.
|
||||
|
||||
Args:
|
||||
scale: The scale of the discrete Gaussian.
|
||||
num_samples: The number of samples to generate.
|
||||
seed: The seed for the random number generator to reproduce samples.
|
||||
|
||||
Returns:
|
||||
A numpy array of discrete Gaussian samples.
|
||||
"""
|
||||
|
||||
def randrange(a, rng):
|
||||
return rng.randrange(a)
|
||||
|
||||
def bern_em1(rng):
|
||||
"""Sample from Bernoulli(exp(-1))."""
|
||||
k = 2
|
||||
while True:
|
||||
if randrange(k, rng) == 0: # if Bernoulli(1/k)==1
|
||||
k = k + 1
|
||||
else:
|
||||
return k % 2
|
||||
|
||||
def bern_emab1(a, b, rng):
|
||||
"""Sample from Bernoulli(exp(-a/b)), assuming 0 <= a <= b."""
|
||||
assert isinstance(a, int)
|
||||
assert isinstance(b, int)
|
||||
assert 0 <= a <= b
|
||||
k = 1
|
||||
while True:
|
||||
if randrange(b, rng) < a and randrange(k, rng) == 0: # if Bern(a/b/k)==1
|
||||
k = k + 1
|
||||
else:
|
||||
return k % 2
|
||||
|
||||
def bern_emab(a, b, rng):
|
||||
"""Sample from Bernoulli(exp(-a/b)), allowing a > b."""
|
||||
while a > b:
|
||||
if bern_em1(rng) == 0:
|
||||
return 0
|
||||
a = a - b
|
||||
return bern_emab1(a, b, rng)
|
||||
|
||||
def geometric(t, rng):
|
||||
"""Sample from geometric(1-exp(-1/t))."""
|
||||
assert isinstance(t, int)
|
||||
assert t > 0
|
||||
while True:
|
||||
u = randrange(t, rng)
|
||||
if bern_emab1(u, t, rng) == 1:
|
||||
while bern_em1(rng) == 1:
|
||||
u = u + t
|
||||
return u
|
||||
|
||||
def dlap(t, rng):
|
||||
"""Sample from discrete Laplace with scale t.
|
||||
|
||||
Pr[x] = exp(-|x|/t) * (exp(1/t)-1)/(exp(1/t)+1). Supported on integers.
|
||||
|
||||
Args:
|
||||
t: The scale.
|
||||
rng: The random number generator.
|
||||
|
||||
Returns:
|
||||
A discrete Laplace sample.
|
||||
"""
|
||||
assert isinstance(t, int)
|
||||
assert t > 0
|
||||
while True:
|
||||
u = geometric(t, rng)
|
||||
b = randrange(2, rng)
|
||||
if b == 1:
|
||||
return u
|
||||
elif u > 0:
|
||||
return -u
|
||||
|
||||
def floorsqrt(x):
|
||||
"""Compute floor(sqrt(x)) exactly."""
|
||||
assert x >= 0
|
||||
a = 0 # maintain a^2<=x.
|
||||
b = 1 # maintain b^2>x.
|
||||
while b * b <= x:
|
||||
b = 2 * b
|
||||
# Do binary search.
|
||||
while a + 1 < b:
|
||||
c = (a + b) // 2
|
||||
if c * c <= x:
|
||||
a = c
|
||||
else:
|
||||
b = c
|
||||
return a
|
||||
|
||||
def dgauss(ss, num, rng):
|
||||
"""Sample from discrete Gaussian.
|
||||
|
||||
Args:
|
||||
ss: Variance proxy, squared scale, sigma^2.
|
||||
num: The number of samples to generate.
|
||||
rng: The random number generator.
|
||||
|
||||
Returns:
|
||||
A list of discrete Gaussian samples.
|
||||
"""
|
||||
ss = fractions.Fraction(ss) # cast to rational for exact arithmetic
|
||||
assert ss > 0
|
||||
t = floorsqrt(ss) + 1
|
||||
results = []
|
||||
trials = 0
|
||||
while len(results) < num:
|
||||
trials = trials + 1
|
||||
y = dlap(t, rng)
|
||||
z = (abs(y) - ss / t)**2 / (2 * ss)
|
||||
if bern_emab(z.numerator, z.denominator, rng) == 1:
|
||||
results.append(y)
|
||||
return results, t, trials
|
||||
|
||||
rng = random.Random(seed)
|
||||
return np.array(dgauss(scale * scale, num_samples, rng)[0])
|
||||
|
||||
|
||||
def dgauss_normalizing_constant(sigma_sq):
|
||||
"""Compute the normalizing constant of the discrete Gaussian.
|
||||
|
||||
Source: https://arxiv.org/pdf/2004.00010.pdf.
|
||||
|
||||
Args:
|
||||
sigma_sq: Variance proxy, squared scale, sigma^2.
|
||||
|
||||
Returns:
|
||||
The normalizing constant.
|
||||
"""
|
||||
original = None
|
||||
poisson = None
|
||||
if sigma_sq <= 1:
|
||||
original = 0
|
||||
x = 1000
|
||||
while x > 0:
|
||||
original = original + math.exp(-x * x / (2.0 * sigma_sq))
|
||||
x = x - 1
|
||||
original = 2 * original + 1
|
||||
|
||||
if sigma_sq * 100 >= 1:
|
||||
poisson = 0
|
||||
y = 1000
|
||||
while y > 0:
|
||||
poisson = poisson + math.exp(-math.pi * math.pi * sigma_sq * 2 * y * y)
|
||||
y = y - 1
|
||||
poisson = math.sqrt(2 * math.pi * sigma_sq) * (1 + 2 * poisson)
|
||||
|
||||
if poisson is None:
|
||||
return original
|
||||
if original is None:
|
||||
return poisson
|
||||
|
||||
scale = max(1, math.sqrt(2 * math.pi * sigma_sq))
|
||||
precision = 1e-15
|
||||
assert -precision * scale <= original - poisson <= precision * scale
|
||||
return (original + poisson) / 2
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright 2021, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Implements DPQuery interface for distributed discrete Gaussian mechanism."""
|
||||
|
||||
import collections
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
|
||||
|
||||
class DistributedDiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||
"""Implements DPQuery for discrete distributed Gaussian sum queries.
|
||||
|
||||
For each local record, we check the L2 norm bound and add discrete Gaussian
|
||||
noise. In particular, this DPQuery does not perform L2 norm clipping and the
|
||||
norms of the input records are expected to be bounded.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_GlobalState = collections.namedtuple('_GlobalState',
|
||||
['l2_norm_bound', 'local_stddev'])
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_SampleParams = collections.namedtuple('_SampleParams',
|
||||
['l2_norm_bound', 'local_stddev'])
|
||||
|
||||
def __init__(self, l2_norm_bound, local_stddev):
|
||||
"""Initializes the DistributedDiscreteGaussianSumQuery.
|
||||
|
||||
Args:
|
||||
l2_norm_bound: The L2 norm bound to verify for each record.
|
||||
local_stddev: The scale/stddev of the local discrete Gaussian noise.
|
||||
"""
|
||||
self._l2_norm_bound = l2_norm_bound
|
||||
self._local_stddev = local_stddev
|
||||
|
||||
def set_ledger(self, ledger):
|
||||
del ledger # Unused.
|
||||
raise NotImplementedError('Ledger has not yet been implemented for'
|
||||
'DistributedDiscreteGaussianSumQuery!')
|
||||
|
||||
def initial_global_state(self):
|
||||
return self._GlobalState(
|
||||
tf.cast(self._l2_norm_bound, tf.float32),
|
||||
tf.cast(self._local_stddev, tf.float32))
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
return self._SampleParams(global_state.l2_norm_bound,
|
||||
global_state.local_stddev)
|
||||
|
||||
def _add_local_noise(self, record, local_stddev, shares=1):
|
||||
"""Add local discrete Gaussian noise to the record.
|
||||
|
||||
Args:
|
||||
record: The record to which we generate and add local noise.
|
||||
local_stddev: The scale/stddev of the local discrete Gaussian noise.
|
||||
shares: Number of shares of local noise to generate. Should be 1 for each
|
||||
record. This can be useful when we want to generate multiple noise
|
||||
shares at once.
|
||||
|
||||
Returns:
|
||||
The record with local noise added.
|
||||
"""
|
||||
# Round up the noise as the TF discrete Gaussian sampler only takes
|
||||
# integer noise stddevs for now.
|
||||
ceil_local_stddev = tf.cast(tf.math.ceil(local_stddev), tf.int32)
|
||||
|
||||
def add_noise(v):
|
||||
# Adds an extra dimension for `shares` number of draws.
|
||||
shape = tf.concat([[shares], tf.shape(v)], axis=0)
|
||||
dgauss_noise = discrete_gaussian_utils.sample_discrete_gaussian(
|
||||
scale=ceil_local_stddev, shape=shape, dtype=v.dtype)
|
||||
# Sum across the number of noise shares and add it.
|
||||
noised_v = v + tf.reduce_sum(dgauss_noise, axis=0)
|
||||
# Ensure shape as TF shape inference may fail due to custom noise sampler.
|
||||
noised_v.set_shape(v.shape.as_list())
|
||||
return noised_v
|
||||
|
||||
return tf.nest.map_structure(add_noise, record)
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""Check record norm and add noise to the record."""
|
||||
record_as_list = tf.nest.flatten(record)
|
||||
record_as_float_list = [tf.cast(x, tf.float32) for x in record_as_list]
|
||||
tf.nest.map_structure(lambda x: tf.compat.v1.assert_type(x, tf.int32),
|
||||
record_as_list)
|
||||
dependencies = [
|
||||
tf.compat.v1.assert_less_equal(
|
||||
tf.linalg.global_norm(record_as_float_list),
|
||||
params.l2_norm_bound,
|
||||
message=f'Global L2 norm exceeds {params.l2_norm_bound}.')
|
||||
]
|
||||
with tf.control_dependencies(dependencies):
|
||||
result = tf.cond(
|
||||
tf.equal(params.local_stddev, 0), lambda: record,
|
||||
lambda: self._add_local_noise(record, params.local_stddev))
|
||||
return result
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
# Note that by directly returning the aggregate, this assumes that there
|
||||
# will not be missing local noise shares during execution.
|
||||
return sample_state, global_state
|
|
@ -0,0 +1,165 @@
|
|||
# Copyright 2021, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for DistributedDiscreteGaussianQuery."""
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
|
||||
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
|
||||
from tensorflow_privacy.privacy.dp_query import test_utils
|
||||
|
||||
ddg_sum_query = distributed_discrete_gaussian_query.DistributedDiscreteGaussianSumQuery
|
||||
|
||||
|
||||
def silence_tf_error_messages(func):
|
||||
"""Decorator that temporarily changes the TF logging levels."""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
cur_verbosity = tf.compat.v1.logging.get_verbosity()
|
||||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.FATAL)
|
||||
func(*args, **kwargs)
|
||||
tf.compat.v1.logging.set_verbosity(cur_verbosity) # Reset verbosity.
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class DistributedDiscreteGaussianQueryTest(tf.test.TestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def test_sum_no_noise(self):
|
||||
with self.cached_session() as sess:
|
||||
record1 = tf.constant([2, 0], dtype=tf.int32)
|
||||
record2 = tf.constant([-1, 1], dtype=tf.int32)
|
||||
|
||||
query = ddg_sum_query(l2_norm_bound=10, local_stddev=0.0)
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [1, 1]
|
||||
self.assertAllEqual(result, expected)
|
||||
|
||||
@parameterized.product(sample_size=[1, 3])
|
||||
def test_sum_multiple_shapes(self, sample_size):
|
||||
with self.cached_session() as sess:
|
||||
t1 = tf.constant([2, 0], dtype=tf.int32)
|
||||
t2 = tf.constant([-1, 1, 3], dtype=tf.int32)
|
||||
t3 = tf.constant([-2], dtype=tf.int32)
|
||||
record = [t1, t2, t3]
|
||||
sample = [record] * sample_size
|
||||
|
||||
query = ddg_sum_query(l2_norm_bound=10, local_stddev=0.0)
|
||||
query_result, _ = test_utils.run_query(query, sample)
|
||||
expected = [sample_size * t1, sample_size * t2, sample_size * t3]
|
||||
result, expected = sess.run([query_result, expected])
|
||||
# Use `assertAllClose` for nested structures equality (with tolerance=0).
|
||||
self.assertAllClose(result, expected, atol=0)
|
||||
|
||||
@parameterized.product(sample_size=[1, 3])
|
||||
def test_sum_nested_record_structure(self, sample_size):
|
||||
with self.cached_session() as sess:
|
||||
t1 = tf.constant([1, 0], dtype=tf.int32)
|
||||
t2 = tf.constant([1, 1, 1], dtype=tf.int32)
|
||||
t3 = tf.constant([1], dtype=tf.int32)
|
||||
t4 = tf.constant([[1, 1], [1, 1]], dtype=tf.int32)
|
||||
record = [t1, dict(a=t2, b=[t3, (t4, t1)])]
|
||||
sample = [record] * sample_size
|
||||
|
||||
query = ddg_sum_query(l2_norm_bound=10, local_stddev=0.0)
|
||||
query_result, _ = test_utils.run_query(query, sample)
|
||||
result = sess.run(query_result)
|
||||
|
||||
s = sample_size
|
||||
expected = [t1 * s, dict(a=t2 * s, b=[t3 * s, (t4 * s, t1 * s)])]
|
||||
# Use `assertAllClose` for nested structures equality (with tolerance=0)
|
||||
self.assertAllClose(result, expected, atol=0)
|
||||
|
||||
def test_sum_raise_on_float_inputs(self):
|
||||
with self.cached_session() as sess:
|
||||
record1 = tf.constant([2, 0], dtype=tf.float32)
|
||||
record2 = tf.constant([-1, 1], dtype=tf.float32)
|
||||
query = ddg_sum_query(l2_norm_bound=10, local_stddev=0.0)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
sess.run(query_result)
|
||||
|
||||
@parameterized.product(l2_norm_bound=[0, 3, 10, 14.1])
|
||||
@silence_tf_error_messages
|
||||
def test_sum_raise_on_l2_norm_excess(self, l2_norm_bound):
|
||||
with self.cached_session() as sess:
|
||||
record = tf.constant([10, 10], dtype=tf.int32)
|
||||
query = ddg_sum_query(l2_norm_bound=l2_norm_bound, local_stddev=0.0)
|
||||
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
query_result, _ = test_utils.run_query(query, [record])
|
||||
sess.run(query_result)
|
||||
|
||||
def test_sum_float_norm_not_rounded(self):
|
||||
"""Test that the float L2 norm bound doesn't get rounded/casted to integers."""
|
||||
with self.cached_session() as sess:
|
||||
# A casted/rounded norm bound would be insufficient.
|
||||
l2_norm_bound = 14.2
|
||||
record = tf.constant([10, 10], dtype=tf.int32)
|
||||
query = ddg_sum_query(l2_norm_bound=l2_norm_bound, local_stddev=0.0)
|
||||
query_result, _ = test_utils.run_query(query, [record])
|
||||
result = sess.run(query_result)
|
||||
expected = [10, 10]
|
||||
self.assertAllEqual(result, expected)
|
||||
|
||||
@parameterized.named_parameters([('2_local_stddev_1_record', 2, 1),
|
||||
('10_local_stddev_4_records', 10, 4),
|
||||
('1000_local_stddev_1_record', 1000, 1),
|
||||
('1000_local_stddev_25_records', 1000, 25)])
|
||||
def test_sum_local_noise_shares(self, local_stddev, num_records):
|
||||
"""Test the noise level of the sum of discrete Gaussians applied locally.
|
||||
|
||||
The sum of discrete Gaussians is not a discrete Gaussian, but it will be
|
||||
extremely close for sigma >= 2. We will thus compare the aggregated noise
|
||||
to a central discrete Gaussian noise with appropriately scaled stddev with
|
||||
some reasonable tolerance.
|
||||
|
||||
Args:
|
||||
local_stddev: The stddev of the local discrete Gaussian noise.
|
||||
num_records: The number of records to be aggregated.
|
||||
"""
|
||||
# Aggregated local noises.
|
||||
num_trials = 1000
|
||||
record = tf.zeros([num_trials], dtype=tf.int32)
|
||||
sample = [record] * num_records
|
||||
query = ddg_sum_query(l2_norm_bound=10.0, local_stddev=local_stddev)
|
||||
query_result, _ = test_utils.run_query(query, sample)
|
||||
|
||||
# Central discrete Gaussian noise.
|
||||
central_stddev = np.sqrt(num_records) * local_stddev
|
||||
central_noise = discrete_gaussian_utils.sample_discrete_gaussian(
|
||||
scale=tf.cast(tf.round(central_stddev), record.dtype),
|
||||
shape=tf.shape(record),
|
||||
dtype=record.dtype)
|
||||
|
||||
agg_noise, central_noise = self.evaluate([query_result, central_noise])
|
||||
|
||||
mean_stddev = central_stddev * np.sqrt(num_trials) / num_trials
|
||||
atol = 3.5 * mean_stddev
|
||||
|
||||
# Use the atol for mean as a rough default atol for stddev/percentile.
|
||||
self.assertAllClose(np.mean(agg_noise), np.mean(central_noise), atol=atol)
|
||||
self.assertAllClose(np.std(agg_noise), np.std(central_noise), atol=atol)
|
||||
self.assertAllClose(
|
||||
np.percentile(agg_noise, [25, 50, 75]),
|
||||
np.percentile(central_noise, [25, 50, 75]),
|
||||
atol=atol)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
Loading…
Reference in a new issue