diff --git a/tensorflow_privacy/privacy/dp_query/dp_query.py b/tensorflow_privacy/privacy/dp_query/dp_query.py index 8e21c63..4dfe962 100644 --- a/tensorflow_privacy/privacy/dp_query/dp_query.py +++ b/tensorflow_privacy/privacy/dp_query/dp_query.py @@ -47,7 +47,7 @@ import collections import tensorflow.compat.v1 as tf -class DPQuery(object): +class DPQuery(metaclass=abc.ABCMeta): """Interface for differentially private query mechanisms. Differential privacy is achieved by processing records to bound sensitivity, @@ -93,8 +93,6 @@ class DPQuery(object): ``` """ - __metaclass__ = abc.ABCMeta - def initial_global_state(self): """Returns the initial global state for the DPQuery. diff --git a/tensorflow_privacy/privacy/dp_query/nested_query_test.py b/tensorflow_privacy/privacy/dp_query/nested_query_test.py index ef461a2..e4b8682 100644 --- a/tensorflow_privacy/privacy/dp_query/nested_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/nested_query_test.py @@ -120,7 +120,23 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase): def test_raises_with_non_sum(self): class NonSumDPQuery(dp_query.DPQuery): - pass + + def initial_sample_state(self, template=None): + del template # Unused. + return None + + def accumulate_preprocessed_record(self, sample_state, + preprocessed_record): + del sample_state, preprocessed_record # Unused. + return None + + def merge_sample_states(self, sample_state_1, sample_state_2): + del sample_state_1, sample_state_2 # Unused. + return None + + def get_noised_result(self, sample_state, global_state): + del sample_state, global_state # Unused. + return None non_sum_query = NonSumDPQuery() @@ -138,6 +154,23 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase): def __init__(self, metric_val): self._metric_val = metric_val + def initial_sample_state(self, template=None): + del template # Unused. + return None + + def accumulate_preprocessed_record(self, sample_state, + preprocessed_record): + del sample_state, preprocessed_record # Unused. + return None + + def merge_sample_states(self, sample_state_1, sample_state_2): + del sample_state_1, sample_state_2 # Unused. + return None + + def get_noised_result(self, sample_state, global_state): + del sample_state, global_state # Unused. + return None + def derive_metrics(self, global_state): return collections.OrderedDict(metric=self._metric_val)