forked from 626_privacy/tensorflow_privacy
Add quantile_adaptive_clip_sum_query which dynamically adjusts the clipping norm so a specified fraction of records per sample are clipped.
PiperOrigin-RevId: 248201320
This commit is contained in:
parent
1d1a6e087a
commit
aaf029edad
4 changed files with 594 additions and 0 deletions
|
@ -33,6 +33,8 @@ else:
|
|||
from privacy.dp_query.no_privacy_query import NoPrivacyAverageQuery
|
||||
from privacy.dp_query.no_privacy_query import NoPrivacySumQuery
|
||||
from privacy.dp_query.normalized_query import NormalizedQuery
|
||||
from privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery
|
||||
from privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipAverageQuery
|
||||
|
||||
from privacy.optimizers.dp_optimizer import DPAdagradGaussianOptimizer
|
||||
from privacy.optimizers.dp_optimizer import DPAdagradOptimizer
|
||||
|
|
|
@ -105,6 +105,29 @@ py_test(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "quantile_adaptive_clip_sum_query",
|
||||
srcs = ["quantile_adaptive_clip_sum_query.py"],
|
||||
deps = [
|
||||
":dp_query",
|
||||
":gaussian_query",
|
||||
":normalized_query",
|
||||
"//third_party/py/tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "quantile_adaptive_clip_sum_query_test",
|
||||
srcs = ["quantile_adaptive_clip_sum_query_test.py"],
|
||||
deps = [
|
||||
":quantile_adaptive_clip_sum_query",
|
||||
":test_utils",
|
||||
"//third_party/py/numpy",
|
||||
"//third_party/py/tensorflow",
|
||||
"//third_party/py/tensorflow_privacy/privacy/analysis:privacy_ledger",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "test_utils",
|
||||
srcs = ["test_utils.py"],
|
||||
|
|
271
privacy/dp_query/quantile_adaptive_clip_sum_query.py
Normal file
271
privacy/dp_query/quantile_adaptive_clip_sum_query.py
Normal file
|
@ -0,0 +1,271 @@
|
|||
# Copyright 2019, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Implements DPQuery interface for adaptive clip queries.
|
||||
|
||||
Instead of a fixed clipping norm specified in advance, the clipping norm is
|
||||
dynamically adjusted to match a target fraction of clipped updates per sample,
|
||||
where the actual fraction of clipped updates is itself estimated in a
|
||||
differentially private manner. For details see Thakkar et al., "Differentially
|
||||
Private Learning with Adaptive Clipping" [http://arxiv.org/abs/1905.03871].
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from privacy.dp_query import dp_query
|
||||
from privacy.dp_query import gaussian_query
|
||||
from privacy.dp_query import normalized_query
|
||||
|
||||
nest = tf.contrib.framework.nest
|
||||
|
||||
|
||||
class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
|
||||
"""DPQuery for sum queries with adaptive clipping.
|
||||
|
||||
Clipping norm is tuned adaptively to converge to a value such that a specified
|
||||
quantile of updates are clipped.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_GlobalState = collections.namedtuple(
|
||||
'_GlobalState', ['l2_norm_clip', 'sum_state', 'clipped_fraction_state'])
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_SampleState = collections.namedtuple(
|
||||
'_SampleState', ['sum_state', 'clipped_fraction_state'])
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_SampleParams = collections.namedtuple(
|
||||
'_SampleParams', ['sum_params', 'clipped_fraction_params'])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_l2_norm_clip,
|
||||
noise_multiplier,
|
||||
target_unclipped_quantile,
|
||||
learning_rate,
|
||||
clipped_count_stddev,
|
||||
expected_num_records,
|
||||
ledger=None):
|
||||
"""Initializes the QuantileAdaptiveClipSumQuery.
|
||||
|
||||
Args:
|
||||
initial_l2_norm_clip: The initial value of clipping norm.
|
||||
noise_multiplier: The multiplier of the l2_norm_clip to make the stddev of
|
||||
the noise added to the output of the sum query.
|
||||
target_unclipped_quantile: The desired quantile of updates which should be
|
||||
unclipped. I.e., a value of 0.8 means a value of l2_norm_clip should be
|
||||
found for which approximately 20% of updates are clipped each round.
|
||||
learning_rate: The learning rate for the clipping norm adaptation. A
|
||||
rate of r means that the clipping norm will change by a maximum of r at
|
||||
each step. This maximum is attained when |clip - target| is 1.0. Can be
|
||||
a tf.Variable for example to implement a learning rate schedule.
|
||||
clipped_count_stddev: The stddev of the noise added to the clipped_count.
|
||||
Since the sensitivity of the clipped count is 0.5, as a rule of thumb it
|
||||
should be about 0.5 for reasonable privacy.
|
||||
expected_num_records: The expected number of records per round, used to
|
||||
estimate the clipped count quantile.
|
||||
ledger: The privacy ledger to which queries should be recorded.
|
||||
"""
|
||||
self._initial_l2_norm_clip = tf.cast(initial_l2_norm_clip, tf.float32)
|
||||
self._noise_multiplier = tf.cast(noise_multiplier, tf.float32)
|
||||
self._target_unclipped_quantile = tf.cast(
|
||||
target_unclipped_quantile, tf.float32)
|
||||
self._learning_rate = tf.cast(learning_rate, tf.float32)
|
||||
|
||||
self._l2_norm_clip = tf.Variable(self._initial_l2_norm_clip)
|
||||
self._sum_stddev = tf.Variable(
|
||||
self._initial_l2_norm_clip * self._noise_multiplier)
|
||||
self._sum_query = gaussian_query.GaussianSumQuery(
|
||||
self._l2_norm_clip,
|
||||
self._sum_stddev,
|
||||
ledger)
|
||||
|
||||
# self._clipped_fraction_query is a DPQuery used to estimate the fraction of
|
||||
# records that are clipped. It accumulates an indicator 0/1 of whether each
|
||||
# record is clipped, and normalizes by the expected number of records. In
|
||||
# practice, we accumulate clipped counts shifted by -0.5 so they are
|
||||
# centered at zero. This makes the sensitivity of the clipped count query
|
||||
# 0.5 instead of 1.0, since the maximum that a single record could affect
|
||||
# the count is 0.5. Note that although the l2_norm_clip of the clipped
|
||||
# fraction query is 0.5, no clipping will ever actually occur because the
|
||||
# value of each record is always +/-0.5.
|
||||
self._clipped_fraction_query = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=0.5,
|
||||
sum_stddev=clipped_count_stddev,
|
||||
denominator=expected_num_records,
|
||||
ledger=ledger)
|
||||
|
||||
def initial_global_state(self):
|
||||
"""See base class."""
|
||||
return self._GlobalState(
|
||||
self._initial_l2_norm_clip,
|
||||
self._sum_query.initial_global_state(),
|
||||
self._clipped_fraction_query.initial_global_state())
|
||||
|
||||
@tf.function
|
||||
def derive_sample_params(self, global_state):
|
||||
"""See base class."""
|
||||
gs = global_state
|
||||
|
||||
# Assign values to variables that inner sum query uses.
|
||||
tf.assign(self._l2_norm_clip, gs.l2_norm_clip)
|
||||
tf.assign(self._sum_stddev, gs.l2_norm_clip * self._noise_multiplier)
|
||||
sum_params = self._sum_query.derive_sample_params(gs.sum_state)
|
||||
clipped_fraction_params = self._clipped_fraction_query.derive_sample_params(
|
||||
gs.clipped_fraction_state)
|
||||
return self._SampleParams(sum_params, clipped_fraction_params)
|
||||
|
||||
def initial_sample_state(self, global_state, template):
|
||||
"""See base class."""
|
||||
clipped_fraction_state = self._clipped_fraction_query.initial_sample_state(
|
||||
global_state.clipped_fraction_state, tf.constant(0.0))
|
||||
sum_state = self._sum_query.initial_sample_state(
|
||||
global_state.sum_state, template)
|
||||
return self._SampleState(sum_state, clipped_fraction_state)
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
preprocessed_sum_record, global_norm = (
|
||||
self._sum_query.preprocess_record_impl(params.sum_params, record))
|
||||
|
||||
# Note we are relying on the internals of GaussianSumQuery here. If we want
|
||||
# to open this up to other kinds of inner queries we'd have to do this in a
|
||||
# more general way.
|
||||
l2_norm_clip = params.sum_params
|
||||
|
||||
# We accumulate clipped counts shifted by 0.5 so they are centered at zero.
|
||||
# This makes the sensitivity of the clipped count query 0.5 instead of 1.0.
|
||||
was_clipped = tf.cast(global_norm >= l2_norm_clip, tf.float32) - 0.5
|
||||
|
||||
preprocessed_clipped_fraction_record = (
|
||||
self._clipped_fraction_query.preprocess_record(
|
||||
params.clipped_fraction_params, was_clipped))
|
||||
|
||||
return preprocessed_sum_record, preprocessed_clipped_fraction_record
|
||||
|
||||
def accumulate_preprocessed_record(
|
||||
self, sample_state, preprocessed_record, weight=1):
|
||||
"""See base class."""
|
||||
preprocessed_sum_record, preprocessed_clipped_fraction_record = preprocessed_record
|
||||
sum_state = self._sum_query.accumulate_preprocessed_record(
|
||||
sample_state.sum_state, preprocessed_sum_record)
|
||||
|
||||
clipped_fraction_state = self._clipped_fraction_query.accumulate_preprocessed_record(
|
||||
sample_state.clipped_fraction_state,
|
||||
preprocessed_clipped_fraction_record)
|
||||
return self._SampleState(sum_state, clipped_fraction_state)
|
||||
|
||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||
"""See base class."""
|
||||
return self._SampleState(
|
||||
self._sum_query.merge_sample_states(
|
||||
sample_state_1.sum_state,
|
||||
sample_state_2.sum_state),
|
||||
self._clipped_fraction_query.merge_sample_states(
|
||||
sample_state_1.clipped_fraction_state,
|
||||
sample_state_2.clipped_fraction_state))
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""See base class."""
|
||||
gs = global_state
|
||||
|
||||
noised_vectors, sum_state = self._sum_query.get_noised_result(
|
||||
sample_state.sum_state, gs.sum_state)
|
||||
|
||||
clipped_fraction_result, new_clipped_fraction_state = (
|
||||
self._clipped_fraction_query.get_noised_result(
|
||||
sample_state.clipped_fraction_state,
|
||||
gs.clipped_fraction_state))
|
||||
|
||||
# Unshift clipped percentile by 0.5. (See comment in accumulate_record.)
|
||||
clipped_quantile = clipped_fraction_result + 0.5
|
||||
unclipped_quantile = 1.0 - clipped_quantile
|
||||
|
||||
# Protect against out-of-range estimates.
|
||||
unclipped_quantile = tf.minimum(1.0, tf.maximum(0.0, unclipped_quantile))
|
||||
|
||||
# Loss function is convex, with derivative in [-1, 1], and minimized when
|
||||
# the true quantile matches the target.
|
||||
loss_grad = unclipped_quantile - self._target_unclipped_quantile
|
||||
|
||||
new_l2_norm_clip = gs.l2_norm_clip - self._learning_rate * loss_grad
|
||||
new_l2_norm_clip = tf.maximum(0.0, new_l2_norm_clip)
|
||||
|
||||
new_global_state = self._GlobalState(
|
||||
new_l2_norm_clip,
|
||||
sum_state,
|
||||
new_clipped_fraction_state)
|
||||
|
||||
return noised_vectors, new_global_state
|
||||
|
||||
|
||||
class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery):
|
||||
"""DPQuery for average queries with adaptive clipping.
|
||||
|
||||
Clipping norm is tuned adaptively to converge to a value such that a specified
|
||||
quantile of updates are clipped.
|
||||
|
||||
Note that we use "fixed-denominator" estimation: the denominator should be
|
||||
specified as the expected number of records per sample. Accumulating the
|
||||
denominator separately would also be possible but would be produce a higher
|
||||
variance estimator.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_l2_norm_clip,
|
||||
noise_multiplier,
|
||||
denominator,
|
||||
target_unclipped_quantile,
|
||||
learning_rate,
|
||||
clipped_count_stddev,
|
||||
expected_num_records,
|
||||
ledger=None):
|
||||
"""Initializes the AdaptiveClipAverageQuery.
|
||||
|
||||
Args:
|
||||
initial_l2_norm_clip: The initial value of clipping norm.
|
||||
noise_multiplier: The multiplier of the l2_norm_clip to make the stddev of
|
||||
the noise.
|
||||
denominator: The normalization constant (applied after noise is added to
|
||||
the sum).
|
||||
target_unclipped_quantile: The desired quantile of updates which should be
|
||||
clipped.
|
||||
learning_rate: The learning rate for the clipping norm adaptation. A
|
||||
rate of r means that the clipping norm will change by a maximum of r at
|
||||
each step. The maximum is attained when |clip - target| is 1.0.
|
||||
clipped_count_stddev: The stddev of the noise added to the clipped_count.
|
||||
Since the sensitivity of the clipped count is 0.5, as a rule of thumb it
|
||||
should be about 0.5 for reasonable privacy.
|
||||
expected_num_records: The expected number of records, used to estimate the
|
||||
clipped count quantile.
|
||||
ledger: The privacy ledger to which queries should be recorded.
|
||||
"""
|
||||
numerator_query = QuantileAdaptiveClipSumQuery(
|
||||
initial_l2_norm_clip,
|
||||
noise_multiplier,
|
||||
target_unclipped_quantile,
|
||||
learning_rate,
|
||||
clipped_count_stddev,
|
||||
expected_num_records,
|
||||
ledger)
|
||||
super(QuantileAdaptiveClipAverageQuery, self).__init__(
|
||||
numerator_query=numerator_query,
|
||||
denominator=denominator)
|
298
privacy/dp_query/quantile_adaptive_clip_sum_query_test.py
Normal file
298
privacy/dp_query/quantile_adaptive_clip_sum_query_test.py
Normal file
|
@ -0,0 +1,298 @@
|
|||
# Copyright 2019, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for QuantileAdaptiveClipSumQuery."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from privacy.analysis import privacy_ledger
|
||||
from privacy.dp_query import quantile_adaptive_clip_sum_query
|
||||
from privacy.dp_query import test_utils
|
||||
|
||||
tf.enable_eager_execution()
|
||||
|
||||
|
||||
class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase):
|
||||
|
||||
def test_sum_no_clip_no_noise(self):
|
||||
record1 = tf.constant([2.0, 0.0])
|
||||
record2 = tf.constant([-1.0, 1.0])
|
||||
|
||||
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
|
||||
initial_l2_norm_clip=10.0,
|
||||
noise_multiplier=0.0,
|
||||
target_unclipped_quantile=1.0,
|
||||
learning_rate=0.0,
|
||||
clipped_count_stddev=0.0,
|
||||
expected_num_records=2.0)
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = query_result.numpy()
|
||||
expected = [1.0, 1.0]
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
def test_sum_with_clip_no_noise(self):
|
||||
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
|
||||
record2 = tf.constant([4.0, -3.0]) # Not clipped.
|
||||
|
||||
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
|
||||
initial_l2_norm_clip=5.0,
|
||||
noise_multiplier=0.0,
|
||||
target_unclipped_quantile=1.0,
|
||||
learning_rate=0.0,
|
||||
clipped_count_stddev=0.0,
|
||||
expected_num_records=2.0)
|
||||
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = query_result.numpy()
|
||||
expected = [1.0, 1.0]
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
def test_sum_with_noise(self):
|
||||
record1, record2 = 2.71828, 3.14159
|
||||
stddev = 1.0
|
||||
clip = 5.0
|
||||
|
||||
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
|
||||
initial_l2_norm_clip=clip,
|
||||
noise_multiplier=stddev / clip,
|
||||
target_unclipped_quantile=1.0,
|
||||
learning_rate=0.0,
|
||||
clipped_count_stddev=0.0,
|
||||
expected_num_records=2.0)
|
||||
|
||||
noised_sums = []
|
||||
for _ in xrange(1000):
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
noised_sums.append(query_result.numpy())
|
||||
|
||||
result_stddev = np.std(noised_sums)
|
||||
self.assertNear(result_stddev, stddev, 0.1)
|
||||
|
||||
def test_average_no_noise(self):
|
||||
record1 = tf.constant([5.0, 0.0]) # Clipped to [3.0, 0.0].
|
||||
record2 = tf.constant([-1.0, 2.0]) # Not clipped.
|
||||
|
||||
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipAverageQuery(
|
||||
initial_l2_norm_clip=3.0,
|
||||
noise_multiplier=0.0,
|
||||
denominator=2.0,
|
||||
target_unclipped_quantile=1.0,
|
||||
learning_rate=0.0,
|
||||
clipped_count_stddev=0.0,
|
||||
expected_num_records=2.0)
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = query_result.numpy()
|
||||
expected_average = [1.0, 1.0]
|
||||
self.assertAllClose(result, expected_average)
|
||||
|
||||
def test_average_with_noise(self):
|
||||
record1, record2 = 2.71828, 3.14159
|
||||
sum_stddev = 1.0
|
||||
denominator = 2.0
|
||||
clip = 3.0
|
||||
|
||||
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipAverageQuery(
|
||||
initial_l2_norm_clip=clip,
|
||||
noise_multiplier=sum_stddev / clip,
|
||||
denominator=denominator,
|
||||
target_unclipped_quantile=1.0,
|
||||
learning_rate=0.0,
|
||||
clipped_count_stddev=0.0,
|
||||
expected_num_records=2.0)
|
||||
|
||||
noised_averages = []
|
||||
for _ in range(1000):
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
noised_averages.append(query_result.numpy())
|
||||
|
||||
result_stddev = np.std(noised_averages)
|
||||
avg_stddev = sum_stddev / denominator
|
||||
self.assertNear(result_stddev, avg_stddev, 0.1)
|
||||
|
||||
def test_adaptation_target_zero(self):
|
||||
record1 = tf.constant([8.5])
|
||||
record2 = tf.constant([-7.25])
|
||||
|
||||
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
|
||||
initial_l2_norm_clip=10.0,
|
||||
noise_multiplier=0.0,
|
||||
target_unclipped_quantile=0.0,
|
||||
learning_rate=1.0,
|
||||
clipped_count_stddev=0.0,
|
||||
expected_num_records=2.0)
|
||||
|
||||
global_state = query.initial_global_state()
|
||||
|
||||
initial_clip = global_state.l2_norm_clip
|
||||
self.assertAllClose(initial_clip, 10.0)
|
||||
|
||||
# On the first two iterations, nothing is clipped, so the clip goes down
|
||||
# by 1.0 (the learning rate). When the clip reaches 8.0, one record is
|
||||
# clipped, so the clip goes down by only 0.5. After two more iterations,
|
||||
# both records are clipped, and the clip norm stays there (at 7.0).
|
||||
|
||||
expected_sums = [1.25, 1.25, 0.75, 0.25, 0.0]
|
||||
expected_clips = [9.0, 8.0, 7.5, 7.0, 7.0]
|
||||
for expected_sum, expected_clip in zip(expected_sums, expected_clips):
|
||||
actual_sum, global_state = test_utils.run_query(
|
||||
query, [record1, record2], global_state)
|
||||
|
||||
actual_clip = global_state.l2_norm_clip
|
||||
|
||||
self.assertAllClose(actual_clip.numpy(), expected_clip)
|
||||
self.assertAllClose(actual_sum.numpy(), (expected_sum,))
|
||||
|
||||
def test_adaptation_target_one(self):
|
||||
record1 = tf.constant([-1.5])
|
||||
record2 = tf.constant([2.75])
|
||||
|
||||
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
|
||||
initial_l2_norm_clip=0.0,
|
||||
noise_multiplier=0.0,
|
||||
target_unclipped_quantile=1.0,
|
||||
learning_rate=1.0,
|
||||
clipped_count_stddev=0.0,
|
||||
expected_num_records=2.0)
|
||||
|
||||
global_state = query.initial_global_state()
|
||||
|
||||
initial_clip = global_state.l2_norm_clip
|
||||
self.assertAllClose(initial_clip, 0.0)
|
||||
|
||||
# On the first two iterations, both are clipped, so the clip goes up
|
||||
# by 1.0 (the learning rate). When the clip reaches 2.0, only one record is
|
||||
# clipped, so the clip goes up by only 0.5. After two more iterations,
|
||||
# both records are clipped, and the clip norm stays there (at 3.0).
|
||||
|
||||
expected_sums = [0.0, 0.0, 0.5, 1.0, 1.25]
|
||||
expected_clips = [1.0, 2.0, 2.5, 3.0, 3.0]
|
||||
for expected_sum, expected_clip in zip(expected_sums, expected_clips):
|
||||
actual_sum, global_state = test_utils.run_query(
|
||||
query, [record1, record2], global_state)
|
||||
|
||||
actual_clip = global_state.l2_norm_clip
|
||||
|
||||
self.assertAllClose(actual_clip.numpy(), expected_clip)
|
||||
self.assertAllClose(actual_sum.numpy(), (expected_sum,))
|
||||
|
||||
def test_adaptation_linspace(self):
|
||||
# 100 records equally spaced from 0 to 10 in 0.1 increments.
|
||||
# Test that with a decaying learning rate we converge to the correct
|
||||
# median with error at most 0.1.
|
||||
records = [tf.constant(x) for x in np.linspace(
|
||||
0.0, 10.0, num=21, dtype=np.float32)]
|
||||
|
||||
learning_rate = tf.Variable(1.0)
|
||||
|
||||
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
|
||||
initial_l2_norm_clip=0.0,
|
||||
noise_multiplier=0.0,
|
||||
target_unclipped_quantile=0.5,
|
||||
learning_rate=learning_rate,
|
||||
clipped_count_stddev=0.0,
|
||||
expected_num_records=2.0)
|
||||
|
||||
global_state = query.initial_global_state()
|
||||
|
||||
for t in range(50):
|
||||
tf.assign(learning_rate, 1.0 / np.sqrt(t+1))
|
||||
_, global_state = test_utils.run_query(query, records, global_state)
|
||||
|
||||
actual_clip = global_state.l2_norm_clip
|
||||
|
||||
if t > 40:
|
||||
self.assertNear(actual_clip, 5.0, 0.25)
|
||||
|
||||
def test_adaptation_all_equal(self):
|
||||
# 100 equal records. Test that with a decaying learning rate we converge to
|
||||
# that record and bounce around it.
|
||||
records = [tf.constant(5.0)] * 20
|
||||
|
||||
learning_rate = tf.Variable(1.0)
|
||||
|
||||
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
|
||||
initial_l2_norm_clip=0.0,
|
||||
noise_multiplier=0.0,
|
||||
target_unclipped_quantile=0.5,
|
||||
learning_rate=learning_rate,
|
||||
clipped_count_stddev=0.0,
|
||||
expected_num_records=2.0)
|
||||
|
||||
global_state = query.initial_global_state()
|
||||
|
||||
for t in range(50):
|
||||
tf.assign(learning_rate, 1.0 / np.sqrt(t+1))
|
||||
_, global_state = test_utils.run_query(query, records, global_state)
|
||||
|
||||
actual_clip = global_state.l2_norm_clip
|
||||
|
||||
if t > 40:
|
||||
self.assertNear(actual_clip, 5.0, 0.25)
|
||||
|
||||
def test_ledger(self):
|
||||
record1 = tf.constant([8.5])
|
||||
record2 = tf.constant([-7.25])
|
||||
|
||||
population_size = tf.Variable(0)
|
||||
selection_probability = tf.Variable(0.0)
|
||||
ledger = privacy_ledger.PrivacyLedger(
|
||||
population_size, selection_probability, 50, 50)
|
||||
|
||||
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
|
||||
initial_l2_norm_clip=10.0,
|
||||
noise_multiplier=1.0,
|
||||
target_unclipped_quantile=0.0,
|
||||
learning_rate=1.0,
|
||||
clipped_count_stddev=0.0,
|
||||
expected_num_records=2.0,
|
||||
ledger=ledger)
|
||||
|
||||
query = privacy_ledger.QueryWithLedger(query, ledger)
|
||||
|
||||
# First sample.
|
||||
tf.assign(population_size, 10)
|
||||
tf.assign(selection_probability, 0.1)
|
||||
_, global_state = test_utils.run_query(query, [record1, record2])
|
||||
|
||||
expected_queries = [[0.5, 0.0], [10.0, 10.0]]
|
||||
formatted = ledger.get_formatted_ledger_eager()
|
||||
sample_1 = formatted[0]
|
||||
self.assertAllClose(sample_1.population_size, 10.0)
|
||||
self.assertAllClose(sample_1.selection_probability, 0.1)
|
||||
self.assertAllClose(sample_1.queries, expected_queries)
|
||||
|
||||
# Second sample.
|
||||
tf.assign(population_size, 20)
|
||||
tf.assign(selection_probability, 0.2)
|
||||
test_utils.run_query(query, [record1, record2], global_state)
|
||||
|
||||
formatted = ledger.get_formatted_ledger_eager()
|
||||
sample_1, sample_2 = formatted
|
||||
self.assertAllClose(sample_1.population_size, 10.0)
|
||||
self.assertAllClose(sample_1.selection_probability, 0.1)
|
||||
self.assertAllClose(sample_1.queries, expected_queries)
|
||||
|
||||
expected_queries_2 = [[0.5, 0.0], [9.0, 9.0]]
|
||||
self.assertAllClose(sample_2.population_size, 20.0)
|
||||
self.assertAllClose(sample_2.selection_probability, 0.2)
|
||||
self.assertAllClose(sample_2.queries, expected_queries_2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
Loading…
Reference in a new issue