forked from 626_privacy/tensorflow_privacy
Add NestedSumQuery for nested queries with sum aggregation.
PiperOrigin-RevId: 320303703
This commit is contained in:
parent
c948e2fe7c
commit
d1e2cc1930
2 changed files with 38 additions and 6 deletions
|
@ -108,3 +108,21 @@ class NestedQuery(dp_query.DPQuery):
|
||||||
*tree.flatten_up_to(self._queries, estimates_and_new_global_states))
|
*tree.flatten_up_to(self._queries, estimates_and_new_global_states))
|
||||||
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
|
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
|
||||||
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
|
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
|
||||||
|
|
||||||
|
|
||||||
|
class NestedSumQuery(dp_query.SumAggregationDPQuery, NestedQuery):
|
||||||
|
"""A NestedQuery that consists only of SumAggregationDPQueries."""
|
||||||
|
|
||||||
|
def __init__(self, queries):
|
||||||
|
"""Initializes the NestedSumQuery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queries: A nested structure of queries that must all be
|
||||||
|
SumAggregationDPQueries.
|
||||||
|
"""
|
||||||
|
def check(query):
|
||||||
|
if not isinstance(query, dp_query.SumAggregationDPQuery):
|
||||||
|
raise TypeError('All subqueries must be SumAggregationDPQueries.')
|
||||||
|
tree.map_structure(check, queries)
|
||||||
|
|
||||||
|
super(NestedSumQuery, self).__init__(queries)
|
||||||
|
|
|
@ -23,6 +23,7 @@ from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||||
from tensorflow_privacy.privacy.dp_query import nested_query
|
from tensorflow_privacy.privacy.dp_query import nested_query
|
||||||
from tensorflow_privacy.privacy.dp_query import test_utils
|
from tensorflow_privacy.privacy.dp_query import test_utils
|
||||||
|
@ -40,7 +41,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
query2 = gaussian_query.GaussianSumQuery(
|
query2 = gaussian_query.GaussianSumQuery(
|
||||||
l2_norm_clip=10.0, stddev=0.0)
|
l2_norm_clip=10.0, stddev=0.0)
|
||||||
|
|
||||||
query = nested_query.NestedQuery([query1, query2])
|
query = nested_query.NestedSumQuery([query1, query2])
|
||||||
|
|
||||||
record1 = [1.0, [2.0, 3.0]]
|
record1 = [1.0, [2.0, 3.0]]
|
||||||
record2 = [4.0, [3.0, 2.0]]
|
record2 = [4.0, [3.0, 2.0]]
|
||||||
|
@ -57,7 +58,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
query2 = gaussian_query.GaussianAverageQuery(
|
query2 = gaussian_query.GaussianAverageQuery(
|
||||||
l2_norm_clip=10.0, sum_stddev=0.0, denominator=5.0)
|
l2_norm_clip=10.0, sum_stddev=0.0, denominator=5.0)
|
||||||
|
|
||||||
query = nested_query.NestedQuery([query1, query2])
|
query = nested_query.NestedSumQuery([query1, query2])
|
||||||
|
|
||||||
record1 = [1.0, [2.0, 3.0]]
|
record1 = [1.0, [2.0, 3.0]]
|
||||||
record2 = [4.0, [3.0, 2.0]]
|
record2 = [4.0, [3.0, 2.0]]
|
||||||
|
@ -74,7 +75,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
query2 = gaussian_query.GaussianAverageQuery(
|
query2 = gaussian_query.GaussianAverageQuery(
|
||||||
l2_norm_clip=5.0, sum_stddev=0.0, denominator=5.0)
|
l2_norm_clip=5.0, sum_stddev=0.0, denominator=5.0)
|
||||||
|
|
||||||
query = nested_query.NestedQuery([query1, query2])
|
query = nested_query.NestedSumQuery([query1, query2])
|
||||||
|
|
||||||
record1 = [1.0, [12.0, 9.0]] # Clipped to [1.0, [4.0, 3.0]]
|
record1 = [1.0, [12.0, 9.0]] # Clipped to [1.0, [4.0, 3.0]]
|
||||||
record2 = [5.0, [1.0, 2.0]] # Clipped to [4.0, [1.0, 2.0]]
|
record2 = [5.0, [1.0, 2.0]] # Clipped to [4.0, [1.0, 2.0]]
|
||||||
|
@ -93,7 +94,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
query_d = gaussian_query.GaussianSumQuery(
|
query_d = gaussian_query.GaussianSumQuery(
|
||||||
l2_norm_clip=10.0, stddev=0.0)
|
l2_norm_clip=10.0, stddev=0.0)
|
||||||
|
|
||||||
query = nested_query.NestedQuery(
|
query = nested_query.NestedSumQuery(
|
||||||
[query_ab, {'c': query_c, 'd': [query_d]}])
|
[query_ab, {'c': query_c, 'd': [query_d]}])
|
||||||
|
|
||||||
record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}]
|
record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}]
|
||||||
|
@ -113,7 +114,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
l2_norm_clip=1.5, stddev=sum_stddev)
|
l2_norm_clip=1.5, stddev=sum_stddev)
|
||||||
query2 = gaussian_query.GaussianAverageQuery(
|
query2 = gaussian_query.GaussianAverageQuery(
|
||||||
l2_norm_clip=0.5, sum_stddev=sum_stddev, denominator=denominator)
|
l2_norm_clip=0.5, sum_stddev=sum_stddev, denominator=denominator)
|
||||||
query = nested_query.NestedQuery((query1, query2))
|
query = nested_query.NestedSumQuery((query1, query2))
|
||||||
|
|
||||||
record1 = (3.0, [2.0, 1.5])
|
record1 = (3.0, [2.0, 1.5])
|
||||||
record2 = (0.0, [-1.0, -3.5])
|
record2 = (0.0, [-1.0, -3.5])
|
||||||
|
@ -136,7 +137,20 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
def test_record_incompatible_with_query(
|
def test_record_incompatible_with_query(
|
||||||
self, queries, record, error_type):
|
self, queries, record, error_type):
|
||||||
with self.assertRaises(error_type):
|
with self.assertRaises(error_type):
|
||||||
test_utils.run_query(nested_query.NestedQuery(queries), [record])
|
test_utils.run_query(nested_query.NestedSumQuery(queries), [record])
|
||||||
|
|
||||||
|
def test_raises_with_non_sum(self):
|
||||||
|
class NonSumDPQuery(dp_query.DPQuery):
|
||||||
|
pass
|
||||||
|
|
||||||
|
non_sum_query = NonSumDPQuery()
|
||||||
|
|
||||||
|
# This should work.
|
||||||
|
nested_query.NestedQuery(non_sum_query)
|
||||||
|
|
||||||
|
# This should not.
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
nested_query.NestedSumQuery(non_sum_query)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in a new issue