Refactor quantile estimation logic from QuantileAdaptiveClipSumQuery so it can be used for other purposes.
PiperOrigin-RevId: 315297665
This commit is contained in:
parent
261ab4f28e
commit
cec011e2a7
7 changed files with 458 additions and 108 deletions
|
@ -36,6 +36,7 @@ else:
|
||||||
from tensorflow_privacy.privacy.dp_query.no_privacy_query import NoPrivacyAverageQuery
|
from tensorflow_privacy.privacy.dp_query.no_privacy_query import NoPrivacyAverageQuery
|
||||||
from tensorflow_privacy.privacy.dp_query.no_privacy_query import NoPrivacySumQuery
|
from tensorflow_privacy.privacy.dp_query.no_privacy_query import NoPrivacySumQuery
|
||||||
from tensorflow_privacy.privacy.dp_query.normalized_query import NormalizedQuery
|
from tensorflow_privacy.privacy.dp_query.normalized_query import NormalizedQuery
|
||||||
|
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import QuantileEstimatorQuery
|
||||||
from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery
|
from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery
|
||||||
from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipAverageQuery
|
from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipAverageQuery
|
||||||
|
|
||||||
|
|
|
@ -120,6 +120,7 @@ py_library(
|
||||||
":dp_query",
|
":dp_query",
|
||||||
":gaussian_query",
|
":gaussian_query",
|
||||||
":normalized_query",
|
":normalized_query",
|
||||||
|
":quantile_estimator_query",
|
||||||
"//third_party/py/tensorflow",
|
"//third_party/py/tensorflow",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -139,6 +140,30 @@ py_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "quantile_estimator_query",
|
||||||
|
srcs = ["quantile_estimator_query.py"],
|
||||||
|
deps = [
|
||||||
|
":dp_query",
|
||||||
|
":gaussian_query",
|
||||||
|
"//third_party/py/tensorflow",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "quantile_estimator_query_test",
|
||||||
|
srcs = ["quantile_estimator_query_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":quantile_estimator_query",
|
||||||
|
":test_utils",
|
||||||
|
"//learning/brain/public:disable_tf2", # build_cleaner: keep; go/disable_tf2
|
||||||
|
"//third_party/py/absl/testing:parameterized",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"//third_party/py/tensorflow",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "test_utils",
|
name = "test_utils",
|
||||||
srcs = ["test_utils.py"],
|
srcs = ["test_utils.py"],
|
||||||
|
|
|
@ -26,7 +26,7 @@ import tensorflow.compat.v1 as tf
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
|
|
||||||
|
|
||||||
class NormalizedQuery(dp_query.DPQuery):
|
class NormalizedQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""DPQuery for queries with a DPQuery numerator and fixed denominator."""
|
"""DPQuery for queries with a DPQuery numerator and fixed denominator."""
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
|
@ -37,7 +37,7 @@ class NormalizedQuery(dp_query.DPQuery):
|
||||||
"""Initializer for NormalizedQuery.
|
"""Initializer for NormalizedQuery.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
numerator_query: A DPQuery for the numerator.
|
numerator_query: A SumAggregationDPQuery for the numerator.
|
||||||
denominator: A value for the denominator. May be None if it will be
|
denominator: A value for the denominator. May be None if it will be
|
||||||
supplied via the set_denominator function before get_noised_result is
|
supplied via the set_denominator function before get_noised_result is
|
||||||
called.
|
called.
|
||||||
|
@ -45,6 +45,8 @@ class NormalizedQuery(dp_query.DPQuery):
|
||||||
self._numerator = numerator_query
|
self._numerator = numerator_query
|
||||||
self._denominator = denominator
|
self._denominator = denominator
|
||||||
|
|
||||||
|
assert isinstance(self._numerator, dp_query.SumAggregationDPQuery)
|
||||||
|
|
||||||
def set_ledger(self, ledger):
|
def set_ledger(self, ledger):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
self._numerator.set_ledger(ledger)
|
self._numerator.set_ledger(ledger)
|
||||||
|
@ -70,12 +72,6 @@ class NormalizedQuery(dp_query.DPQuery):
|
||||||
def preprocess_record(self, params, record):
|
def preprocess_record(self, params, record):
|
||||||
return self._numerator.preprocess_record(params, record)
|
return self._numerator.preprocess_record(params, record)
|
||||||
|
|
||||||
def accumulate_preprocessed_record(
|
|
||||||
self, sample_state, preprocessed_record):
|
|
||||||
"""See base class."""
|
|
||||||
return self._numerator.accumulate_preprocessed_record(
|
|
||||||
sample_state, preprocessed_record)
|
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
|
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
|
||||||
|
@ -85,7 +81,3 @@ class NormalizedQuery(dp_query.DPQuery):
|
||||||
|
|
||||||
return (tf.nest.map_structure(normalize, noised_sum),
|
return (tf.nest.map_structure(normalize, noised_sum),
|
||||||
self._GlobalState(new_sum_global_state, global_state.denominator))
|
self._GlobalState(new_sum_global_state, global_state.denominator))
|
||||||
|
|
||||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
|
||||||
"""See base class."""
|
|
||||||
return self._numerator.merge_sample_states(sample_state_1, sample_state_2)
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ import tensorflow.compat.v1 as tf
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||||
from tensorflow_privacy.privacy.dp_query import normalized_query
|
from tensorflow_privacy.privacy.dp_query import normalized_query
|
||||||
|
from tensorflow_privacy.privacy.dp_query import quantile_estimator_query
|
||||||
|
|
||||||
|
|
||||||
class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
@ -45,18 +46,16 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
_GlobalState = collections.namedtuple(
|
_GlobalState = collections.namedtuple(
|
||||||
'_GlobalState', [
|
'_GlobalState', [
|
||||||
'noise_multiplier',
|
'noise_multiplier',
|
||||||
'target_unclipped_quantile',
|
|
||||||
'learning_rate',
|
|
||||||
'sum_state',
|
'sum_state',
|
||||||
'clipped_fraction_state'])
|
'quantile_estimator_state'])
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
_SampleState = collections.namedtuple(
|
_SampleState = collections.namedtuple(
|
||||||
'_SampleState', ['sum_state', 'clipped_fraction_state'])
|
'_SampleState', ['sum_state', 'quantile_estimator_state'])
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
_SampleParams = collections.namedtuple(
|
_SampleParams = collections.namedtuple(
|
||||||
'_SampleParams', ['sum_params', 'clipped_fraction_params'])
|
'_SampleParams', ['sum_params', 'quantile_estimator_params'])
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -66,7 +65,7 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
learning_rate,
|
learning_rate,
|
||||||
clipped_count_stddev,
|
clipped_count_stddev,
|
||||||
expected_num_records,
|
expected_num_records,
|
||||||
geometric_update=False):
|
geometric_update=True):
|
||||||
"""Initializes the QuantileAdaptiveClipSumQuery.
|
"""Initializes the QuantileAdaptiveClipSumQuery.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -86,126 +85,79 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
estimate the clipped count quantile.
|
estimate the clipped count quantile.
|
||||||
geometric_update: If True, use geometric updating of clip.
|
geometric_update: If True, use geometric updating of clip.
|
||||||
"""
|
"""
|
||||||
self._initial_l2_norm_clip = initial_l2_norm_clip
|
|
||||||
self._noise_multiplier = noise_multiplier
|
self._noise_multiplier = noise_multiplier
|
||||||
self._target_unclipped_quantile = target_unclipped_quantile
|
|
||||||
self._learning_rate = learning_rate
|
|
||||||
|
|
||||||
# Initialize sum query's global state with None, to be set later.
|
self._quantile_estimator_query = quantile_estimator_query.QuantileEstimatorQuery(
|
||||||
self._sum_query = gaussian_query.GaussianSumQuery(None, None)
|
initial_l2_norm_clip,
|
||||||
|
target_unclipped_quantile,
|
||||||
|
learning_rate,
|
||||||
|
clipped_count_stddev,
|
||||||
|
expected_num_records,
|
||||||
|
geometric_update)
|
||||||
|
|
||||||
# self._clipped_fraction_query is a DPQuery used to estimate the fraction of
|
self._sum_query = gaussian_query.GaussianSumQuery(
|
||||||
# records that are clipped. It accumulates an indicator 0/1 of whether each
|
initial_l2_norm_clip,
|
||||||
# record is clipped, and normalizes by the expected number of records. In
|
noise_multiplier * initial_l2_norm_clip)
|
||||||
# practice, we accumulate clipped counts shifted by -0.5 so they are
|
|
||||||
# centered at zero. This makes the sensitivity of the clipped count query
|
|
||||||
# 0.5 instead of 1.0, since the maximum that a single record could affect
|
|
||||||
# the count is 0.5. Note that although the l2_norm_clip of the clipped
|
|
||||||
# fraction query is 0.5, no clipping will ever actually occur because the
|
|
||||||
# value of each record is always +/-0.5.
|
|
||||||
self._clipped_fraction_query = gaussian_query.GaussianAverageQuery(
|
|
||||||
l2_norm_clip=0.5,
|
|
||||||
sum_stddev=clipped_count_stddev,
|
|
||||||
denominator=expected_num_records)
|
|
||||||
|
|
||||||
self._geometric_update = geometric_update
|
assert isinstance(self._sum_query, dp_query.SumAggregationDPQuery)
|
||||||
|
assert isinstance(self._quantile_estimator_query,
|
||||||
|
dp_query.SumAggregationDPQuery)
|
||||||
|
|
||||||
def set_ledger(self, ledger):
|
def set_ledger(self, ledger):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
self._sum_query.set_ledger(ledger)
|
self._sum_query.set_ledger(ledger)
|
||||||
self._clipped_fraction_query.set_ledger(ledger)
|
self._quantile_estimator_query.set_ledger(ledger)
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
initial_l2_norm_clip = tf.cast(self._initial_l2_norm_clip, tf.float32)
|
|
||||||
noise_multiplier = tf.cast(self._noise_multiplier, tf.float32)
|
|
||||||
target_unclipped_quantile = tf.cast(self._target_unclipped_quantile,
|
|
||||||
tf.float32)
|
|
||||||
learning_rate = tf.cast(self._learning_rate, tf.float32)
|
|
||||||
sum_stddev = initial_l2_norm_clip * noise_multiplier
|
|
||||||
|
|
||||||
sum_query_global_state = self._sum_query.make_global_state(
|
|
||||||
l2_norm_clip=initial_l2_norm_clip,
|
|
||||||
stddev=sum_stddev)
|
|
||||||
|
|
||||||
return self._GlobalState(
|
return self._GlobalState(
|
||||||
noise_multiplier,
|
tf.cast(self._noise_multiplier, tf.float32),
|
||||||
target_unclipped_quantile,
|
self._sum_query.initial_global_state(),
|
||||||
learning_rate,
|
self._quantile_estimator_query.initial_global_state())
|
||||||
sum_query_global_state,
|
|
||||||
self._clipped_fraction_query.initial_global_state())
|
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
|
return self._SampleParams(
|
||||||
# Assign values to variables that inner sum query uses.
|
self._sum_query.derive_sample_params(global_state.sum_state),
|
||||||
sum_params = self._sum_query.derive_sample_params(global_state.sum_state)
|
self._quantile_estimator_query.derive_sample_params(
|
||||||
clipped_fraction_params = self._clipped_fraction_query.derive_sample_params(
|
global_state.quantile_estimator_state))
|
||||||
global_state.clipped_fraction_state)
|
|
||||||
return self._SampleParams(sum_params, clipped_fraction_params)
|
|
||||||
|
|
||||||
def initial_sample_state(self, template):
|
def initial_sample_state(self, template):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
sum_state = self._sum_query.initial_sample_state(template)
|
return self._SampleState(
|
||||||
clipped_fraction_state = self._clipped_fraction_query.initial_sample_state(
|
self._sum_query.initial_sample_state(template),
|
||||||
tf.constant(0.0))
|
self._quantile_estimator_query.initial_sample_state(tf.constant(0.0)))
|
||||||
return self._SampleState(sum_state, clipped_fraction_state)
|
|
||||||
|
|
||||||
def preprocess_record(self, params, record):
|
def preprocess_record(self, params, record):
|
||||||
preprocessed_sum_record, global_norm = (
|
clipped_record, global_norm = (
|
||||||
self._sum_query.preprocess_record_impl(params.sum_params, record))
|
self._sum_query.preprocess_record_impl(params.sum_params, record))
|
||||||
|
|
||||||
# Note we are relying on the internals of GaussianSumQuery here. If we want
|
was_unclipped = self._quantile_estimator_query.preprocess_record(
|
||||||
# to open this up to other kinds of inner queries we'd have to do this in a
|
params.quantile_estimator_params, global_norm)
|
||||||
# more general way.
|
|
||||||
l2_norm_clip = params.sum_params
|
|
||||||
|
|
||||||
# We accumulate clipped counts shifted by 0.5 so they are centered at zero.
|
return self._SampleState(clipped_record, was_unclipped)
|
||||||
# This makes the sensitivity of the clipped count query 0.5 instead of 1.0.
|
|
||||||
was_clipped = tf.cast(global_norm >= l2_norm_clip, tf.float32) - 0.5
|
|
||||||
|
|
||||||
return self._SampleState(preprocessed_sum_record, was_clipped)
|
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
gs = global_state
|
|
||||||
|
|
||||||
noised_vectors, sum_state = self._sum_query.get_noised_result(
|
noised_vectors, sum_state = self._sum_query.get_noised_result(
|
||||||
sample_state.sum_state, gs.sum_state)
|
sample_state.sum_state, global_state.sum_state)
|
||||||
del sum_state # Unused. To be set explicitly later.
|
del sum_state # To be set explicitly later when we know the new clip.
|
||||||
|
|
||||||
clipped_fraction_result, new_clipped_fraction_state = (
|
new_l2_norm_clip, new_quantile_estimator_state = (
|
||||||
self._clipped_fraction_query.get_noised_result(
|
self._quantile_estimator_query.get_noised_result(
|
||||||
sample_state.clipped_fraction_state,
|
sample_state.quantile_estimator_state,
|
||||||
gs.clipped_fraction_state))
|
global_state.quantile_estimator_state))
|
||||||
|
|
||||||
# Unshift clipped percentile by 0.5. (See comment in initializer.)
|
|
||||||
clipped_quantile = clipped_fraction_result + 0.5
|
|
||||||
unclipped_quantile = 1.0 - clipped_quantile
|
|
||||||
|
|
||||||
# Protect against out-of-range estimates.
|
|
||||||
unclipped_quantile = tf.minimum(1.0, tf.maximum(0.0, unclipped_quantile))
|
|
||||||
|
|
||||||
# Loss function is convex, with derivative in [-1, 1], and minimized when
|
|
||||||
# the true quantile matches the target.
|
|
||||||
loss_grad = unclipped_quantile - global_state.target_unclipped_quantile
|
|
||||||
|
|
||||||
update = global_state.learning_rate * loss_grad
|
|
||||||
|
|
||||||
if self._geometric_update:
|
|
||||||
new_l2_norm_clip = gs.sum_state.l2_norm_clip * tf.math.exp(-update)
|
|
||||||
else:
|
|
||||||
new_l2_norm_clip = tf.math.maximum(0.0,
|
|
||||||
gs.sum_state.l2_norm_clip - update)
|
|
||||||
|
|
||||||
|
new_l2_norm_clip = tf.maximum(new_l2_norm_clip, 0.0)
|
||||||
new_sum_stddev = new_l2_norm_clip * global_state.noise_multiplier
|
new_sum_stddev = new_l2_norm_clip * global_state.noise_multiplier
|
||||||
new_sum_query_global_state = self._sum_query.make_global_state(
|
new_sum_query_state = self._sum_query.make_global_state(
|
||||||
l2_norm_clip=new_l2_norm_clip,
|
l2_norm_clip=new_l2_norm_clip,
|
||||||
stddev=new_sum_stddev)
|
stddev=new_sum_stddev)
|
||||||
|
|
||||||
new_global_state = global_state._replace(
|
new_global_state = self._GlobalState(
|
||||||
sum_state=new_sum_query_global_state,
|
global_state.noise_multiplier,
|
||||||
clipped_fraction_state=new_clipped_fraction_state)
|
new_sum_query_state,
|
||||||
|
new_quantile_estimator_state)
|
||||||
|
|
||||||
return noised_vectors, new_global_state
|
return noised_vectors, new_global_state
|
||||||
|
|
||||||
|
|
|
@ -138,7 +138,8 @@ class QuantileAdaptiveClipSumQueryTest(
|
||||||
target_unclipped_quantile=0.0,
|
target_unclipped_quantile=0.0,
|
||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
clipped_count_stddev=0.0,
|
clipped_count_stddev=0.0,
|
||||||
expected_num_records=2.0)
|
expected_num_records=2.0,
|
||||||
|
geometric_update=False)
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
@ -207,7 +208,8 @@ class QuantileAdaptiveClipSumQueryTest(
|
||||||
target_unclipped_quantile=1.0,
|
target_unclipped_quantile=1.0,
|
||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
clipped_count_stddev=0.0,
|
clipped_count_stddev=0.0,
|
||||||
expected_num_records=2.0)
|
expected_num_records=2.0,
|
||||||
|
geometric_update=False)
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
@ -344,7 +346,8 @@ class QuantileAdaptiveClipSumQueryTest(
|
||||||
target_unclipped_quantile=0.0,
|
target_unclipped_quantile=0.0,
|
||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
clipped_count_stddev=0.0,
|
clipped_count_stddev=0.0,
|
||||||
expected_num_records=2.0)
|
expected_num_records=2.0,
|
||||||
|
geometric_update=False)
|
||||||
|
|
||||||
query = privacy_ledger.QueryWithLedger(
|
query = privacy_ledger.QueryWithLedger(
|
||||||
query, population_size, selection_probability)
|
query, population_size, selection_probability)
|
||||||
|
|
158
tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py
Normal file
158
tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Implements DPQuery interface for quantile estimator.
|
||||||
|
|
||||||
|
From a starting estimate of the target quantile, the estimate is updated
|
||||||
|
dynamically where the fraction of below_estimate updates is estimated in a
|
||||||
|
differentially private manner. For details see Thakkar et al., "Differentially
|
||||||
|
Private Learning with Adaptive Clipping" [http://arxiv.org/abs/1905.03871].
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
|
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||||
|
|
||||||
|
|
||||||
|
class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
"""Defines iterative process to estimate a target quantile of a distribution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
_GlobalState = collections.namedtuple(
|
||||||
|
'_GlobalState', [
|
||||||
|
'current_estimate',
|
||||||
|
'target_quantile',
|
||||||
|
'learning_rate',
|
||||||
|
'below_estimate_state'])
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
_SampleParams = collections.namedtuple(
|
||||||
|
'_SampleParams', ['current_estimate', 'below_estimate_params'])
|
||||||
|
|
||||||
|
# No separate SampleState-- sample state is just below_estimate_query's
|
||||||
|
# SampleState.
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
initial_estimate,
|
||||||
|
target_quantile,
|
||||||
|
learning_rate,
|
||||||
|
below_estimate_stddev,
|
||||||
|
expected_num_records,
|
||||||
|
geometric_update=False):
|
||||||
|
"""Initializes the QuantileAdaptiveClipSumQuery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_estimate: The initial estimate of the quantile.
|
||||||
|
target_quantile: The target quantile. I.e., a value of 0.8 means a value
|
||||||
|
should be found for which approximately 80% of updates are
|
||||||
|
less than the estimate each round.
|
||||||
|
learning_rate: The learning rate. A rate of r means that the estimate
|
||||||
|
will change by a maximum of r at each step (for arithmetic updating) or
|
||||||
|
by a maximum factor of exp(r) (for geometric updating).
|
||||||
|
below_estimate_stddev: The stddev of the noise added to the count of
|
||||||
|
records currently below the estimate. Since the sensitivity of the count
|
||||||
|
query is 0.5, as a rule of thumb it should be about 0.5 for reasonable
|
||||||
|
privacy.
|
||||||
|
expected_num_records: The expected number of records per round.
|
||||||
|
geometric_update: If True, use geometric updating of estimate. Geometric
|
||||||
|
updating is preferred for non-negative records like vector norms that
|
||||||
|
could potentially be very large or very close to zero.
|
||||||
|
"""
|
||||||
|
self._initial_estimate = initial_estimate
|
||||||
|
self._target_quantile = target_quantile
|
||||||
|
self._learning_rate = learning_rate
|
||||||
|
|
||||||
|
# A DPQuery used to estimate the fraction of records that are less than the
|
||||||
|
# current quantile estimate. It accumulates an indicator 0/1 of whether each
|
||||||
|
# record is below the estimate, and normalizes by the expected number of
|
||||||
|
# records. In practice, we accumulate counts shifted by -0.5 so they are
|
||||||
|
# centered at zero. This makes the sensitivity of the below_estimate count
|
||||||
|
# query 0.5 instead of 1.0, since the maximum that a single record could
|
||||||
|
# affect the count is 0.5. Note that although the l2_norm_clip of the
|
||||||
|
# below_estimate query is 0.5, no clipping will ever actually occur
|
||||||
|
# because the value of each record is always +/-0.5.
|
||||||
|
self._below_estimate_query = gaussian_query.GaussianAverageQuery(
|
||||||
|
l2_norm_clip=0.5,
|
||||||
|
sum_stddev=below_estimate_stddev,
|
||||||
|
denominator=expected_num_records)
|
||||||
|
|
||||||
|
self._geometric_update = geometric_update
|
||||||
|
|
||||||
|
assert isinstance(self._below_estimate_query,
|
||||||
|
dp_query.SumAggregationDPQuery)
|
||||||
|
|
||||||
|
def set_ledger(self, ledger):
|
||||||
|
"""See base class."""
|
||||||
|
self._below_estimate_query.set_ledger(ledger)
|
||||||
|
|
||||||
|
def initial_global_state(self):
|
||||||
|
"""See base class."""
|
||||||
|
return self._GlobalState(
|
||||||
|
tf.cast(self._initial_estimate, tf.float32),
|
||||||
|
tf.cast(self._target_quantile, tf.float32),
|
||||||
|
tf.cast(self._learning_rate, tf.float32),
|
||||||
|
self._below_estimate_query.initial_global_state())
|
||||||
|
|
||||||
|
def derive_sample_params(self, global_state):
|
||||||
|
"""See base class."""
|
||||||
|
below_estimate_params = self._below_estimate_query.derive_sample_params(
|
||||||
|
global_state.below_estimate_state)
|
||||||
|
return self._SampleParams(global_state.current_estimate,
|
||||||
|
below_estimate_params)
|
||||||
|
|
||||||
|
def preprocess_record(self, params, record):
|
||||||
|
# We accumulate counts shifted by 0.5 so they are centered at zero.
|
||||||
|
# This makes the sensitivity of the count query 0.5 instead of 1.0.
|
||||||
|
below = tf.cast(record <= params.current_estimate, tf.float32) - 0.5
|
||||||
|
return self._below_estimate_query.preprocess_record(
|
||||||
|
params.below_estimate_params, below)
|
||||||
|
|
||||||
|
def get_noised_result(self, sample_state, global_state):
|
||||||
|
"""See base class."""
|
||||||
|
below_estimate_result, new_below_estimate_state = (
|
||||||
|
self._below_estimate_query.get_noised_result(
|
||||||
|
sample_state,
|
||||||
|
global_state.below_estimate_state))
|
||||||
|
|
||||||
|
# Unshift below_estimate percentile by 0.5. (See comment in initializer.)
|
||||||
|
below_estimate = below_estimate_result + 0.5
|
||||||
|
|
||||||
|
# Protect against out-of-range estimates.
|
||||||
|
below_estimate = tf.minimum(1.0, tf.maximum(0.0, below_estimate))
|
||||||
|
|
||||||
|
# Loss function is convex, with derivative in [-1, 1], and minimized when
|
||||||
|
# the true quantile matches the target.
|
||||||
|
loss_grad = below_estimate - global_state.target_quantile
|
||||||
|
|
||||||
|
update = global_state.learning_rate * loss_grad
|
||||||
|
|
||||||
|
if self._geometric_update:
|
||||||
|
new_estimate = global_state.current_estimate * tf.math.exp(-update)
|
||||||
|
else:
|
||||||
|
new_estimate = global_state.current_estimate - update
|
||||||
|
|
||||||
|
new_global_state = global_state._replace(
|
||||||
|
current_estimate=new_estimate,
|
||||||
|
below_estimate_state=new_below_estimate_state)
|
||||||
|
|
||||||
|
return new_estimate, new_global_state
|
|
@ -0,0 +1,219 @@
|
||||||
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Tests for QuantileAdaptiveClipSumQuery."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.dp_query import quantile_estimator_query
|
||||||
|
from tensorflow_privacy.privacy.dp_query import test_utils
|
||||||
|
|
||||||
|
tf.enable_eager_execution()
|
||||||
|
|
||||||
|
|
||||||
|
class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def test_target_zero(self):
|
||||||
|
record1 = tf.constant(8.5)
|
||||||
|
record2 = tf.constant(7.25)
|
||||||
|
|
||||||
|
query = quantile_estimator_query.QuantileEstimatorQuery(
|
||||||
|
initial_estimate=10.0,
|
||||||
|
target_quantile=0.0,
|
||||||
|
learning_rate=1.0,
|
||||||
|
below_estimate_stddev=0.0,
|
||||||
|
expected_num_records=2.0,
|
||||||
|
geometric_update=False)
|
||||||
|
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
initial_estimate = global_state.current_estimate
|
||||||
|
self.assertAllClose(initial_estimate, 10.0)
|
||||||
|
|
||||||
|
# On the first two iterations, both records are below, so the estimate goes
|
||||||
|
# down by 1.0 (the learning rate). When the estimate reaches 8.0, only one
|
||||||
|
# record is below, so the estimate goes down by only 0.5. After two more
|
||||||
|
# iterations, both records are below, and the estimate stays there (at 7.0).
|
||||||
|
|
||||||
|
expected_estimates = [9.0, 8.0, 7.5, 7.0, 7.0]
|
||||||
|
for expected_estimate in expected_estimates:
|
||||||
|
actual_estimate, global_state = test_utils.run_query(
|
||||||
|
query, [record1, record2], global_state)
|
||||||
|
|
||||||
|
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
||||||
|
|
||||||
|
def test_target_zero_geometric(self):
|
||||||
|
record1 = tf.constant(5.0)
|
||||||
|
record2 = tf.constant(2.5)
|
||||||
|
|
||||||
|
query = quantile_estimator_query.QuantileEstimatorQuery(
|
||||||
|
initial_estimate=16.0,
|
||||||
|
target_quantile=0.0,
|
||||||
|
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
|
||||||
|
below_estimate_stddev=0.0,
|
||||||
|
expected_num_records=2.0,
|
||||||
|
geometric_update=True)
|
||||||
|
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
initial_estimate = global_state.current_estimate
|
||||||
|
self.assertAllClose(initial_estimate, 16.0)
|
||||||
|
|
||||||
|
# For two iterations, both records are below, so the estimate is halved.
|
||||||
|
# Then only one record is below, so the estimate goes down by only sqrt(2.0)
|
||||||
|
# to 4 / sqrt(2.0). Still only one record is below, so it reduces to 2.0.
|
||||||
|
# Now no records are below, and the estimate norm stays there (at 2.0).
|
||||||
|
|
||||||
|
four_div_root_two = 4 / np.sqrt(2.0) # approx 2.828
|
||||||
|
|
||||||
|
expected_estimates = [8.0, 4.0, four_div_root_two, 2.0, 2.0]
|
||||||
|
for expected_estimate in expected_estimates:
|
||||||
|
actual_estimate, global_state = test_utils.run_query(
|
||||||
|
query, [record1, record2], global_state)
|
||||||
|
|
||||||
|
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
||||||
|
|
||||||
|
def test_target_one(self):
|
||||||
|
record1 = tf.constant(1.5)
|
||||||
|
record2 = tf.constant(2.75)
|
||||||
|
|
||||||
|
query = quantile_estimator_query.QuantileEstimatorQuery(
|
||||||
|
initial_estimate=0.0,
|
||||||
|
target_quantile=1.0,
|
||||||
|
learning_rate=1.0,
|
||||||
|
below_estimate_stddev=0.0,
|
||||||
|
expected_num_records=2.0,
|
||||||
|
geometric_update=False)
|
||||||
|
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
initial_estimate = global_state.current_estimate
|
||||||
|
self.assertAllClose(initial_estimate, 0.0)
|
||||||
|
|
||||||
|
# On the first two iterations, both are above, so the estimate goes up
|
||||||
|
# by 1.0 (the learning rate). When it reaches 2.0, only one record is
|
||||||
|
# above, so the estimate goes up by only 0.5. After two more iterations,
|
||||||
|
# both records are below, and the estimate stays there (at 3.0).
|
||||||
|
|
||||||
|
expected_estimates = [1.0, 2.0, 2.5, 3.0, 3.0]
|
||||||
|
for expected_estimate in expected_estimates:
|
||||||
|
actual_estimate, global_state = test_utils.run_query(
|
||||||
|
query, [record1, record2], global_state)
|
||||||
|
|
||||||
|
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
||||||
|
|
||||||
|
def test_target_one_geometric(self):
|
||||||
|
record1 = tf.constant(1.5)
|
||||||
|
record2 = tf.constant(3.0)
|
||||||
|
|
||||||
|
query = quantile_estimator_query.QuantileEstimatorQuery(
|
||||||
|
initial_estimate=0.5,
|
||||||
|
target_quantile=1.0,
|
||||||
|
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
|
||||||
|
below_estimate_stddev=0.0,
|
||||||
|
expected_num_records=2.0,
|
||||||
|
geometric_update=True)
|
||||||
|
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
initial_estimate = global_state.current_estimate
|
||||||
|
self.assertAllClose(initial_estimate, 0.5)
|
||||||
|
|
||||||
|
# On the first two iterations, both are above, so the estimate is doubled.
|
||||||
|
# When the estimate reaches 2.0, only one record is above, so the estimate
|
||||||
|
# is multiplied by sqrt(2.0). Still only one is above so it increases to
|
||||||
|
# 4.0. Now both records are above, and the estimate stays there (at 4.0).
|
||||||
|
|
||||||
|
two_times_root_two = 2 * np.sqrt(2.0) # approx 2.828
|
||||||
|
|
||||||
|
expected_estimates = [1.0, 2.0, two_times_root_two, 4.0, 4.0]
|
||||||
|
for expected_estimate in expected_estimates:
|
||||||
|
actual_estimate, global_state = test_utils.run_query(
|
||||||
|
query, [record1, record2], global_state)
|
||||||
|
|
||||||
|
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('start_low_arithmetic', True, False),
|
||||||
|
('start_low_geometric', True, True),
|
||||||
|
('start_high_arithmetic', False, False),
|
||||||
|
('start_high_geometric', False, True))
|
||||||
|
def test_linspace(self, start_low, geometric):
|
||||||
|
# 100 records equally spaced from 0 to 10 in 0.1 increments.
|
||||||
|
# Test that we converge to the correct median value and bounce around it.
|
||||||
|
num_records = 21
|
||||||
|
records = [tf.constant(x) for x in np.linspace(
|
||||||
|
0.0, 10.0, num=num_records, dtype=np.float32)]
|
||||||
|
|
||||||
|
query = quantile_estimator_query.QuantileEstimatorQuery(
|
||||||
|
initial_estimate=(1.0 if start_low else 10.0),
|
||||||
|
target_quantile=0.5,
|
||||||
|
learning_rate=1.0,
|
||||||
|
below_estimate_stddev=0.0,
|
||||||
|
expected_num_records=num_records,
|
||||||
|
geometric_update=geometric)
|
||||||
|
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
for t in range(50):
|
||||||
|
_, global_state = test_utils.run_query(query, records, global_state)
|
||||||
|
|
||||||
|
actual_estimate = global_state.current_estimate
|
||||||
|
|
||||||
|
if t > 40:
|
||||||
|
self.assertNear(actual_estimate, 5.0, 0.25)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('start_low_arithmetic', True, False),
|
||||||
|
('start_low_geometric', True, True),
|
||||||
|
('start_high_arithmetic', False, False),
|
||||||
|
('start_high_geometric', False, True))
|
||||||
|
def test_all_equal(self, start_low, geometric):
|
||||||
|
# 20 equal records. Test that we converge to that record and bounce around
|
||||||
|
# it. Unlike the linspace test, the quantile-matching objective is very
|
||||||
|
# sharp at the optimum so a decaying learning rate is necessary.
|
||||||
|
num_records = 20
|
||||||
|
records = [tf.constant(5.0)] * num_records
|
||||||
|
|
||||||
|
learning_rate = tf.Variable(1.0)
|
||||||
|
|
||||||
|
query = quantile_estimator_query.QuantileEstimatorQuery(
|
||||||
|
initial_estimate=(1.0 if start_low else 10.0),
|
||||||
|
target_quantile=0.5,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
below_estimate_stddev=0.0,
|
||||||
|
expected_num_records=num_records,
|
||||||
|
geometric_update=geometric)
|
||||||
|
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
for t in range(50):
|
||||||
|
tf.assign(learning_rate, 1.0 / np.sqrt(t + 1))
|
||||||
|
_, global_state = test_utils.run_query(query, records, global_state)
|
||||||
|
|
||||||
|
actual_estimate = global_state.current_estimate
|
||||||
|
|
||||||
|
if t > 40:
|
||||||
|
self.assertNear(actual_estimate, 5.0, 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
Loading…
Reference in a new issue