forked from 626_privacy/tensorflow_privacy
Remove GaussianAverageQuery. Users can simply wrap GaussianSumQuery with a NormalizedQuery.
PiperOrigin-RevId: 374784618
This commit is contained in:
parent
1de7e4dde4
commit
e5848656ed
7 changed files with 96 additions and 189 deletions
|
@ -39,7 +39,6 @@ else:
|
|||
# DPQuery classes
|
||||
from tensorflow_privacy.privacy.dp_query.dp_query import DPQuery
|
||||
from tensorflow_privacy.privacy.dp_query.dp_query import SumAggregationDPQuery
|
||||
from tensorflow_privacy.privacy.dp_query.gaussian_query import GaussianAverageQuery
|
||||
from tensorflow_privacy.privacy.dp_query.gaussian_query import GaussianSumQuery
|
||||
from tensorflow_privacy.privacy.dp_query.nested_query import NestedQuery
|
||||
from tensorflow_privacy.privacy.dp_query.no_privacy_query import NoPrivacyAverageQuery
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for PrivacyLedger."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
@ -57,10 +56,9 @@ class PrivacyLedgerTest(tf.test.TestCase):
|
|||
population_size = tf.Variable(0)
|
||||
selection_probability = tf.Variable(1.0)
|
||||
|
||||
query = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=10.0, stddev=0.0)
|
||||
query = privacy_ledger.QueryWithLedger(
|
||||
query, population_size, selection_probability)
|
||||
query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
|
||||
query = privacy_ledger.QueryWithLedger(query, population_size,
|
||||
selection_probability)
|
||||
|
||||
# First sample.
|
||||
tf.assign(population_size, 10)
|
||||
|
@ -93,14 +91,12 @@ class PrivacyLedgerTest(tf.test.TestCase):
|
|||
population_size = tf.Variable(0)
|
||||
selection_probability = tf.Variable(1.0)
|
||||
|
||||
query1 = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=4.0, sum_stddev=2.0, denominator=5.0)
|
||||
query2 = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=5.0, sum_stddev=1.0, denominator=5.0)
|
||||
query1 = gaussian_query.GaussianSumQuery(l2_norm_clip=4.0, stddev=2.0)
|
||||
query2 = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=1.0)
|
||||
|
||||
query = nested_query.NestedQuery([query1, query2])
|
||||
query = privacy_ledger.QueryWithLedger(
|
||||
query, population_size, selection_probability)
|
||||
query = privacy_ledger.QueryWithLedger(query, population_size,
|
||||
selection_probability)
|
||||
|
||||
record1 = [1.0, [12.0, 9.0]]
|
||||
record2 = [5.0, [1.0, 2.0]]
|
||||
|
|
|
@ -11,9 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Implements DPQuery interface for Gaussian average queries.
|
||||
"""
|
||||
"""Implements DPQuery interface for Gaussian sum queries."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -25,7 +23,6 @@ import distutils
|
|||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
from tensorflow_privacy.privacy.dp_query import normalized_query
|
||||
|
||||
|
||||
class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||
|
@ -35,8 +32,8 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
|||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_GlobalState = collections.namedtuple(
|
||||
'_GlobalState', ['l2_norm_clip', 'stddev'])
|
||||
_GlobalState = collections.namedtuple('_GlobalState',
|
||||
['l2_norm_clip', 'stddev'])
|
||||
|
||||
def __init__(self, l2_norm_clip, stddev):
|
||||
"""Initializes the GaussianSumQuery.
|
||||
|
@ -55,8 +52,8 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
|||
|
||||
def make_global_state(self, l2_norm_clip, stddev):
|
||||
"""Creates a global state from the given parameters."""
|
||||
return self._GlobalState(tf.cast(l2_norm_clip, tf.float32),
|
||||
tf.cast(stddev, tf.float32))
|
||||
return self._GlobalState(
|
||||
tf.cast(l2_norm_clip, tf.float32), tf.cast(stddev, tf.float32))
|
||||
|
||||
def initial_global_state(self):
|
||||
return self.make_global_state(self._l2_norm_clip, self._stddev)
|
||||
|
@ -94,48 +91,17 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
|||
return v + tf.random.normal(
|
||||
tf.shape(input=v), stddev=global_state.stddev, dtype=v.dtype)
|
||||
else:
|
||||
random_normal = tf.random_normal_initializer(
|
||||
stddev=global_state.stddev)
|
||||
random_normal = tf.random_normal_initializer(stddev=global_state.stddev)
|
||||
|
||||
def add_noise(v):
|
||||
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
|
||||
|
||||
if self._ledger:
|
||||
dependencies = [
|
||||
self._ledger.record_sum_query(
|
||||
global_state.l2_norm_clip, global_state.stddev)
|
||||
self._ledger.record_sum_query(global_state.l2_norm_clip,
|
||||
global_state.stddev)
|
||||
]
|
||||
else:
|
||||
dependencies = []
|
||||
with tf.control_dependencies(dependencies):
|
||||
return tf.nest.map_structure(add_noise, sample_state), global_state
|
||||
|
||||
|
||||
class GaussianAverageQuery(normalized_query.NormalizedQuery):
|
||||
"""Implements DPQuery interface for Gaussian average queries.
|
||||
|
||||
Accumulates clipped vectors, adds Gaussian noise, and normalizes.
|
||||
|
||||
Note that we use "fixed-denominator" estimation: the denominator should be
|
||||
specified as the expected number of records per sample. Accumulating the
|
||||
denominator separately would also be possible but would be produce a higher
|
||||
variance estimator.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
l2_norm_clip,
|
||||
sum_stddev,
|
||||
denominator):
|
||||
"""Initializes the GaussianAverageQuery.
|
||||
|
||||
Args:
|
||||
l2_norm_clip: The clipping norm to apply to the global norm of each
|
||||
record.
|
||||
sum_stddev: The stddev of the noise added to the sum (before
|
||||
normalization).
|
||||
denominator: The normalization constant (applied after noise is added to
|
||||
the sum).
|
||||
"""
|
||||
super(GaussianAverageQuery, self).__init__(
|
||||
numerator_query=GaussianSumQuery(l2_norm_clip, sum_stddev),
|
||||
denominator=denominator)
|
||||
|
|
|
@ -11,8 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for GaussianAverageQuery."""
|
||||
"""Tests for GaussianSumQuery."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -34,8 +33,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
record1 = tf.constant([2.0, 0.0])
|
||||
record2 = tf.constant([-1.0, 1.0])
|
||||
|
||||
query = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=10.0, stddev=0.0)
|
||||
query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [1.0, 1.0]
|
||||
|
@ -46,8 +44,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
|
||||
record2 = tf.constant([4.0, -3.0]) # Not clipped.
|
||||
|
||||
query = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=5.0, stddev=0.0)
|
||||
query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=0.0)
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [1.0, 1.0]
|
||||
|
@ -80,8 +77,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
record1, record2 = 2.71828, 3.14159
|
||||
stddev = 1.0
|
||||
|
||||
query = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=5.0, stddev=stddev)
|
||||
query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=stddev)
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
|
||||
noised_sums = []
|
||||
|
@ -108,8 +104,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
sample_state_2 = get_sample_state(records2)
|
||||
|
||||
merged = gaussian_query.GaussianSumQuery(10.0, 1.0).merge_sample_states(
|
||||
sample_state_1,
|
||||
sample_state_2)
|
||||
sample_state_1, sample_state_2)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
result = sess.run(merged)
|
||||
|
@ -117,36 +112,6 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
expected = [3.0, 10.0]
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
def test_gaussian_average_no_noise(self):
|
||||
with self.cached_session() as sess:
|
||||
record1 = tf.constant([5.0, 0.0]) # Clipped to [3.0, 0.0].
|
||||
record2 = tf.constant([-1.0, 2.0]) # Not clipped.
|
||||
|
||||
query = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=3.0, sum_stddev=0.0, denominator=2.0)
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected_average = [1.0, 1.0]
|
||||
self.assertAllClose(result, expected_average)
|
||||
|
||||
def test_gaussian_average_with_noise(self):
|
||||
with self.cached_session() as sess:
|
||||
record1, record2 = 2.71828, 3.14159
|
||||
sum_stddev = 1.0
|
||||
denominator = 2.0
|
||||
|
||||
query = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=5.0, sum_stddev=sum_stddev, denominator=denominator)
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
|
||||
noised_averages = []
|
||||
for _ in range(1000):
|
||||
noised_averages.append(sess.run(query_result))
|
||||
|
||||
result_stddev = np.std(noised_averages)
|
||||
avg_stddev = sum_stddev / denominator
|
||||
self.assertNear(result_stddev, avg_stddev, 0.1)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('type_mismatch', [1.0], (1.0,), TypeError),
|
||||
('too_few_on_left', [1.0], [1.0, 1.0], ValueError),
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for NestedQuery."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
@ -27,9 +26,9 @@ import tensorflow.compat.v1 as tf
|
|||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||
from tensorflow_privacy.privacy.dp_query import nested_query
|
||||
from tensorflow_privacy.privacy.dp_query import normalized_query
|
||||
from tensorflow_privacy.privacy.dp_query import test_utils
|
||||
|
||||
|
||||
_basic_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
|
||||
|
||||
|
||||
|
@ -37,10 +36,8 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
def test_nested_gaussian_sum_no_clip_no_noise(self):
|
||||
with self.cached_session() as sess:
|
||||
query1 = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=10.0, stddev=0.0)
|
||||
query2 = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=10.0, stddev=0.0)
|
||||
query1 = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
|
||||
query2 = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
|
||||
|
||||
query = nested_query.NestedSumQuery([query1, query2])
|
||||
|
||||
|
@ -52,29 +49,14 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
expected = [5.0, [5.0, 5.0]]
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
def test_nested_gaussian_average_no_clip_no_noise(self):
|
||||
with self.cached_session() as sess:
|
||||
query1 = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=10.0, sum_stddev=0.0, denominator=5.0)
|
||||
query2 = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=10.0, sum_stddev=0.0, denominator=5.0)
|
||||
|
||||
query = nested_query.NestedSumQuery([query1, query2])
|
||||
|
||||
record1 = [1.0, [2.0, 3.0]]
|
||||
record2 = [4.0, [3.0, 2.0]]
|
||||
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [1.0, [1.0, 1.0]]
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
def test_nested_gaussian_average_with_clip_no_noise(self):
|
||||
with self.cached_session() as sess:
|
||||
query1 = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=4.0, sum_stddev=0.0, denominator=5.0)
|
||||
query2 = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=5.0, sum_stddev=0.0, denominator=5.0)
|
||||
query1 = normalized_query.NormalizedQuery(
|
||||
gaussian_query.GaussianSumQuery(l2_norm_clip=4.0, stddev=0.0),
|
||||
denominator=5.0)
|
||||
query2 = normalized_query.NormalizedQuery(
|
||||
gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=0.0),
|
||||
denominator=5.0)
|
||||
|
||||
query = nested_query.NestedSumQuery([query1, query2])
|
||||
|
||||
|
@ -88,15 +70,17 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
def test_complex_nested_query(self):
|
||||
with self.cached_session() as sess:
|
||||
query_ab = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=1.0, stddev=0.0)
|
||||
query_c = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=10.0, sum_stddev=0.0, denominator=2.0)
|
||||
query_d = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=10.0, stddev=0.0)
|
||||
query_ab = gaussian_query.GaussianSumQuery(l2_norm_clip=1.0, stddev=0.0)
|
||||
query_c = normalized_query.NormalizedQuery(
|
||||
gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0),
|
||||
denominator=2.0)
|
||||
query_d = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
|
||||
|
||||
query = nested_query.NestedSumQuery(
|
||||
[query_ab, {'c': query_c, 'd': [query_d]}])
|
||||
[query_ab, {
|
||||
'c': query_c,
|
||||
'd': [query_d]
|
||||
}])
|
||||
|
||||
record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}]
|
||||
record2 = [{'a': 3.14159, 'b': 0.0}, {'c': (6.0, -4.0), 'd': [5.0]}]
|
||||
|
@ -108,13 +92,13 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
def test_nested_query_with_noise(self):
|
||||
with self.cached_session() as sess:
|
||||
sum_stddev = 2.71828
|
||||
stddev = 2.71828
|
||||
denominator = 3.14159
|
||||
|
||||
query1 = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=1.5, stddev=sum_stddev)
|
||||
query2 = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=0.5, sum_stddev=sum_stddev, denominator=denominator)
|
||||
query1 = gaussian_query.GaussianSumQuery(l2_norm_clip=1.5, stddev=stddev)
|
||||
query2 = normalized_query.NormalizedQuery(
|
||||
gaussian_query.GaussianSumQuery(l2_norm_clip=0.5, stddev=stddev),
|
||||
denominator=denominator)
|
||||
query = nested_query.NestedSumQuery((query1, query2))
|
||||
|
||||
record1 = (3.0, [2.0, 1.5])
|
||||
|
@ -127,20 +111,20 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
noised_averages.append(tf.nest.flatten(sess.run(query_result)))
|
||||
|
||||
result_stddev = np.std(noised_averages, 0)
|
||||
avg_stddev = sum_stddev / denominator
|
||||
expected_stddev = [sum_stddev, avg_stddev, avg_stddev]
|
||||
avg_stddev = stddev / denominator
|
||||
expected_stddev = [stddev, avg_stddev, avg_stddev]
|
||||
self.assertArrayNear(result_stddev, expected_stddev, 0.1)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('type_mismatch', [_basic_query], (1.0,), TypeError),
|
||||
('too_many_queries', [_basic_query, _basic_query], [1.0], ValueError),
|
||||
('query_too_deep', [_basic_query, [_basic_query]], [1.0, 1.0], TypeError))
|
||||
def test_record_incompatible_with_query(
|
||||
self, queries, record, error_type):
|
||||
def test_record_incompatible_with_query(self, queries, record, error_type):
|
||||
with self.assertRaises(error_type):
|
||||
test_utils.run_query(nested_query.NestedSumQuery(queries), [record])
|
||||
|
||||
def test_raises_with_non_sum(self):
|
||||
|
||||
class NonSumDPQuery(dp_query.DPQuery):
|
||||
pass
|
||||
|
||||
|
@ -154,6 +138,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
nested_query.NestedSumQuery(non_sum_query)
|
||||
|
||||
def test_metrics(self):
|
||||
|
||||
class QueryWithMetric(dp_query.SumAggregationDPQuery):
|
||||
|
||||
def __init__(self, metric_val):
|
||||
|
@ -171,8 +156,10 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
metric_val = nested_a.derive_metrics(global_state)
|
||||
self.assertEqual(metric_val['metric'], 1)
|
||||
|
||||
nested_b = nested_query.NestedSumQuery(
|
||||
{'key1': query1, 'key2': [query2, query3]})
|
||||
nested_b = nested_query.NestedSumQuery({
|
||||
'key1': query1,
|
||||
'key2': [query2, query3]
|
||||
})
|
||||
global_state = nested_b.initial_global_state()
|
||||
metric_val = nested_b.derive_metrics(global_state)
|
||||
self.assertEqual(metric_val['key1/metric'], 1)
|
||||
|
|
|
@ -11,8 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for GaussianAverageQuery."""
|
||||
"""Tests for NormalizedQuery."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -32,8 +31,7 @@ class NormalizedQueryTest(tf.test.TestCase):
|
|||
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
|
||||
record2 = tf.constant([4.0, -3.0]) # Not clipped.
|
||||
|
||||
sum_query = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=5.0, stddev=0.0)
|
||||
sum_query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=0.0)
|
||||
query = normalized_query.NormalizedQuery(
|
||||
numerator_query=sum_query, denominator=2.0)
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Implements DPQuery interface for quantile estimator.
|
||||
|
||||
From a starting estimate of the target quantile, the estimate is updated
|
||||
|
@ -27,23 +26,20 @@ from __future__ import print_function
|
|||
import collections
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||
from tensorflow_privacy.privacy.dp_query import no_privacy_query
|
||||
from tensorflow_privacy.privacy.dp_query import normalized_query
|
||||
|
||||
|
||||
class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||
"""Iterative process to estimate target quantile of a univariate distribution.
|
||||
"""
|
||||
"""Iterative process to estimate target quantile of a univariate distribution."""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_GlobalState = collections.namedtuple(
|
||||
'_GlobalState', [
|
||||
'current_estimate',
|
||||
'target_quantile',
|
||||
'learning_rate',
|
||||
'below_estimate_state'])
|
||||
_GlobalState = collections.namedtuple('_GlobalState', [
|
||||
'current_estimate', 'target_quantile', 'learning_rate',
|
||||
'below_estimate_state'
|
||||
])
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_SampleParams = collections.namedtuple(
|
||||
|
@ -52,8 +48,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
|||
# No separate SampleState-- sample state is just below_estimate_query's
|
||||
# SampleState.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
def __init__(self,
|
||||
initial_estimate,
|
||||
target_quantile,
|
||||
learning_rate,
|
||||
|
@ -65,11 +60,11 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
|||
Args:
|
||||
initial_estimate: The initial estimate of the quantile.
|
||||
target_quantile: The target quantile. I.e., a value of 0.8 means a value
|
||||
should be found for which approximately 80% of updates are
|
||||
less than the estimate each round.
|
||||
learning_rate: The learning rate. A rate of r means that the estimate
|
||||
will change by a maximum of r at each step (for arithmetic updating) or
|
||||
by a maximum factor of exp(r) (for geometric updating).
|
||||
should be found for which approximately 80% of updates are less than the
|
||||
estimate each round.
|
||||
learning_rate: The learning rate. A rate of r means that the estimate will
|
||||
change by a maximum of r at each step (for arithmetic updating) or by a
|
||||
maximum factor of exp(r) (for geometric updating).
|
||||
below_estimate_stddev: The stddev of the noise added to the count of
|
||||
records currently below the estimate. Since the sensitivity of the count
|
||||
query is 0.5, as a rule of thumb it should be about 0.5 for reasonable
|
||||
|
@ -90,8 +85,8 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
|||
|
||||
self._geometric_update = geometric_update
|
||||
|
||||
def _construct_below_estimate_query(
|
||||
self, below_estimate_stddev, expected_num_records):
|
||||
def _construct_below_estimate_query(self, below_estimate_stddev,
|
||||
expected_num_records):
|
||||
# A DPQuery used to estimate the fraction of records that are less than the
|
||||
# current quantile estimate. It accumulates an indicator 0/1 of whether each
|
||||
# record is below the estimate, and normalizes by the expected number of
|
||||
|
@ -101,9 +96,9 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
|||
# affect the count is 0.5. Note that although the l2_norm_clip of the
|
||||
# below_estimate query is 0.5, no clipping will ever actually occur
|
||||
# because the value of each record is always +/-0.5.
|
||||
return gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=0.5,
|
||||
sum_stddev=below_estimate_stddev,
|
||||
return normalized_query.NormalizedQuery(
|
||||
gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=0.5, stddev=below_estimate_stddev),
|
||||
denominator=expected_num_records)
|
||||
|
||||
def set_ledger(self, ledger):
|
||||
|
@ -140,8 +135,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
|||
def get_noised_result(self, sample_state, global_state):
|
||||
below_estimate_result, new_below_estimate_state = (
|
||||
self._below_estimate_query.get_noised_result(
|
||||
sample_state,
|
||||
global_state.below_estimate_state))
|
||||
sample_state, global_state.below_estimate_state))
|
||||
|
||||
# Unshift below_estimate percentile by 0.5. (See comment in initializer.)
|
||||
below_estimate = below_estimate_result + 0.5
|
||||
|
@ -177,8 +171,7 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
|
|||
below estimate with an exact denominator.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
def __init__(self,
|
||||
initial_estimate,
|
||||
target_quantile,
|
||||
learning_rate,
|
||||
|
@ -188,22 +181,25 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
|
|||
Args:
|
||||
initial_estimate: The initial estimate of the quantile.
|
||||
target_quantile: The target quantile. I.e., a value of 0.8 means a value
|
||||
should be found for which approximately 80% of updates are
|
||||
less than the estimate each round.
|
||||
learning_rate: The learning rate. A rate of r means that the estimate
|
||||
will change by a maximum of r at each step (for arithmetic updating) or
|
||||
by a maximum factor of exp(r) (for geometric updating).
|
||||
should be found for which approximately 80% of updates are less than the
|
||||
estimate each round.
|
||||
learning_rate: The learning rate. A rate of r means that the estimate will
|
||||
change by a maximum of r at each step (for arithmetic updating) or by a
|
||||
maximum factor of exp(r) (for geometric updating).
|
||||
geometric_update: If True, use geometric updating of estimate. Geometric
|
||||
updating is preferred for non-negative records like vector norms that
|
||||
could potentially be very large or very close to zero.
|
||||
"""
|
||||
super(NoPrivacyQuantileEstimatorQuery, self).__init__(
|
||||
initial_estimate, target_quantile, learning_rate,
|
||||
below_estimate_stddev=None, expected_num_records=None,
|
||||
initial_estimate,
|
||||
target_quantile,
|
||||
learning_rate,
|
||||
below_estimate_stddev=None,
|
||||
expected_num_records=None,
|
||||
geometric_update=geometric_update)
|
||||
|
||||
def _construct_below_estimate_query(
|
||||
self, below_estimate_stddev, expected_num_records):
|
||||
def _construct_below_estimate_query(self, below_estimate_stddev,
|
||||
expected_num_records):
|
||||
del below_estimate_stddev
|
||||
del expected_num_records
|
||||
return no_privacy_query.NoPrivacyAverageQuery()
|
||||
|
|
Loading…
Reference in a new issue