diff --git a/tensorflow_privacy/privacy/dp_query/nested_query.py b/tensorflow_privacy/privacy/dp_query/nested_query.py index dfdb6e8..75b2db1 100644 --- a/tensorflow_privacy/privacy/dp_query/nested_query.py +++ b/tensorflow_privacy/privacy/dp_query/nested_query.py @@ -108,3 +108,21 @@ class NestedQuery(dp_query.DPQuery): *tree.flatten_up_to(self._queries, estimates_and_new_global_states)) return (tf.nest.pack_sequence_as(self._queries, flat_estimates), tf.nest.pack_sequence_as(self._queries, flat_new_global_states)) + + +class NestedSumQuery(dp_query.SumAggregationDPQuery, NestedQuery): + """A NestedQuery that consists only of SumAggregationDPQueries.""" + + def __init__(self, queries): + """Initializes the NestedSumQuery. + + Args: + queries: A nested structure of queries that must all be + SumAggregationDPQueries. + """ + def check(query): + if not isinstance(query, dp_query.SumAggregationDPQuery): + raise TypeError('All subqueries must be SumAggregationDPQueries.') + tree.map_structure(check, queries) + + super(NestedSumQuery, self).__init__(queries) diff --git a/tensorflow_privacy/privacy/dp_query/nested_query_test.py b/tensorflow_privacy/privacy/dp_query/nested_query_test.py index 06749b1..625487e 100644 --- a/tensorflow_privacy/privacy/dp_query/nested_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/nested_query_test.py @@ -23,6 +23,7 @@ from absl.testing import parameterized import numpy as np import tensorflow.compat.v1 as tf +from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import gaussian_query from tensorflow_privacy.privacy.dp_query import nested_query from tensorflow_privacy.privacy.dp_query import test_utils @@ -40,7 +41,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase): query2 = gaussian_query.GaussianSumQuery( l2_norm_clip=10.0, stddev=0.0) - query = nested_query.NestedQuery([query1, query2]) + query = nested_query.NestedSumQuery([query1, query2]) record1 = [1.0, [2.0, 3.0]] record2 = [4.0, [3.0, 2.0]] @@ -57,7 +58,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase): query2 = gaussian_query.GaussianAverageQuery( l2_norm_clip=10.0, sum_stddev=0.0, denominator=5.0) - query = nested_query.NestedQuery([query1, query2]) + query = nested_query.NestedSumQuery([query1, query2]) record1 = [1.0, [2.0, 3.0]] record2 = [4.0, [3.0, 2.0]] @@ -74,7 +75,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase): query2 = gaussian_query.GaussianAverageQuery( l2_norm_clip=5.0, sum_stddev=0.0, denominator=5.0) - query = nested_query.NestedQuery([query1, query2]) + query = nested_query.NestedSumQuery([query1, query2]) record1 = [1.0, [12.0, 9.0]] # Clipped to [1.0, [4.0, 3.0]] record2 = [5.0, [1.0, 2.0]] # Clipped to [4.0, [1.0, 2.0]] @@ -93,7 +94,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase): query_d = gaussian_query.GaussianSumQuery( l2_norm_clip=10.0, stddev=0.0) - query = nested_query.NestedQuery( + query = nested_query.NestedSumQuery( [query_ab, {'c': query_c, 'd': [query_d]}]) record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}] @@ -113,7 +114,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase): l2_norm_clip=1.5, stddev=sum_stddev) query2 = gaussian_query.GaussianAverageQuery( l2_norm_clip=0.5, sum_stddev=sum_stddev, denominator=denominator) - query = nested_query.NestedQuery((query1, query2)) + query = nested_query.NestedSumQuery((query1, query2)) record1 = (3.0, [2.0, 1.5]) record2 = (0.0, [-1.0, -3.5]) @@ -136,7 +137,20 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase): def test_record_incompatible_with_query( self, queries, record, error_type): with self.assertRaises(error_type): - test_utils.run_query(nested_query.NestedQuery(queries), [record]) + test_utils.run_query(nested_query.NestedSumQuery(queries), [record]) + + def test_raises_with_non_sum(self): + class NonSumDPQuery(dp_query.DPQuery): + pass + + non_sum_query = NonSumDPQuery() + + # This should work. + nested_query.NestedQuery(non_sum_query) + + # This should not. + with self.assertRaises(TypeError): + nested_query.NestedSumQuery(non_sum_query) if __name__ == '__main__':