Update TensorFlow Privacy to use Python 3 metaclass.

PiperOrigin-RevId: 424773127
This commit is contained in:
Michael Reneer 2022-01-27 20:31:35 -08:00 committed by A. Unique TensorFlower
parent b0803999ad
commit 9050f18b59
2 changed files with 35 additions and 4 deletions

View file

@ -47,7 +47,7 @@ import collections
import tensorflow.compat.v1 as tf
class DPQuery(object):
class DPQuery(metaclass=abc.ABCMeta):
"""Interface for differentially private query mechanisms.
Differential privacy is achieved by processing records to bound sensitivity,
@ -93,8 +93,6 @@ class DPQuery(object):
```
"""
__metaclass__ = abc.ABCMeta
def initial_global_state(self):
"""Returns the initial global state for the DPQuery.

View file

@ -120,7 +120,23 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_raises_with_non_sum(self):
class NonSumDPQuery(dp_query.DPQuery):
pass
def initial_sample_state(self, template=None):
del template # Unused.
return None
def accumulate_preprocessed_record(self, sample_state,
preprocessed_record):
del sample_state, preprocessed_record # Unused.
return None
def merge_sample_states(self, sample_state_1, sample_state_2):
del sample_state_1, sample_state_2 # Unused.
return None
def get_noised_result(self, sample_state, global_state):
del sample_state, global_state # Unused.
return None
non_sum_query = NonSumDPQuery()
@ -138,6 +154,23 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
def __init__(self, metric_val):
self._metric_val = metric_val
def initial_sample_state(self, template=None):
del template # Unused.
return None
def accumulate_preprocessed_record(self, sample_state,
preprocessed_record):
del sample_state, preprocessed_record # Unused.
return None
def merge_sample_states(self, sample_state_1, sample_state_2):
del sample_state_1, sample_state_2 # Unused.
return None
def get_noised_result(self, sample_state, global_state):
del sample_state, global_state # Unused.
return None
def derive_metrics(self, global_state):
return collections.OrderedDict(metric=self._metric_val)