Add NestedSumQuery for nested queries with sum aggregation.

PiperOrigin-RevId: 320303703
This commit is contained in:
Galen Andrew 2020-07-08 18:04:39 -07:00 committed by A. Unique TensorFlower
parent c948e2fe7c
commit d1e2cc1930
2 changed files with 38 additions and 6 deletions

View file

@ -108,3 +108,21 @@ class NestedQuery(dp_query.DPQuery):
*tree.flatten_up_to(self._queries, estimates_and_new_global_states))
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
class NestedSumQuery(dp_query.SumAggregationDPQuery, NestedQuery):
"""A NestedQuery that consists only of SumAggregationDPQueries."""
def __init__(self, queries):
"""Initializes the NestedSumQuery.
Args:
queries: A nested structure of queries that must all be
SumAggregationDPQueries.
"""
def check(query):
if not isinstance(query, dp_query.SumAggregationDPQuery):
raise TypeError('All subqueries must be SumAggregationDPQueries.')
tree.map_structure(check, queries)
super(NestedSumQuery, self).__init__(queries)

View file

@ -23,6 +23,7 @@ from absl.testing import parameterized
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import nested_query
from tensorflow_privacy.privacy.dp_query import test_utils
@ -40,7 +41,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
query2 = gaussian_query.GaussianSumQuery(
l2_norm_clip=10.0, stddev=0.0)
query = nested_query.NestedQuery([query1, query2])
query = nested_query.NestedSumQuery([query1, query2])
record1 = [1.0, [2.0, 3.0]]
record2 = [4.0, [3.0, 2.0]]
@ -57,7 +58,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
query2 = gaussian_query.GaussianAverageQuery(
l2_norm_clip=10.0, sum_stddev=0.0, denominator=5.0)
query = nested_query.NestedQuery([query1, query2])
query = nested_query.NestedSumQuery([query1, query2])
record1 = [1.0, [2.0, 3.0]]
record2 = [4.0, [3.0, 2.0]]
@ -74,7 +75,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
query2 = gaussian_query.GaussianAverageQuery(
l2_norm_clip=5.0, sum_stddev=0.0, denominator=5.0)
query = nested_query.NestedQuery([query1, query2])
query = nested_query.NestedSumQuery([query1, query2])
record1 = [1.0, [12.0, 9.0]] # Clipped to [1.0, [4.0, 3.0]]
record2 = [5.0, [1.0, 2.0]] # Clipped to [4.0, [1.0, 2.0]]
@ -93,7 +94,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
query_d = gaussian_query.GaussianSumQuery(
l2_norm_clip=10.0, stddev=0.0)
query = nested_query.NestedQuery(
query = nested_query.NestedSumQuery(
[query_ab, {'c': query_c, 'd': [query_d]}])
record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}]
@ -113,7 +114,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
l2_norm_clip=1.5, stddev=sum_stddev)
query2 = gaussian_query.GaussianAverageQuery(
l2_norm_clip=0.5, sum_stddev=sum_stddev, denominator=denominator)
query = nested_query.NestedQuery((query1, query2))
query = nested_query.NestedSumQuery((query1, query2))
record1 = (3.0, [2.0, 1.5])
record2 = (0.0, [-1.0, -3.5])
@ -136,7 +137,20 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_record_incompatible_with_query(
self, queries, record, error_type):
with self.assertRaises(error_type):
test_utils.run_query(nested_query.NestedQuery(queries), [record])
test_utils.run_query(nested_query.NestedSumQuery(queries), [record])
def test_raises_with_non_sum(self):
class NonSumDPQuery(dp_query.DPQuery):
pass
non_sum_query = NonSumDPQuery()
# This should work.
nested_query.NestedQuery(non_sum_query)
# This should not.
with self.assertRaises(TypeError):
nested_query.NestedSumQuery(non_sum_query)
if __name__ == '__main__':