Remove GaussianAverageQuery. Users can simply wrap GaussianSumQuery with a NormalizedQuery.

PiperOrigin-RevId: 374784618
This commit is contained in:
Galen Andrew 2021-05-19 20:19:39 -07:00 committed by A. Unique TensorFlower
parent 1de7e4dde4
commit e5848656ed
7 changed files with 96 additions and 189 deletions

View file

@ -39,7 +39,6 @@ else:
# DPQuery classes # DPQuery classes
from tensorflow_privacy.privacy.dp_query.dp_query import DPQuery from tensorflow_privacy.privacy.dp_query.dp_query import DPQuery
from tensorflow_privacy.privacy.dp_query.dp_query import SumAggregationDPQuery from tensorflow_privacy.privacy.dp_query.dp_query import SumAggregationDPQuery
from tensorflow_privacy.privacy.dp_query.gaussian_query import GaussianAverageQuery
from tensorflow_privacy.privacy.dp_query.gaussian_query import GaussianSumQuery from tensorflow_privacy.privacy.dp_query.gaussian_query import GaussianSumQuery
from tensorflow_privacy.privacy.dp_query.nested_query import NestedQuery from tensorflow_privacy.privacy.dp_query.nested_query import NestedQuery
from tensorflow_privacy.privacy.dp_query.no_privacy_query import NoPrivacyAverageQuery from tensorflow_privacy.privacy.dp_query.no_privacy_query import NoPrivacyAverageQuery

View file

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for PrivacyLedger.""" """Tests for PrivacyLedger."""
from __future__ import absolute_import from __future__ import absolute_import
@ -57,10 +56,9 @@ class PrivacyLedgerTest(tf.test.TestCase):
population_size = tf.Variable(0) population_size = tf.Variable(0)
selection_probability = tf.Variable(1.0) selection_probability = tf.Variable(1.0)
query = gaussian_query.GaussianSumQuery( query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
l2_norm_clip=10.0, stddev=0.0) query = privacy_ledger.QueryWithLedger(query, population_size,
query = privacy_ledger.QueryWithLedger( selection_probability)
query, population_size, selection_probability)
# First sample. # First sample.
tf.assign(population_size, 10) tf.assign(population_size, 10)
@ -93,14 +91,12 @@ class PrivacyLedgerTest(tf.test.TestCase):
population_size = tf.Variable(0) population_size = tf.Variable(0)
selection_probability = tf.Variable(1.0) selection_probability = tf.Variable(1.0)
query1 = gaussian_query.GaussianAverageQuery( query1 = gaussian_query.GaussianSumQuery(l2_norm_clip=4.0, stddev=2.0)
l2_norm_clip=4.0, sum_stddev=2.0, denominator=5.0) query2 = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=1.0)
query2 = gaussian_query.GaussianAverageQuery(
l2_norm_clip=5.0, sum_stddev=1.0, denominator=5.0)
query = nested_query.NestedQuery([query1, query2]) query = nested_query.NestedQuery([query1, query2])
query = privacy_ledger.QueryWithLedger( query = privacy_ledger.QueryWithLedger(query, population_size,
query, population_size, selection_probability) selection_probability)
record1 = [1.0, [12.0, 9.0]] record1 = [1.0, [12.0, 9.0]]
record2 = [5.0, [1.0, 2.0]] record2 = [5.0, [1.0, 2.0]]

View file

@ -11,9 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Implements DPQuery interface for Gaussian sum queries."""
"""Implements DPQuery interface for Gaussian average queries.
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -25,7 +23,6 @@ import distutils
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import normalized_query
class GaussianSumQuery(dp_query.SumAggregationDPQuery): class GaussianSumQuery(dp_query.SumAggregationDPQuery):
@ -35,8 +32,8 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
""" """
# pylint: disable=invalid-name # pylint: disable=invalid-name
_GlobalState = collections.namedtuple( _GlobalState = collections.namedtuple('_GlobalState',
'_GlobalState', ['l2_norm_clip', 'stddev']) ['l2_norm_clip', 'stddev'])
def __init__(self, l2_norm_clip, stddev): def __init__(self, l2_norm_clip, stddev):
"""Initializes the GaussianSumQuery. """Initializes the GaussianSumQuery.
@ -55,8 +52,8 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
def make_global_state(self, l2_norm_clip, stddev): def make_global_state(self, l2_norm_clip, stddev):
"""Creates a global state from the given parameters.""" """Creates a global state from the given parameters."""
return self._GlobalState(tf.cast(l2_norm_clip, tf.float32), return self._GlobalState(
tf.cast(stddev, tf.float32)) tf.cast(l2_norm_clip, tf.float32), tf.cast(stddev, tf.float32))
def initial_global_state(self): def initial_global_state(self):
return self.make_global_state(self._l2_norm_clip, self._stddev) return self.make_global_state(self._l2_norm_clip, self._stddev)
@ -94,48 +91,17 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
return v + tf.random.normal( return v + tf.random.normal(
tf.shape(input=v), stddev=global_state.stddev, dtype=v.dtype) tf.shape(input=v), stddev=global_state.stddev, dtype=v.dtype)
else: else:
random_normal = tf.random_normal_initializer( random_normal = tf.random_normal_initializer(stddev=global_state.stddev)
stddev=global_state.stddev)
def add_noise(v): def add_noise(v):
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype) return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
if self._ledger: if self._ledger:
dependencies = [ dependencies = [
self._ledger.record_sum_query( self._ledger.record_sum_query(global_state.l2_norm_clip,
global_state.l2_norm_clip, global_state.stddev) global_state.stddev)
] ]
else: else:
dependencies = [] dependencies = []
with tf.control_dependencies(dependencies): with tf.control_dependencies(dependencies):
return tf.nest.map_structure(add_noise, sample_state), global_state return tf.nest.map_structure(add_noise, sample_state), global_state
class GaussianAverageQuery(normalized_query.NormalizedQuery):
"""Implements DPQuery interface for Gaussian average queries.
Accumulates clipped vectors, adds Gaussian noise, and normalizes.
Note that we use "fixed-denominator" estimation: the denominator should be
specified as the expected number of records per sample. Accumulating the
denominator separately would also be possible but would be produce a higher
variance estimator.
"""
def __init__(self,
l2_norm_clip,
sum_stddev,
denominator):
"""Initializes the GaussianAverageQuery.
Args:
l2_norm_clip: The clipping norm to apply to the global norm of each
record.
sum_stddev: The stddev of the noise added to the sum (before
normalization).
denominator: The normalization constant (applied after noise is added to
the sum).
"""
super(GaussianAverageQuery, self).__init__(
numerator_query=GaussianSumQuery(l2_norm_clip, sum_stddev),
denominator=denominator)

View file

@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for GaussianSumQuery."""
"""Tests for GaussianAverageQuery."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -34,8 +33,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
record1 = tf.constant([2.0, 0.0]) record1 = tf.constant([2.0, 0.0])
record2 = tf.constant([-1.0, 1.0]) record2 = tf.constant([-1.0, 1.0])
query = gaussian_query.GaussianSumQuery( query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
l2_norm_clip=10.0, stddev=0.0)
query_result, _ = test_utils.run_query(query, [record1, record2]) query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result) result = sess.run(query_result)
expected = [1.0, 1.0] expected = [1.0, 1.0]
@ -46,8 +44,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0]. record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
record2 = tf.constant([4.0, -3.0]) # Not clipped. record2 = tf.constant([4.0, -3.0]) # Not clipped.
query = gaussian_query.GaussianSumQuery( query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=0.0)
l2_norm_clip=5.0, stddev=0.0)
query_result, _ = test_utils.run_query(query, [record1, record2]) query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result) result = sess.run(query_result)
expected = [1.0, 1.0] expected = [1.0, 1.0]
@ -80,8 +77,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
record1, record2 = 2.71828, 3.14159 record1, record2 = 2.71828, 3.14159
stddev = 1.0 stddev = 1.0
query = gaussian_query.GaussianSumQuery( query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=stddev)
l2_norm_clip=5.0, stddev=stddev)
query_result, _ = test_utils.run_query(query, [record1, record2]) query_result, _ = test_utils.run_query(query, [record1, record2])
noised_sums = [] noised_sums = []
@ -108,8 +104,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
sample_state_2 = get_sample_state(records2) sample_state_2 = get_sample_state(records2)
merged = gaussian_query.GaussianSumQuery(10.0, 1.0).merge_sample_states( merged = gaussian_query.GaussianSumQuery(10.0, 1.0).merge_sample_states(
sample_state_1, sample_state_1, sample_state_2)
sample_state_2)
with self.cached_session() as sess: with self.cached_session() as sess:
result = sess.run(merged) result = sess.run(merged)
@ -117,36 +112,6 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
expected = [3.0, 10.0] expected = [3.0, 10.0]
self.assertAllClose(result, expected) self.assertAllClose(result, expected)
def test_gaussian_average_no_noise(self):
with self.cached_session() as sess:
record1 = tf.constant([5.0, 0.0]) # Clipped to [3.0, 0.0].
record2 = tf.constant([-1.0, 2.0]) # Not clipped.
query = gaussian_query.GaussianAverageQuery(
l2_norm_clip=3.0, sum_stddev=0.0, denominator=2.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result)
expected_average = [1.0, 1.0]
self.assertAllClose(result, expected_average)
def test_gaussian_average_with_noise(self):
with self.cached_session() as sess:
record1, record2 = 2.71828, 3.14159
sum_stddev = 1.0
denominator = 2.0
query = gaussian_query.GaussianAverageQuery(
l2_norm_clip=5.0, sum_stddev=sum_stddev, denominator=denominator)
query_result, _ = test_utils.run_query(query, [record1, record2])
noised_averages = []
for _ in range(1000):
noised_averages.append(sess.run(query_result))
result_stddev = np.std(noised_averages)
avg_stddev = sum_stddev / denominator
self.assertNear(result_stddev, avg_stddev, 0.1)
@parameterized.named_parameters( @parameterized.named_parameters(
('type_mismatch', [1.0], (1.0,), TypeError), ('type_mismatch', [1.0], (1.0,), TypeError),
('too_few_on_left', [1.0], [1.0, 1.0], ValueError), ('too_few_on_left', [1.0], [1.0, 1.0], ValueError),

View file

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for NestedQuery.""" """Tests for NestedQuery."""
from __future__ import absolute_import from __future__ import absolute_import
@ -27,9 +26,9 @@ import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import nested_query from tensorflow_privacy.privacy.dp_query import nested_query
from tensorflow_privacy.privacy.dp_query import normalized_query
from tensorflow_privacy.privacy.dp_query import test_utils from tensorflow_privacy.privacy.dp_query import test_utils
_basic_query = gaussian_query.GaussianSumQuery(1.0, 0.0) _basic_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
@ -37,10 +36,8 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_nested_gaussian_sum_no_clip_no_noise(self): def test_nested_gaussian_sum_no_clip_no_noise(self):
with self.cached_session() as sess: with self.cached_session() as sess:
query1 = gaussian_query.GaussianSumQuery( query1 = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
l2_norm_clip=10.0, stddev=0.0) query2 = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
query2 = gaussian_query.GaussianSumQuery(
l2_norm_clip=10.0, stddev=0.0)
query = nested_query.NestedSumQuery([query1, query2]) query = nested_query.NestedSumQuery([query1, query2])
@ -52,29 +49,14 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
expected = [5.0, [5.0, 5.0]] expected = [5.0, [5.0, 5.0]]
self.assertAllClose(result, expected) self.assertAllClose(result, expected)
def test_nested_gaussian_average_no_clip_no_noise(self):
with self.cached_session() as sess:
query1 = gaussian_query.GaussianAverageQuery(
l2_norm_clip=10.0, sum_stddev=0.0, denominator=5.0)
query2 = gaussian_query.GaussianAverageQuery(
l2_norm_clip=10.0, sum_stddev=0.0, denominator=5.0)
query = nested_query.NestedSumQuery([query1, query2])
record1 = [1.0, [2.0, 3.0]]
record2 = [4.0, [3.0, 2.0]]
query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result)
expected = [1.0, [1.0, 1.0]]
self.assertAllClose(result, expected)
def test_nested_gaussian_average_with_clip_no_noise(self): def test_nested_gaussian_average_with_clip_no_noise(self):
with self.cached_session() as sess: with self.cached_session() as sess:
query1 = gaussian_query.GaussianAverageQuery( query1 = normalized_query.NormalizedQuery(
l2_norm_clip=4.0, sum_stddev=0.0, denominator=5.0) gaussian_query.GaussianSumQuery(l2_norm_clip=4.0, stddev=0.0),
query2 = gaussian_query.GaussianAverageQuery( denominator=5.0)
l2_norm_clip=5.0, sum_stddev=0.0, denominator=5.0) query2 = normalized_query.NormalizedQuery(
gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=0.0),
denominator=5.0)
query = nested_query.NestedSumQuery([query1, query2]) query = nested_query.NestedSumQuery([query1, query2])
@ -88,15 +70,17 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_complex_nested_query(self): def test_complex_nested_query(self):
with self.cached_session() as sess: with self.cached_session() as sess:
query_ab = gaussian_query.GaussianSumQuery( query_ab = gaussian_query.GaussianSumQuery(l2_norm_clip=1.0, stddev=0.0)
l2_norm_clip=1.0, stddev=0.0) query_c = normalized_query.NormalizedQuery(
query_c = gaussian_query.GaussianAverageQuery( gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0),
l2_norm_clip=10.0, sum_stddev=0.0, denominator=2.0) denominator=2.0)
query_d = gaussian_query.GaussianSumQuery( query_d = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
l2_norm_clip=10.0, stddev=0.0)
query = nested_query.NestedSumQuery( query = nested_query.NestedSumQuery(
[query_ab, {'c': query_c, 'd': [query_d]}]) [query_ab, {
'c': query_c,
'd': [query_d]
}])
record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}] record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}]
record2 = [{'a': 3.14159, 'b': 0.0}, {'c': (6.0, -4.0), 'd': [5.0]}] record2 = [{'a': 3.14159, 'b': 0.0}, {'c': (6.0, -4.0), 'd': [5.0]}]
@ -108,13 +92,13 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_nested_query_with_noise(self): def test_nested_query_with_noise(self):
with self.cached_session() as sess: with self.cached_session() as sess:
sum_stddev = 2.71828 stddev = 2.71828
denominator = 3.14159 denominator = 3.14159
query1 = gaussian_query.GaussianSumQuery( query1 = gaussian_query.GaussianSumQuery(l2_norm_clip=1.5, stddev=stddev)
l2_norm_clip=1.5, stddev=sum_stddev) query2 = normalized_query.NormalizedQuery(
query2 = gaussian_query.GaussianAverageQuery( gaussian_query.GaussianSumQuery(l2_norm_clip=0.5, stddev=stddev),
l2_norm_clip=0.5, sum_stddev=sum_stddev, denominator=denominator) denominator=denominator)
query = nested_query.NestedSumQuery((query1, query2)) query = nested_query.NestedSumQuery((query1, query2))
record1 = (3.0, [2.0, 1.5]) record1 = (3.0, [2.0, 1.5])
@ -127,20 +111,20 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
noised_averages.append(tf.nest.flatten(sess.run(query_result))) noised_averages.append(tf.nest.flatten(sess.run(query_result)))
result_stddev = np.std(noised_averages, 0) result_stddev = np.std(noised_averages, 0)
avg_stddev = sum_stddev / denominator avg_stddev = stddev / denominator
expected_stddev = [sum_stddev, avg_stddev, avg_stddev] expected_stddev = [stddev, avg_stddev, avg_stddev]
self.assertArrayNear(result_stddev, expected_stddev, 0.1) self.assertArrayNear(result_stddev, expected_stddev, 0.1)
@parameterized.named_parameters( @parameterized.named_parameters(
('type_mismatch', [_basic_query], (1.0,), TypeError), ('type_mismatch', [_basic_query], (1.0,), TypeError),
('too_many_queries', [_basic_query, _basic_query], [1.0], ValueError), ('too_many_queries', [_basic_query, _basic_query], [1.0], ValueError),
('query_too_deep', [_basic_query, [_basic_query]], [1.0, 1.0], TypeError)) ('query_too_deep', [_basic_query, [_basic_query]], [1.0, 1.0], TypeError))
def test_record_incompatible_with_query( def test_record_incompatible_with_query(self, queries, record, error_type):
self, queries, record, error_type):
with self.assertRaises(error_type): with self.assertRaises(error_type):
test_utils.run_query(nested_query.NestedSumQuery(queries), [record]) test_utils.run_query(nested_query.NestedSumQuery(queries), [record])
def test_raises_with_non_sum(self): def test_raises_with_non_sum(self):
class NonSumDPQuery(dp_query.DPQuery): class NonSumDPQuery(dp_query.DPQuery):
pass pass
@ -154,6 +138,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
nested_query.NestedSumQuery(non_sum_query) nested_query.NestedSumQuery(non_sum_query)
def test_metrics(self): def test_metrics(self):
class QueryWithMetric(dp_query.SumAggregationDPQuery): class QueryWithMetric(dp_query.SumAggregationDPQuery):
def __init__(self, metric_val): def __init__(self, metric_val):
@ -171,8 +156,10 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
metric_val = nested_a.derive_metrics(global_state) metric_val = nested_a.derive_metrics(global_state)
self.assertEqual(metric_val['metric'], 1) self.assertEqual(metric_val['metric'], 1)
nested_b = nested_query.NestedSumQuery( nested_b = nested_query.NestedSumQuery({
{'key1': query1, 'key2': [query2, query3]}) 'key1': query1,
'key2': [query2, query3]
})
global_state = nested_b.initial_global_state() global_state = nested_b.initial_global_state()
metric_val = nested_b.derive_metrics(global_state) metric_val = nested_b.derive_metrics(global_state)
self.assertEqual(metric_val['key1/metric'], 1) self.assertEqual(metric_val['key1/metric'], 1)

View file

@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for NormalizedQuery."""
"""Tests for GaussianAverageQuery."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -32,8 +31,7 @@ class NormalizedQueryTest(tf.test.TestCase):
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0]. record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
record2 = tf.constant([4.0, -3.0]) # Not clipped. record2 = tf.constant([4.0, -3.0]) # Not clipped.
sum_query = gaussian_query.GaussianSumQuery( sum_query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=0.0)
l2_norm_clip=5.0, stddev=0.0)
query = normalized_query.NormalizedQuery( query = normalized_query.NormalizedQuery(
numerator_query=sum_query, denominator=2.0) numerator_query=sum_query, denominator=2.0)

View file

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Implements DPQuery interface for quantile estimator. """Implements DPQuery interface for quantile estimator.
From a starting estimate of the target quantile, the estimate is updated From a starting estimate of the target quantile, the estimate is updated
@ -27,23 +26,20 @@ from __future__ import print_function
import collections import collections
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import no_privacy_query from tensorflow_privacy.privacy.dp_query import no_privacy_query
from tensorflow_privacy.privacy.dp_query import normalized_query
class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery): class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
"""Iterative process to estimate target quantile of a univariate distribution. """Iterative process to estimate target quantile of a univariate distribution."""
"""
# pylint: disable=invalid-name # pylint: disable=invalid-name
_GlobalState = collections.namedtuple( _GlobalState = collections.namedtuple('_GlobalState', [
'_GlobalState', [ 'current_estimate', 'target_quantile', 'learning_rate',
'current_estimate', 'below_estimate_state'
'target_quantile', ])
'learning_rate',
'below_estimate_state'])
# pylint: disable=invalid-name # pylint: disable=invalid-name
_SampleParams = collections.namedtuple( _SampleParams = collections.namedtuple(
@ -52,8 +48,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
# No separate SampleState-- sample state is just below_estimate_query's # No separate SampleState-- sample state is just below_estimate_query's
# SampleState. # SampleState.
def __init__( def __init__(self,
self,
initial_estimate, initial_estimate,
target_quantile, target_quantile,
learning_rate, learning_rate,
@ -65,11 +60,11 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
Args: Args:
initial_estimate: The initial estimate of the quantile. initial_estimate: The initial estimate of the quantile.
target_quantile: The target quantile. I.e., a value of 0.8 means a value target_quantile: The target quantile. I.e., a value of 0.8 means a value
should be found for which approximately 80% of updates are should be found for which approximately 80% of updates are less than the
less than the estimate each round. estimate each round.
learning_rate: The learning rate. A rate of r means that the estimate learning_rate: The learning rate. A rate of r means that the estimate will
will change by a maximum of r at each step (for arithmetic updating) or change by a maximum of r at each step (for arithmetic updating) or by a
by a maximum factor of exp(r) (for geometric updating). maximum factor of exp(r) (for geometric updating).
below_estimate_stddev: The stddev of the noise added to the count of below_estimate_stddev: The stddev of the noise added to the count of
records currently below the estimate. Since the sensitivity of the count records currently below the estimate. Since the sensitivity of the count
query is 0.5, as a rule of thumb it should be about 0.5 for reasonable query is 0.5, as a rule of thumb it should be about 0.5 for reasonable
@ -90,8 +85,8 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
self._geometric_update = geometric_update self._geometric_update = geometric_update
def _construct_below_estimate_query( def _construct_below_estimate_query(self, below_estimate_stddev,
self, below_estimate_stddev, expected_num_records): expected_num_records):
# A DPQuery used to estimate the fraction of records that are less than the # A DPQuery used to estimate the fraction of records that are less than the
# current quantile estimate. It accumulates an indicator 0/1 of whether each # current quantile estimate. It accumulates an indicator 0/1 of whether each
# record is below the estimate, and normalizes by the expected number of # record is below the estimate, and normalizes by the expected number of
@ -101,9 +96,9 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
# affect the count is 0.5. Note that although the l2_norm_clip of the # affect the count is 0.5. Note that although the l2_norm_clip of the
# below_estimate query is 0.5, no clipping will ever actually occur # below_estimate query is 0.5, no clipping will ever actually occur
# because the value of each record is always +/-0.5. # because the value of each record is always +/-0.5.
return gaussian_query.GaussianAverageQuery( return normalized_query.NormalizedQuery(
l2_norm_clip=0.5, gaussian_query.GaussianSumQuery(
sum_stddev=below_estimate_stddev, l2_norm_clip=0.5, stddev=below_estimate_stddev),
denominator=expected_num_records) denominator=expected_num_records)
def set_ledger(self, ledger): def set_ledger(self, ledger):
@ -140,8 +135,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state): def get_noised_result(self, sample_state, global_state):
below_estimate_result, new_below_estimate_state = ( below_estimate_result, new_below_estimate_state = (
self._below_estimate_query.get_noised_result( self._below_estimate_query.get_noised_result(
sample_state, sample_state, global_state.below_estimate_state))
global_state.below_estimate_state))
# Unshift below_estimate percentile by 0.5. (See comment in initializer.) # Unshift below_estimate percentile by 0.5. (See comment in initializer.)
below_estimate = below_estimate_result + 0.5 below_estimate = below_estimate_result + 0.5
@ -177,8 +171,7 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
below estimate with an exact denominator. below estimate with an exact denominator.
""" """
def __init__( def __init__(self,
self,
initial_estimate, initial_estimate,
target_quantile, target_quantile,
learning_rate, learning_rate,
@ -188,22 +181,25 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
Args: Args:
initial_estimate: The initial estimate of the quantile. initial_estimate: The initial estimate of the quantile.
target_quantile: The target quantile. I.e., a value of 0.8 means a value target_quantile: The target quantile. I.e., a value of 0.8 means a value
should be found for which approximately 80% of updates are should be found for which approximately 80% of updates are less than the
less than the estimate each round. estimate each round.
learning_rate: The learning rate. A rate of r means that the estimate learning_rate: The learning rate. A rate of r means that the estimate will
will change by a maximum of r at each step (for arithmetic updating) or change by a maximum of r at each step (for arithmetic updating) or by a
by a maximum factor of exp(r) (for geometric updating). maximum factor of exp(r) (for geometric updating).
geometric_update: If True, use geometric updating of estimate. Geometric geometric_update: If True, use geometric updating of estimate. Geometric
updating is preferred for non-negative records like vector norms that updating is preferred for non-negative records like vector norms that
could potentially be very large or very close to zero. could potentially be very large or very close to zero.
""" """
super(NoPrivacyQuantileEstimatorQuery, self).__init__( super(NoPrivacyQuantileEstimatorQuery, self).__init__(
initial_estimate, target_quantile, learning_rate, initial_estimate,
below_estimate_stddev=None, expected_num_records=None, target_quantile,
learning_rate,
below_estimate_stddev=None,
expected_num_records=None,
geometric_update=geometric_update) geometric_update=geometric_update)
def _construct_below_estimate_query( def _construct_below_estimate_query(self, below_estimate_stddev,
self, below_estimate_stddev, expected_num_records): expected_num_records):
del below_estimate_stddev del below_estimate_stddev
del expected_num_records del expected_num_records
return no_privacy_query.NoPrivacyAverageQuery() return no_privacy_query.NoPrivacyAverageQuery()