Update TensorFlow Privacy to use Python 3 metaclass.

PiperOrigin-RevId: 424773127
This commit is contained in:
Michael Reneer 2022-01-27 20:31:35 -08:00 committed by A. Unique TensorFlower
parent b0803999ad
commit 9050f18b59
2 changed files with 35 additions and 4 deletions

View file

@ -47,7 +47,7 @@ import collections
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
class DPQuery(object): class DPQuery(metaclass=abc.ABCMeta):
"""Interface for differentially private query mechanisms. """Interface for differentially private query mechanisms.
Differential privacy is achieved by processing records to bound sensitivity, Differential privacy is achieved by processing records to bound sensitivity,
@ -93,8 +93,6 @@ class DPQuery(object):
``` ```
""" """
__metaclass__ = abc.ABCMeta
def initial_global_state(self): def initial_global_state(self):
"""Returns the initial global state for the DPQuery. """Returns the initial global state for the DPQuery.

View file

@ -120,7 +120,23 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_raises_with_non_sum(self): def test_raises_with_non_sum(self):
class NonSumDPQuery(dp_query.DPQuery): class NonSumDPQuery(dp_query.DPQuery):
pass
def initial_sample_state(self, template=None):
del template # Unused.
return None
def accumulate_preprocessed_record(self, sample_state,
preprocessed_record):
del sample_state, preprocessed_record # Unused.
return None
def merge_sample_states(self, sample_state_1, sample_state_2):
del sample_state_1, sample_state_2 # Unused.
return None
def get_noised_result(self, sample_state, global_state):
del sample_state, global_state # Unused.
return None
non_sum_query = NonSumDPQuery() non_sum_query = NonSumDPQuery()
@ -138,6 +154,23 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
def __init__(self, metric_val): def __init__(self, metric_val):
self._metric_val = metric_val self._metric_val = metric_val
def initial_sample_state(self, template=None):
del template # Unused.
return None
def accumulate_preprocessed_record(self, sample_state,
preprocessed_record):
del sample_state, preprocessed_record # Unused.
return None
def merge_sample_states(self, sample_state_1, sample_state_2):
del sample_state_1, sample_state_2 # Unused.
return None
def get_noised_result(self, sample_state, global_state):
del sample_state, global_state # Unused.
return None
def derive_metrics(self, global_state): def derive_metrics(self, global_state):
return collections.OrderedDict(metric=self._metric_val) return collections.OrderedDict(metric=self._metric_val)