Add NestedSumQuery for nested queries with sum aggregation.
PiperOrigin-RevId: 320303703
This commit is contained in:
parent
c948e2fe7c
commit
d1e2cc1930
2 changed files with 38 additions and 6 deletions
|
@ -108,3 +108,21 @@ class NestedQuery(dp_query.DPQuery):
|
|||
*tree.flatten_up_to(self._queries, estimates_and_new_global_states))
|
||||
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
|
||||
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
|
||||
|
||||
|
||||
class NestedSumQuery(dp_query.SumAggregationDPQuery, NestedQuery):
|
||||
"""A NestedQuery that consists only of SumAggregationDPQueries."""
|
||||
|
||||
def __init__(self, queries):
|
||||
"""Initializes the NestedSumQuery.
|
||||
|
||||
Args:
|
||||
queries: A nested structure of queries that must all be
|
||||
SumAggregationDPQueries.
|
||||
"""
|
||||
def check(query):
|
||||
if not isinstance(query, dp_query.SumAggregationDPQuery):
|
||||
raise TypeError('All subqueries must be SumAggregationDPQueries.')
|
||||
tree.map_structure(check, queries)
|
||||
|
||||
super(NestedSumQuery, self).__init__(queries)
|
||||
|
|
|
@ -23,6 +23,7 @@ from absl.testing import parameterized
|
|||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||
from tensorflow_privacy.privacy.dp_query import nested_query
|
||||
from tensorflow_privacy.privacy.dp_query import test_utils
|
||||
|
@ -40,7 +41,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
query2 = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=10.0, stddev=0.0)
|
||||
|
||||
query = nested_query.NestedQuery([query1, query2])
|
||||
query = nested_query.NestedSumQuery([query1, query2])
|
||||
|
||||
record1 = [1.0, [2.0, 3.0]]
|
||||
record2 = [4.0, [3.0, 2.0]]
|
||||
|
@ -57,7 +58,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
query2 = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=10.0, sum_stddev=0.0, denominator=5.0)
|
||||
|
||||
query = nested_query.NestedQuery([query1, query2])
|
||||
query = nested_query.NestedSumQuery([query1, query2])
|
||||
|
||||
record1 = [1.0, [2.0, 3.0]]
|
||||
record2 = [4.0, [3.0, 2.0]]
|
||||
|
@ -74,7 +75,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
query2 = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=5.0, sum_stddev=0.0, denominator=5.0)
|
||||
|
||||
query = nested_query.NestedQuery([query1, query2])
|
||||
query = nested_query.NestedSumQuery([query1, query2])
|
||||
|
||||
record1 = [1.0, [12.0, 9.0]] # Clipped to [1.0, [4.0, 3.0]]
|
||||
record2 = [5.0, [1.0, 2.0]] # Clipped to [4.0, [1.0, 2.0]]
|
||||
|
@ -93,7 +94,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
query_d = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=10.0, stddev=0.0)
|
||||
|
||||
query = nested_query.NestedQuery(
|
||||
query = nested_query.NestedSumQuery(
|
||||
[query_ab, {'c': query_c, 'd': [query_d]}])
|
||||
|
||||
record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}]
|
||||
|
@ -113,7 +114,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
l2_norm_clip=1.5, stddev=sum_stddev)
|
||||
query2 = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=0.5, sum_stddev=sum_stddev, denominator=denominator)
|
||||
query = nested_query.NestedQuery((query1, query2))
|
||||
query = nested_query.NestedSumQuery((query1, query2))
|
||||
|
||||
record1 = (3.0, [2.0, 1.5])
|
||||
record2 = (0.0, [-1.0, -3.5])
|
||||
|
@ -136,7 +137,20 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
def test_record_incompatible_with_query(
|
||||
self, queries, record, error_type):
|
||||
with self.assertRaises(error_type):
|
||||
test_utils.run_query(nested_query.NestedQuery(queries), [record])
|
||||
test_utils.run_query(nested_query.NestedSumQuery(queries), [record])
|
||||
|
||||
def test_raises_with_non_sum(self):
|
||||
class NonSumDPQuery(dp_query.DPQuery):
|
||||
pass
|
||||
|
||||
non_sum_query = NonSumDPQuery()
|
||||
|
||||
# This should work.
|
||||
nested_query.NestedQuery(non_sum_query)
|
||||
|
||||
# This should not.
|
||||
with self.assertRaises(TypeError):
|
||||
nested_query.NestedSumQuery(non_sum_query)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in a new issue