From cec011e2a797a32a9b102a69d2d799256de7ca98 Mon Sep 17 00:00:00 2001 From: Galen Andrew Date: Mon, 8 Jun 2020 10:06:12 -0700 Subject: [PATCH] Refactor quantile estimation logic from QuantileAdaptiveClipSumQuery so it can be used for other purposes. PiperOrigin-RevId: 315297665 --- tensorflow_privacy/__init__.py | 1 + tensorflow_privacy/privacy/dp_query/BUILD | 25 ++ .../privacy/dp_query/normalized_query.py | 16 +- .../quantile_adaptive_clip_sum_query.py | 138 ++++------- .../quantile_adaptive_clip_sum_query_test.py | 9 +- .../dp_query/quantile_estimator_query.py | 158 +++++++++++++ .../dp_query/quantile_estimator_query_test.py | 219 ++++++++++++++++++ 7 files changed, 458 insertions(+), 108 deletions(-) create mode 100644 tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py create mode 100644 tensorflow_privacy/privacy/dp_query/quantile_estimator_query_test.py diff --git a/tensorflow_privacy/__init__.py b/tensorflow_privacy/__init__.py index 7c00cb0..648e98a 100644 --- a/tensorflow_privacy/__init__.py +++ b/tensorflow_privacy/__init__.py @@ -36,6 +36,7 @@ else: from tensorflow_privacy.privacy.dp_query.no_privacy_query import NoPrivacyAverageQuery from tensorflow_privacy.privacy.dp_query.no_privacy_query import NoPrivacySumQuery from tensorflow_privacy.privacy.dp_query.normalized_query import NormalizedQuery + from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import QuantileEstimatorQuery from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipAverageQuery diff --git a/tensorflow_privacy/privacy/dp_query/BUILD b/tensorflow_privacy/privacy/dp_query/BUILD index 1f1e812..29cd78c 100644 --- a/tensorflow_privacy/privacy/dp_query/BUILD +++ b/tensorflow_privacy/privacy/dp_query/BUILD @@ -120,6 +120,7 @@ py_library( ":dp_query", ":gaussian_query", ":normalized_query", + ":quantile_estimator_query", "//third_party/py/tensorflow", ], ) @@ -139,6 +140,30 @@ py_test( ], ) +py_library( + name = "quantile_estimator_query", + srcs = ["quantile_estimator_query.py"], + deps = [ + ":dp_query", + ":gaussian_query", + "//third_party/py/tensorflow", + ], +) + +py_test( + name = "quantile_estimator_query_test", + srcs = ["quantile_estimator_query_test.py"], + python_version = "PY3", + deps = [ + ":quantile_estimator_query", + ":test_utils", + "//learning/brain/public:disable_tf2", # build_cleaner: keep; go/disable_tf2 + "//third_party/py/absl/testing:parameterized", + "//third_party/py/numpy", + "//third_party/py/tensorflow", + ], +) + py_library( name = "test_utils", srcs = ["test_utils.py"], diff --git a/tensorflow_privacy/privacy/dp_query/normalized_query.py b/tensorflow_privacy/privacy/dp_query/normalized_query.py index f09a113..f5d8f42 100644 --- a/tensorflow_privacy/privacy/dp_query/normalized_query.py +++ b/tensorflow_privacy/privacy/dp_query/normalized_query.py @@ -26,7 +26,7 @@ import tensorflow.compat.v1 as tf from tensorflow_privacy.privacy.dp_query import dp_query -class NormalizedQuery(dp_query.DPQuery): +class NormalizedQuery(dp_query.SumAggregationDPQuery): """DPQuery for queries with a DPQuery numerator and fixed denominator.""" # pylint: disable=invalid-name @@ -37,7 +37,7 @@ class NormalizedQuery(dp_query.DPQuery): """Initializer for NormalizedQuery. Args: - numerator_query: A DPQuery for the numerator. + numerator_query: A SumAggregationDPQuery for the numerator. denominator: A value for the denominator. May be None if it will be supplied via the set_denominator function before get_noised_result is called. @@ -45,6 +45,8 @@ class NormalizedQuery(dp_query.DPQuery): self._numerator = numerator_query self._denominator = denominator + assert isinstance(self._numerator, dp_query.SumAggregationDPQuery) + def set_ledger(self, ledger): """See base class.""" self._numerator.set_ledger(ledger) @@ -70,12 +72,6 @@ class NormalizedQuery(dp_query.DPQuery): def preprocess_record(self, params, record): return self._numerator.preprocess_record(params, record) - def accumulate_preprocessed_record( - self, sample_state, preprocessed_record): - """See base class.""" - return self._numerator.accumulate_preprocessed_record( - sample_state, preprocessed_record) - def get_noised_result(self, sample_state, global_state): """See base class.""" noised_sum, new_sum_global_state = self._numerator.get_noised_result( @@ -85,7 +81,3 @@ class NormalizedQuery(dp_query.DPQuery): return (tf.nest.map_structure(normalize, noised_sum), self._GlobalState(new_sum_global_state, global_state.denominator)) - - def merge_sample_states(self, sample_state_1, sample_state_2): - """See base class.""" - return self._numerator.merge_sample_states(sample_state_1, sample_state_2) diff --git a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py index 25ec88b..b8d4485 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py @@ -32,6 +32,7 @@ import tensorflow.compat.v1 as tf from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import gaussian_query from tensorflow_privacy.privacy.dp_query import normalized_query +from tensorflow_privacy.privacy.dp_query import quantile_estimator_query class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery): @@ -45,18 +46,16 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery): _GlobalState = collections.namedtuple( '_GlobalState', [ 'noise_multiplier', - 'target_unclipped_quantile', - 'learning_rate', 'sum_state', - 'clipped_fraction_state']) + 'quantile_estimator_state']) # pylint: disable=invalid-name _SampleState = collections.namedtuple( - '_SampleState', ['sum_state', 'clipped_fraction_state']) + '_SampleState', ['sum_state', 'quantile_estimator_state']) # pylint: disable=invalid-name _SampleParams = collections.namedtuple( - '_SampleParams', ['sum_params', 'clipped_fraction_params']) + '_SampleParams', ['sum_params', 'quantile_estimator_params']) def __init__( self, @@ -66,7 +65,7 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery): learning_rate, clipped_count_stddev, expected_num_records, - geometric_update=False): + geometric_update=True): """Initializes the QuantileAdaptiveClipSumQuery. Args: @@ -86,126 +85,79 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery): estimate the clipped count quantile. geometric_update: If True, use geometric updating of clip. """ - self._initial_l2_norm_clip = initial_l2_norm_clip self._noise_multiplier = noise_multiplier - self._target_unclipped_quantile = target_unclipped_quantile - self._learning_rate = learning_rate - # Initialize sum query's global state with None, to be set later. - self._sum_query = gaussian_query.GaussianSumQuery(None, None) + self._quantile_estimator_query = quantile_estimator_query.QuantileEstimatorQuery( + initial_l2_norm_clip, + target_unclipped_quantile, + learning_rate, + clipped_count_stddev, + expected_num_records, + geometric_update) - # self._clipped_fraction_query is a DPQuery used to estimate the fraction of - # records that are clipped. It accumulates an indicator 0/1 of whether each - # record is clipped, and normalizes by the expected number of records. In - # practice, we accumulate clipped counts shifted by -0.5 so they are - # centered at zero. This makes the sensitivity of the clipped count query - # 0.5 instead of 1.0, since the maximum that a single record could affect - # the count is 0.5. Note that although the l2_norm_clip of the clipped - # fraction query is 0.5, no clipping will ever actually occur because the - # value of each record is always +/-0.5. - self._clipped_fraction_query = gaussian_query.GaussianAverageQuery( - l2_norm_clip=0.5, - sum_stddev=clipped_count_stddev, - denominator=expected_num_records) + self._sum_query = gaussian_query.GaussianSumQuery( + initial_l2_norm_clip, + noise_multiplier * initial_l2_norm_clip) - self._geometric_update = geometric_update + assert isinstance(self._sum_query, dp_query.SumAggregationDPQuery) + assert isinstance(self._quantile_estimator_query, + dp_query.SumAggregationDPQuery) def set_ledger(self, ledger): """See base class.""" self._sum_query.set_ledger(ledger) - self._clipped_fraction_query.set_ledger(ledger) + self._quantile_estimator_query.set_ledger(ledger) def initial_global_state(self): """See base class.""" - initial_l2_norm_clip = tf.cast(self._initial_l2_norm_clip, tf.float32) - noise_multiplier = tf.cast(self._noise_multiplier, tf.float32) - target_unclipped_quantile = tf.cast(self._target_unclipped_quantile, - tf.float32) - learning_rate = tf.cast(self._learning_rate, tf.float32) - sum_stddev = initial_l2_norm_clip * noise_multiplier - - sum_query_global_state = self._sum_query.make_global_state( - l2_norm_clip=initial_l2_norm_clip, - stddev=sum_stddev) - return self._GlobalState( - noise_multiplier, - target_unclipped_quantile, - learning_rate, - sum_query_global_state, - self._clipped_fraction_query.initial_global_state()) + tf.cast(self._noise_multiplier, tf.float32), + self._sum_query.initial_global_state(), + self._quantile_estimator_query.initial_global_state()) def derive_sample_params(self, global_state): """See base class.""" - - # Assign values to variables that inner sum query uses. - sum_params = self._sum_query.derive_sample_params(global_state.sum_state) - clipped_fraction_params = self._clipped_fraction_query.derive_sample_params( - global_state.clipped_fraction_state) - return self._SampleParams(sum_params, clipped_fraction_params) + return self._SampleParams( + self._sum_query.derive_sample_params(global_state.sum_state), + self._quantile_estimator_query.derive_sample_params( + global_state.quantile_estimator_state)) def initial_sample_state(self, template): """See base class.""" - sum_state = self._sum_query.initial_sample_state(template) - clipped_fraction_state = self._clipped_fraction_query.initial_sample_state( - tf.constant(0.0)) - return self._SampleState(sum_state, clipped_fraction_state) + return self._SampleState( + self._sum_query.initial_sample_state(template), + self._quantile_estimator_query.initial_sample_state(tf.constant(0.0))) def preprocess_record(self, params, record): - preprocessed_sum_record, global_norm = ( + clipped_record, global_norm = ( self._sum_query.preprocess_record_impl(params.sum_params, record)) - # Note we are relying on the internals of GaussianSumQuery here. If we want - # to open this up to other kinds of inner queries we'd have to do this in a - # more general way. - l2_norm_clip = params.sum_params + was_unclipped = self._quantile_estimator_query.preprocess_record( + params.quantile_estimator_params, global_norm) - # We accumulate clipped counts shifted by 0.5 so they are centered at zero. - # This makes the sensitivity of the clipped count query 0.5 instead of 1.0. - was_clipped = tf.cast(global_norm >= l2_norm_clip, tf.float32) - 0.5 - - return self._SampleState(preprocessed_sum_record, was_clipped) + return self._SampleState(clipped_record, was_unclipped) def get_noised_result(self, sample_state, global_state): """See base class.""" - gs = global_state - noised_vectors, sum_state = self._sum_query.get_noised_result( - sample_state.sum_state, gs.sum_state) - del sum_state # Unused. To be set explicitly later. + sample_state.sum_state, global_state.sum_state) + del sum_state # To be set explicitly later when we know the new clip. - clipped_fraction_result, new_clipped_fraction_state = ( - self._clipped_fraction_query.get_noised_result( - sample_state.clipped_fraction_state, - gs.clipped_fraction_state)) - - # Unshift clipped percentile by 0.5. (See comment in initializer.) - clipped_quantile = clipped_fraction_result + 0.5 - unclipped_quantile = 1.0 - clipped_quantile - - # Protect against out-of-range estimates. - unclipped_quantile = tf.minimum(1.0, tf.maximum(0.0, unclipped_quantile)) - - # Loss function is convex, with derivative in [-1, 1], and minimized when - # the true quantile matches the target. - loss_grad = unclipped_quantile - global_state.target_unclipped_quantile - - update = global_state.learning_rate * loss_grad - - if self._geometric_update: - new_l2_norm_clip = gs.sum_state.l2_norm_clip * tf.math.exp(-update) - else: - new_l2_norm_clip = tf.math.maximum(0.0, - gs.sum_state.l2_norm_clip - update) + new_l2_norm_clip, new_quantile_estimator_state = ( + self._quantile_estimator_query.get_noised_result( + sample_state.quantile_estimator_state, + global_state.quantile_estimator_state)) + new_l2_norm_clip = tf.maximum(new_l2_norm_clip, 0.0) new_sum_stddev = new_l2_norm_clip * global_state.noise_multiplier - new_sum_query_global_state = self._sum_query.make_global_state( + new_sum_query_state = self._sum_query.make_global_state( l2_norm_clip=new_l2_norm_clip, stddev=new_sum_stddev) - new_global_state = global_state._replace( - sum_state=new_sum_query_global_state, - clipped_fraction_state=new_clipped_fraction_state) + new_global_state = self._GlobalState( + global_state.noise_multiplier, + new_sum_query_state, + new_quantile_estimator_state) return noised_vectors, new_global_state diff --git a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py index 033be71..4b21d41 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py @@ -138,7 +138,8 @@ class QuantileAdaptiveClipSumQueryTest( target_unclipped_quantile=0.0, learning_rate=1.0, clipped_count_stddev=0.0, - expected_num_records=2.0) + expected_num_records=2.0, + geometric_update=False) global_state = query.initial_global_state() @@ -207,7 +208,8 @@ class QuantileAdaptiveClipSumQueryTest( target_unclipped_quantile=1.0, learning_rate=1.0, clipped_count_stddev=0.0, - expected_num_records=2.0) + expected_num_records=2.0, + geometric_update=False) global_state = query.initial_global_state() @@ -344,7 +346,8 @@ class QuantileAdaptiveClipSumQueryTest( target_unclipped_quantile=0.0, learning_rate=1.0, clipped_count_stddev=0.0, - expected_num_records=2.0) + expected_num_records=2.0, + geometric_update=False) query = privacy_ledger.QueryWithLedger( query, population_size, selection_probability) diff --git a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py new file mode 100644 index 0000000..cdb93ae --- /dev/null +++ b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py @@ -0,0 +1,158 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implements DPQuery interface for quantile estimator. + +From a starting estimate of the target quantile, the estimate is updated +dynamically where the fraction of below_estimate updates is estimated in a +differentially private manner. For details see Thakkar et al., "Differentially +Private Learning with Adaptive Clipping" [http://arxiv.org/abs/1905.03871]. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import tensorflow.compat.v1 as tf + +from tensorflow_privacy.privacy.dp_query import dp_query +from tensorflow_privacy.privacy.dp_query import gaussian_query + + +class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery): + """Defines iterative process to estimate a target quantile of a distribution. + """ + + # pylint: disable=invalid-name + _GlobalState = collections.namedtuple( + '_GlobalState', [ + 'current_estimate', + 'target_quantile', + 'learning_rate', + 'below_estimate_state']) + + # pylint: disable=invalid-name + _SampleParams = collections.namedtuple( + '_SampleParams', ['current_estimate', 'below_estimate_params']) + + # No separate SampleState-- sample state is just below_estimate_query's + # SampleState. + + def __init__( + self, + initial_estimate, + target_quantile, + learning_rate, + below_estimate_stddev, + expected_num_records, + geometric_update=False): + """Initializes the QuantileAdaptiveClipSumQuery. + + Args: + initial_estimate: The initial estimate of the quantile. + target_quantile: The target quantile. I.e., a value of 0.8 means a value + should be found for which approximately 80% of updates are + less than the estimate each round. + learning_rate: The learning rate. A rate of r means that the estimate + will change by a maximum of r at each step (for arithmetic updating) or + by a maximum factor of exp(r) (for geometric updating). + below_estimate_stddev: The stddev of the noise added to the count of + records currently below the estimate. Since the sensitivity of the count + query is 0.5, as a rule of thumb it should be about 0.5 for reasonable + privacy. + expected_num_records: The expected number of records per round. + geometric_update: If True, use geometric updating of estimate. Geometric + updating is preferred for non-negative records like vector norms that + could potentially be very large or very close to zero. + """ + self._initial_estimate = initial_estimate + self._target_quantile = target_quantile + self._learning_rate = learning_rate + + # A DPQuery used to estimate the fraction of records that are less than the + # current quantile estimate. It accumulates an indicator 0/1 of whether each + # record is below the estimate, and normalizes by the expected number of + # records. In practice, we accumulate counts shifted by -0.5 so they are + # centered at zero. This makes the sensitivity of the below_estimate count + # query 0.5 instead of 1.0, since the maximum that a single record could + # affect the count is 0.5. Note that although the l2_norm_clip of the + # below_estimate query is 0.5, no clipping will ever actually occur + # because the value of each record is always +/-0.5. + self._below_estimate_query = gaussian_query.GaussianAverageQuery( + l2_norm_clip=0.5, + sum_stddev=below_estimate_stddev, + denominator=expected_num_records) + + self._geometric_update = geometric_update + + assert isinstance(self._below_estimate_query, + dp_query.SumAggregationDPQuery) + + def set_ledger(self, ledger): + """See base class.""" + self._below_estimate_query.set_ledger(ledger) + + def initial_global_state(self): + """See base class.""" + return self._GlobalState( + tf.cast(self._initial_estimate, tf.float32), + tf.cast(self._target_quantile, tf.float32), + tf.cast(self._learning_rate, tf.float32), + self._below_estimate_query.initial_global_state()) + + def derive_sample_params(self, global_state): + """See base class.""" + below_estimate_params = self._below_estimate_query.derive_sample_params( + global_state.below_estimate_state) + return self._SampleParams(global_state.current_estimate, + below_estimate_params) + + def preprocess_record(self, params, record): + # We accumulate counts shifted by 0.5 so they are centered at zero. + # This makes the sensitivity of the count query 0.5 instead of 1.0. + below = tf.cast(record <= params.current_estimate, tf.float32) - 0.5 + return self._below_estimate_query.preprocess_record( + params.below_estimate_params, below) + + def get_noised_result(self, sample_state, global_state): + """See base class.""" + below_estimate_result, new_below_estimate_state = ( + self._below_estimate_query.get_noised_result( + sample_state, + global_state.below_estimate_state)) + + # Unshift below_estimate percentile by 0.5. (See comment in initializer.) + below_estimate = below_estimate_result + 0.5 + + # Protect against out-of-range estimates. + below_estimate = tf.minimum(1.0, tf.maximum(0.0, below_estimate)) + + # Loss function is convex, with derivative in [-1, 1], and minimized when + # the true quantile matches the target. + loss_grad = below_estimate - global_state.target_quantile + + update = global_state.learning_rate * loss_grad + + if self._geometric_update: + new_estimate = global_state.current_estimate * tf.math.exp(-update) + else: + new_estimate = global_state.current_estimate - update + + new_global_state = global_state._replace( + current_estimate=new_estimate, + below_estimate_state=new_below_estimate_state) + + return new_estimate, new_global_state diff --git a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query_test.py b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query_test.py new file mode 100644 index 0000000..5eb5def --- /dev/null +++ b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query_test.py @@ -0,0 +1,219 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for QuantileAdaptiveClipSumQuery.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import numpy as np +import tensorflow.compat.v1 as tf + +from tensorflow_privacy.privacy.dp_query import quantile_estimator_query +from tensorflow_privacy.privacy.dp_query import test_utils + +tf.enable_eager_execution() + + +class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase): + + def test_target_zero(self): + record1 = tf.constant(8.5) + record2 = tf.constant(7.25) + + query = quantile_estimator_query.QuantileEstimatorQuery( + initial_estimate=10.0, + target_quantile=0.0, + learning_rate=1.0, + below_estimate_stddev=0.0, + expected_num_records=2.0, + geometric_update=False) + + global_state = query.initial_global_state() + + initial_estimate = global_state.current_estimate + self.assertAllClose(initial_estimate, 10.0) + + # On the first two iterations, both records are below, so the estimate goes + # down by 1.0 (the learning rate). When the estimate reaches 8.0, only one + # record is below, so the estimate goes down by only 0.5. After two more + # iterations, both records are below, and the estimate stays there (at 7.0). + + expected_estimates = [9.0, 8.0, 7.5, 7.0, 7.0] + for expected_estimate in expected_estimates: + actual_estimate, global_state = test_utils.run_query( + query, [record1, record2], global_state) + + self.assertAllClose(actual_estimate.numpy(), expected_estimate) + + def test_target_zero_geometric(self): + record1 = tf.constant(5.0) + record2 = tf.constant(2.5) + + query = quantile_estimator_query.QuantileEstimatorQuery( + initial_estimate=16.0, + target_quantile=0.0, + learning_rate=np.log(2.0), # Geometric steps in powers of 2. + below_estimate_stddev=0.0, + expected_num_records=2.0, + geometric_update=True) + + global_state = query.initial_global_state() + + initial_estimate = global_state.current_estimate + self.assertAllClose(initial_estimate, 16.0) + + # For two iterations, both records are below, so the estimate is halved. + # Then only one record is below, so the estimate goes down by only sqrt(2.0) + # to 4 / sqrt(2.0). Still only one record is below, so it reduces to 2.0. + # Now no records are below, and the estimate norm stays there (at 2.0). + + four_div_root_two = 4 / np.sqrt(2.0) # approx 2.828 + + expected_estimates = [8.0, 4.0, four_div_root_two, 2.0, 2.0] + for expected_estimate in expected_estimates: + actual_estimate, global_state = test_utils.run_query( + query, [record1, record2], global_state) + + self.assertAllClose(actual_estimate.numpy(), expected_estimate) + + def test_target_one(self): + record1 = tf.constant(1.5) + record2 = tf.constant(2.75) + + query = quantile_estimator_query.QuantileEstimatorQuery( + initial_estimate=0.0, + target_quantile=1.0, + learning_rate=1.0, + below_estimate_stddev=0.0, + expected_num_records=2.0, + geometric_update=False) + + global_state = query.initial_global_state() + + initial_estimate = global_state.current_estimate + self.assertAllClose(initial_estimate, 0.0) + + # On the first two iterations, both are above, so the estimate goes up + # by 1.0 (the learning rate). When it reaches 2.0, only one record is + # above, so the estimate goes up by only 0.5. After two more iterations, + # both records are below, and the estimate stays there (at 3.0). + + expected_estimates = [1.0, 2.0, 2.5, 3.0, 3.0] + for expected_estimate in expected_estimates: + actual_estimate, global_state = test_utils.run_query( + query, [record1, record2], global_state) + + self.assertAllClose(actual_estimate.numpy(), expected_estimate) + + def test_target_one_geometric(self): + record1 = tf.constant(1.5) + record2 = tf.constant(3.0) + + query = quantile_estimator_query.QuantileEstimatorQuery( + initial_estimate=0.5, + target_quantile=1.0, + learning_rate=np.log(2.0), # Geometric steps in powers of 2. + below_estimate_stddev=0.0, + expected_num_records=2.0, + geometric_update=True) + + global_state = query.initial_global_state() + + initial_estimate = global_state.current_estimate + self.assertAllClose(initial_estimate, 0.5) + + # On the first two iterations, both are above, so the estimate is doubled. + # When the estimate reaches 2.0, only one record is above, so the estimate + # is multiplied by sqrt(2.0). Still only one is above so it increases to + # 4.0. Now both records are above, and the estimate stays there (at 4.0). + + two_times_root_two = 2 * np.sqrt(2.0) # approx 2.828 + + expected_estimates = [1.0, 2.0, two_times_root_two, 4.0, 4.0] + for expected_estimate in expected_estimates: + actual_estimate, global_state = test_utils.run_query( + query, [record1, record2], global_state) + + self.assertAllClose(actual_estimate.numpy(), expected_estimate) + + @parameterized.named_parameters( + ('start_low_arithmetic', True, False), + ('start_low_geometric', True, True), + ('start_high_arithmetic', False, False), + ('start_high_geometric', False, True)) + def test_linspace(self, start_low, geometric): + # 100 records equally spaced from 0 to 10 in 0.1 increments. + # Test that we converge to the correct median value and bounce around it. + num_records = 21 + records = [tf.constant(x) for x in np.linspace( + 0.0, 10.0, num=num_records, dtype=np.float32)] + + query = quantile_estimator_query.QuantileEstimatorQuery( + initial_estimate=(1.0 if start_low else 10.0), + target_quantile=0.5, + learning_rate=1.0, + below_estimate_stddev=0.0, + expected_num_records=num_records, + geometric_update=geometric) + + global_state = query.initial_global_state() + + for t in range(50): + _, global_state = test_utils.run_query(query, records, global_state) + + actual_estimate = global_state.current_estimate + + if t > 40: + self.assertNear(actual_estimate, 5.0, 0.25) + + @parameterized.named_parameters( + ('start_low_arithmetic', True, False), + ('start_low_geometric', True, True), + ('start_high_arithmetic', False, False), + ('start_high_geometric', False, True)) + def test_all_equal(self, start_low, geometric): + # 20 equal records. Test that we converge to that record and bounce around + # it. Unlike the linspace test, the quantile-matching objective is very + # sharp at the optimum so a decaying learning rate is necessary. + num_records = 20 + records = [tf.constant(5.0)] * num_records + + learning_rate = tf.Variable(1.0) + + query = quantile_estimator_query.QuantileEstimatorQuery( + initial_estimate=(1.0 if start_low else 10.0), + target_quantile=0.5, + learning_rate=learning_rate, + below_estimate_stddev=0.0, + expected_num_records=num_records, + geometric_update=geometric) + + global_state = query.initial_global_state() + + for t in range(50): + tf.assign(learning_rate, 1.0 / np.sqrt(t + 1)) + _, global_state = test_utils.run_query(query, records, global_state) + + actual_estimate = global_state.current_estimate + + if t > 40: + self.assertNear(actual_estimate, 5.0, 0.5) + + +if __name__ == '__main__': + tf.test.main()