Allow exact denominator for below estimate fraction used by quantile estimator.
Also: 1) Check that records for quantile estimator query are scalars. 2) Add tests of quantile estimator with noise. 3) Add privacy ledger to no-privacy queries. PiperOrigin-RevId: 320633937
This commit is contained in:
parent
d1e2cc1930
commit
2f51adac89
5 changed files with 177 additions and 43 deletions
|
@ -146,6 +146,7 @@ py_library(
|
||||||
deps = [
|
deps = [
|
||||||
":dp_query",
|
":dp_query",
|
||||||
":gaussian_query",
|
":gaussian_query",
|
||||||
|
":no_privacy_query",
|
||||||
"//third_party/py/tensorflow",
|
"//third_party/py/tensorflow",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2018, The TensorFlow Authors.
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
|
@ -28,8 +30,26 @@ class NoPrivacySumQuery(dp_query.SumAggregationDPQuery):
|
||||||
Accumulates vectors without clipping or adding noise.
|
Accumulates vectors without clipping or adding noise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._ledger = None
|
||||||
|
|
||||||
|
def set_ledger(self, ledger):
|
||||||
|
warnings.warn(
|
||||||
|
'Attempt to use NoPrivacySumQuery with privacy ledger. Privacy '
|
||||||
|
'guarantees will be vacuous.')
|
||||||
|
self._ledger = ledger
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
|
|
||||||
|
if self._ledger:
|
||||||
|
dependencies = [
|
||||||
|
self._ledger.record_sum_query(float('inf'), 0.0)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
dependencies = []
|
||||||
|
|
||||||
|
with tf.control_dependencies(dependencies):
|
||||||
return sample_state, global_state
|
return sample_state, global_state
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,6 +59,15 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
|
||||||
Accumulates vectors and normalizes by the total number of accumulated vectors.
|
Accumulates vectors and normalizes by the total number of accumulated vectors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._ledger = None
|
||||||
|
|
||||||
|
def set_ledger(self, ledger):
|
||||||
|
warnings.warn(
|
||||||
|
'Attempt to use NoPrivacyAverageQuery with privacy ledger. Privacy '
|
||||||
|
'guarantees will be vacuous.')
|
||||||
|
self._ledger = ledger
|
||||||
|
|
||||||
def initial_sample_state(self, template):
|
def initial_sample_state(self, template):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return (super(NoPrivacyAverageQuery, self).initial_sample_state(template),
|
return (super(NoPrivacyAverageQuery, self).initial_sample_state(template),
|
||||||
|
@ -59,5 +88,13 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
sum_state, denominator = sample_state
|
sum_state, denominator = sample_state
|
||||||
|
|
||||||
|
if self._ledger:
|
||||||
|
dependencies = [
|
||||||
|
self._ledger.record_sum_query(float('inf'), 0.0)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
dependencies = []
|
||||||
|
|
||||||
|
with tf.control_dependencies(dependencies):
|
||||||
return (tf.nest.map_structure(lambda t: t / denominator,
|
return (tf.nest.map_structure(lambda t: t / denominator,
|
||||||
sum_state), global_state)
|
sum_state), global_state)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2019, The TensorFlow Authors.
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -126,7 +126,7 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._SampleState(
|
return self._SampleState(
|
||||||
self._sum_query.initial_sample_state(template),
|
self._sum_query.initial_sample_state(template),
|
||||||
self._quantile_estimator_query.initial_sample_state(tf.constant(0.0)))
|
self._quantile_estimator_query.initial_sample_state())
|
||||||
|
|
||||||
def preprocess_record(self, params, record):
|
def preprocess_record(self, params, record):
|
||||||
clipped_record, global_norm = (
|
clipped_record, global_norm = (
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2019, The TensorFlow Authors.
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -30,10 +30,11 @@ import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||||
|
from tensorflow_privacy.privacy.dp_query import no_privacy_query
|
||||||
|
|
||||||
|
|
||||||
class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""Defines iterative process to estimate a target quantile of a distribution.
|
"""Iterative process to estimate target quantile of a univariate distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
|
@ -82,6 +83,15 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
self._target_quantile = target_quantile
|
self._target_quantile = target_quantile
|
||||||
self._learning_rate = learning_rate
|
self._learning_rate = learning_rate
|
||||||
|
|
||||||
|
self._below_estimate_query = self._construct_below_estimate_query(
|
||||||
|
below_estimate_stddev, expected_num_records)
|
||||||
|
assert isinstance(self._below_estimate_query,
|
||||||
|
dp_query.SumAggregationDPQuery)
|
||||||
|
|
||||||
|
self._geometric_update = geometric_update
|
||||||
|
|
||||||
|
def _construct_below_estimate_query(
|
||||||
|
self, below_estimate_stddev, expected_num_records):
|
||||||
# A DPQuery used to estimate the fraction of records that are less than the
|
# A DPQuery used to estimate the fraction of records that are less than the
|
||||||
# current quantile estimate. It accumulates an indicator 0/1 of whether each
|
# current quantile estimate. It accumulates an indicator 0/1 of whether each
|
||||||
# record is below the estimate, and normalizes by the expected number of
|
# record is below the estimate, and normalizes by the expected number of
|
||||||
|
@ -91,16 +101,11 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
# affect the count is 0.5. Note that although the l2_norm_clip of the
|
# affect the count is 0.5. Note that although the l2_norm_clip of the
|
||||||
# below_estimate query is 0.5, no clipping will ever actually occur
|
# below_estimate query is 0.5, no clipping will ever actually occur
|
||||||
# because the value of each record is always +/-0.5.
|
# because the value of each record is always +/-0.5.
|
||||||
self._below_estimate_query = gaussian_query.GaussianAverageQuery(
|
return gaussian_query.GaussianAverageQuery(
|
||||||
l2_norm_clip=0.5,
|
l2_norm_clip=0.5,
|
||||||
sum_stddev=below_estimate_stddev,
|
sum_stddev=below_estimate_stddev,
|
||||||
denominator=expected_num_records)
|
denominator=expected_num_records)
|
||||||
|
|
||||||
self._geometric_update = geometric_update
|
|
||||||
|
|
||||||
assert isinstance(self._below_estimate_query,
|
|
||||||
dp_query.SumAggregationDPQuery)
|
|
||||||
|
|
||||||
def set_ledger(self, ledger):
|
def set_ledger(self, ledger):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
self._below_estimate_query.set_ledger(ledger)
|
self._below_estimate_query.set_ledger(ledger)
|
||||||
|
@ -120,7 +125,15 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
return self._SampleParams(global_state.current_estimate,
|
return self._SampleParams(global_state.current_estimate,
|
||||||
below_estimate_params)
|
below_estimate_params)
|
||||||
|
|
||||||
|
def initial_sample_state(self, template=None):
|
||||||
|
# Template is ignored because records are required to be scalars.
|
||||||
|
del template
|
||||||
|
|
||||||
|
return self._below_estimate_query.initial_sample_state(0.0)
|
||||||
|
|
||||||
def preprocess_record(self, params, record):
|
def preprocess_record(self, params, record):
|
||||||
|
tf.debugging.assert_scalar(record)
|
||||||
|
|
||||||
# We accumulate counts shifted by 0.5 so they are centered at zero.
|
# We accumulate counts shifted by 0.5 so they are centered at zero.
|
||||||
# This makes the sensitivity of the count query 0.5 instead of 1.0.
|
# This makes the sensitivity of the count query 0.5 instead of 1.0.
|
||||||
below = tf.cast(record <= params.current_estimate, tf.float32) - 0.5
|
below = tf.cast(record <= params.current_estimate, tf.float32) - 0.5
|
||||||
|
@ -156,3 +169,42 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
below_estimate_state=new_below_estimate_state)
|
below_estimate_state=new_below_estimate_state)
|
||||||
|
|
||||||
return new_estimate, new_global_state
|
return new_estimate, new_global_state
|
||||||
|
|
||||||
|
|
||||||
|
class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
|
||||||
|
"""Iterative process to estimate target quantile of a univariate distribution.
|
||||||
|
|
||||||
|
Unlike the base class, this uses a NoPrivacyQuery to estimate the fraction
|
||||||
|
below estimate with an exact denominator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
initial_estimate,
|
||||||
|
target_quantile,
|
||||||
|
learning_rate,
|
||||||
|
geometric_update=False):
|
||||||
|
"""Initializes the NoPrivacyQuantileEstimatorQuery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_estimate: The initial estimate of the quantile.
|
||||||
|
target_quantile: The target quantile. I.e., a value of 0.8 means a value
|
||||||
|
should be found for which approximately 80% of updates are
|
||||||
|
less than the estimate each round.
|
||||||
|
learning_rate: The learning rate. A rate of r means that the estimate
|
||||||
|
will change by a maximum of r at each step (for arithmetic updating) or
|
||||||
|
by a maximum factor of exp(r) (for geometric updating).
|
||||||
|
geometric_update: If True, use geometric updating of estimate. Geometric
|
||||||
|
updating is preferred for non-negative records like vector norms that
|
||||||
|
could potentially be very large or very close to zero.
|
||||||
|
"""
|
||||||
|
super(NoPrivacyQuantileEstimatorQuery, self).__init__(
|
||||||
|
initial_estimate, target_quantile, learning_rate,
|
||||||
|
below_estimate_stddev=None, expected_num_records=None,
|
||||||
|
geometric_update=geometric_update)
|
||||||
|
|
||||||
|
def _construct_below_estimate_query(
|
||||||
|
self, below_estimate_stddev, expected_num_records):
|
||||||
|
del below_estimate_stddev
|
||||||
|
del expected_num_records
|
||||||
|
return no_privacy_query.NoPrivacyAverageQuery()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2019, The TensorFlow Authors.
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -29,18 +29,42 @@ from tensorflow_privacy.privacy.dp_query import test_utils
|
||||||
tf.enable_eager_execution()
|
tf.enable_eager_execution()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_quantile_estimator_query(
|
||||||
|
initial_estimate,
|
||||||
|
target_quantile,
|
||||||
|
learning_rate,
|
||||||
|
below_estimate_stddev,
|
||||||
|
expected_num_records,
|
||||||
|
geometric_update):
|
||||||
|
if expected_num_records is not None:
|
||||||
|
return quantile_estimator_query.QuantileEstimatorQuery(
|
||||||
|
initial_estimate,
|
||||||
|
target_quantile,
|
||||||
|
learning_rate,
|
||||||
|
below_estimate_stddev,
|
||||||
|
expected_num_records,
|
||||||
|
geometric_update)
|
||||||
|
else:
|
||||||
|
return quantile_estimator_query.NoPrivacyQuantileEstimatorQuery(
|
||||||
|
initial_estimate,
|
||||||
|
target_quantile,
|
||||||
|
learning_rate,
|
||||||
|
geometric_update)
|
||||||
|
|
||||||
|
|
||||||
class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def test_target_zero(self):
|
@parameterized.named_parameters(('exact', True), ('fixed', False))
|
||||||
|
def test_target_zero(self, exact):
|
||||||
record1 = tf.constant(8.5)
|
record1 = tf.constant(8.5)
|
||||||
record2 = tf.constant(7.25)
|
record2 = tf.constant(7.25)
|
||||||
|
|
||||||
query = quantile_estimator_query.QuantileEstimatorQuery(
|
query = _make_quantile_estimator_query(
|
||||||
initial_estimate=10.0,
|
initial_estimate=10.0,
|
||||||
target_quantile=0.0,
|
target_quantile=0.0,
|
||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
below_estimate_stddev=0.0,
|
below_estimate_stddev=0.0,
|
||||||
expected_num_records=2.0,
|
expected_num_records=(None if exact else 2.0),
|
||||||
geometric_update=False)
|
geometric_update=False)
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
@ -60,16 +84,17 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
||||||
|
|
||||||
def test_target_zero_geometric(self):
|
@parameterized.named_parameters(('exact', True), ('fixed', False))
|
||||||
|
def test_target_zero_geometric(self, exact):
|
||||||
record1 = tf.constant(5.0)
|
record1 = tf.constant(5.0)
|
||||||
record2 = tf.constant(2.5)
|
record2 = tf.constant(2.5)
|
||||||
|
|
||||||
query = quantile_estimator_query.QuantileEstimatorQuery(
|
query = _make_quantile_estimator_query(
|
||||||
initial_estimate=16.0,
|
initial_estimate=16.0,
|
||||||
target_quantile=0.0,
|
target_quantile=0.0,
|
||||||
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
|
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
|
||||||
below_estimate_stddev=0.0,
|
below_estimate_stddev=0.0,
|
||||||
expected_num_records=2.0,
|
expected_num_records=(None if exact else 2.0),
|
||||||
geometric_update=True)
|
geometric_update=True)
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
@ -91,16 +116,17 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
||||||
|
|
||||||
def test_target_one(self):
|
@parameterized.named_parameters(('exact', True), ('fixed', False))
|
||||||
|
def test_target_one(self, exact):
|
||||||
record1 = tf.constant(1.5)
|
record1 = tf.constant(1.5)
|
||||||
record2 = tf.constant(2.75)
|
record2 = tf.constant(2.75)
|
||||||
|
|
||||||
query = quantile_estimator_query.QuantileEstimatorQuery(
|
query = _make_quantile_estimator_query(
|
||||||
initial_estimate=0.0,
|
initial_estimate=0.0,
|
||||||
target_quantile=1.0,
|
target_quantile=1.0,
|
||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
below_estimate_stddev=0.0,
|
below_estimate_stddev=0.0,
|
||||||
expected_num_records=2.0,
|
expected_num_records=(None if exact else 2.0),
|
||||||
geometric_update=False)
|
geometric_update=False)
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
@ -120,16 +146,17 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
||||||
|
|
||||||
def test_target_one_geometric(self):
|
@parameterized.named_parameters(('exact', True), ('fixed', False))
|
||||||
|
def test_target_one_geometric(self, exact):
|
||||||
record1 = tf.constant(1.5)
|
record1 = tf.constant(1.5)
|
||||||
record2 = tf.constant(3.0)
|
record2 = tf.constant(3.0)
|
||||||
|
|
||||||
query = quantile_estimator_query.QuantileEstimatorQuery(
|
query = _make_quantile_estimator_query(
|
||||||
initial_estimate=0.5,
|
initial_estimate=0.5,
|
||||||
target_quantile=1.0,
|
target_quantile=1.0,
|
||||||
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
|
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
|
||||||
below_estimate_stddev=0.0,
|
below_estimate_stddev=0.0,
|
||||||
expected_num_records=2.0,
|
expected_num_records=(None if exact else 2.0),
|
||||||
geometric_update=True)
|
geometric_update=True)
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
@ -152,23 +179,27 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('start_low_arithmetic', True, False),
|
('start_low_geometric_exact', True, True, True),
|
||||||
('start_low_geometric', True, True),
|
('start_low_arithmetic_exact', True, True, False),
|
||||||
('start_high_arithmetic', False, False),
|
('start_high_geometric_exact', True, False, True),
|
||||||
('start_high_geometric', False, True))
|
('start_high_arithmetic_exact', True, False, False),
|
||||||
def test_linspace(self, start_low, geometric):
|
('start_low_geometric_noised', False, True, True),
|
||||||
|
('start_low_arithmetic_noised', False, True, False),
|
||||||
|
('start_high_geometric_noised', False, False, True),
|
||||||
|
('start_high_arithmetic_noised', False, False, False))
|
||||||
|
def test_linspace(self, exact, start_low, geometric):
|
||||||
# 100 records equally spaced from 0 to 10 in 0.1 increments.
|
# 100 records equally spaced from 0 to 10 in 0.1 increments.
|
||||||
# Test that we converge to the correct median value and bounce around it.
|
# Test that we converge to the correct median value and bounce around it.
|
||||||
num_records = 21
|
num_records = 21
|
||||||
records = [tf.constant(x) for x in np.linspace(
|
records = [tf.constant(x) for x in np.linspace(
|
||||||
0.0, 10.0, num=num_records, dtype=np.float32)]
|
0.0, 10.0, num=num_records, dtype=np.float32)]
|
||||||
|
|
||||||
query = quantile_estimator_query.QuantileEstimatorQuery(
|
query = _make_quantile_estimator_query(
|
||||||
initial_estimate=(1.0 if start_low else 10.0),
|
initial_estimate=(1.0 if start_low else 10.0),
|
||||||
target_quantile=0.5,
|
target_quantile=0.5,
|
||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
below_estimate_stddev=0.0,
|
below_estimate_stddev=(0.0 if exact else 1e-2),
|
||||||
expected_num_records=num_records,
|
expected_num_records=(None if exact else num_records),
|
||||||
geometric_update=geometric)
|
geometric_update=geometric)
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
@ -182,11 +213,15 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertNear(actual_estimate, 5.0, 0.25)
|
self.assertNear(actual_estimate, 5.0, 0.25)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('start_low_arithmetic', True, False),
|
('start_low_geometric_exact', True, True, True),
|
||||||
('start_low_geometric', True, True),
|
('start_low_arithmetic_exact', True, True, False),
|
||||||
('start_high_arithmetic', False, False),
|
('start_high_geometric_exact', True, False, True),
|
||||||
('start_high_geometric', False, True))
|
('start_high_arithmetic_exact', True, False, False),
|
||||||
def test_all_equal(self, start_low, geometric):
|
('start_low_geometric_noised', False, True, True),
|
||||||
|
('start_low_arithmetic_noised', False, True, False),
|
||||||
|
('start_high_geometric_noised', False, False, True),
|
||||||
|
('start_high_arithmetic_noised', False, False, False))
|
||||||
|
def test_all_equal(self, exact, start_low, geometric):
|
||||||
# 20 equal records. Test that we converge to that record and bounce around
|
# 20 equal records. Test that we converge to that record and bounce around
|
||||||
# it. Unlike the linspace test, the quantile-matching objective is very
|
# it. Unlike the linspace test, the quantile-matching objective is very
|
||||||
# sharp at the optimum so a decaying learning rate is necessary.
|
# sharp at the optimum so a decaying learning rate is necessary.
|
||||||
|
@ -195,12 +230,12 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
learning_rate = tf.Variable(1.0)
|
learning_rate = tf.Variable(1.0)
|
||||||
|
|
||||||
query = quantile_estimator_query.QuantileEstimatorQuery(
|
query = _make_quantile_estimator_query(
|
||||||
initial_estimate=(1.0 if start_low else 10.0),
|
initial_estimate=(1.0 if start_low else 10.0),
|
||||||
target_quantile=0.5,
|
target_quantile=0.5,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
below_estimate_stddev=0.0,
|
below_estimate_stddev=(0.0 if exact else 1e-2),
|
||||||
expected_num_records=num_records,
|
expected_num_records=(None if exact else num_records),
|
||||||
geometric_update=geometric)
|
geometric_update=geometric)
|
||||||
|
|
||||||
global_state = query.initial_global_state()
|
global_state = query.initial_global_state()
|
||||||
|
@ -214,6 +249,15 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
if t > 40:
|
if t > 40:
|
||||||
self.assertNear(actual_estimate, 5.0, 0.5)
|
self.assertNear(actual_estimate, 5.0, 0.5)
|
||||||
|
|
||||||
|
def test_raises_with_non_scalar_record(self):
|
||||||
|
query = quantile_estimator_query.NoPrivacyQuantileEstimatorQuery(
|
||||||
|
initial_estimate=1.0,
|
||||||
|
target_quantile=0.5,
|
||||||
|
learning_rate=1.0)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, 'scalar'):
|
||||||
|
query.accumulate_record(None, None, [1.0, 2.0])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue