Update TensorFlow Privacy to use Python 3 metaclass
.
PiperOrigin-RevId: 424773127
This commit is contained in:
parent
b0803999ad
commit
9050f18b59
2 changed files with 35 additions and 4 deletions
|
@ -47,7 +47,7 @@ import collections
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
|
||||||
class DPQuery(object):
|
class DPQuery(metaclass=abc.ABCMeta):
|
||||||
"""Interface for differentially private query mechanisms.
|
"""Interface for differentially private query mechanisms.
|
||||||
|
|
||||||
Differential privacy is achieved by processing records to bound sensitivity,
|
Differential privacy is achieved by processing records to bound sensitivity,
|
||||||
|
@ -93,8 +93,6 @@ class DPQuery(object):
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__metaclass__ = abc.ABCMeta
|
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""Returns the initial global state for the DPQuery.
|
"""Returns the initial global state for the DPQuery.
|
||||||
|
|
||||||
|
|
|
@ -120,7 +120,23 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
def test_raises_with_non_sum(self):
|
def test_raises_with_non_sum(self):
|
||||||
|
|
||||||
class NonSumDPQuery(dp_query.DPQuery):
|
class NonSumDPQuery(dp_query.DPQuery):
|
||||||
pass
|
|
||||||
|
def initial_sample_state(self, template=None):
|
||||||
|
del template # Unused.
|
||||||
|
return None
|
||||||
|
|
||||||
|
def accumulate_preprocessed_record(self, sample_state,
|
||||||
|
preprocessed_record):
|
||||||
|
del sample_state, preprocessed_record # Unused.
|
||||||
|
return None
|
||||||
|
|
||||||
|
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||||
|
del sample_state_1, sample_state_2 # Unused.
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_noised_result(self, sample_state, global_state):
|
||||||
|
del sample_state, global_state # Unused.
|
||||||
|
return None
|
||||||
|
|
||||||
non_sum_query = NonSumDPQuery()
|
non_sum_query = NonSumDPQuery()
|
||||||
|
|
||||||
|
@ -138,6 +154,23 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
def __init__(self, metric_val):
|
def __init__(self, metric_val):
|
||||||
self._metric_val = metric_val
|
self._metric_val = metric_val
|
||||||
|
|
||||||
|
def initial_sample_state(self, template=None):
|
||||||
|
del template # Unused.
|
||||||
|
return None
|
||||||
|
|
||||||
|
def accumulate_preprocessed_record(self, sample_state,
|
||||||
|
preprocessed_record):
|
||||||
|
del sample_state, preprocessed_record # Unused.
|
||||||
|
return None
|
||||||
|
|
||||||
|
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||||
|
del sample_state_1, sample_state_2 # Unused.
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_noised_result(self, sample_state, global_state):
|
||||||
|
del sample_state, global_state # Unused.
|
||||||
|
return None
|
||||||
|
|
||||||
def derive_metrics(self, global_state):
|
def derive_metrics(self, global_state):
|
||||||
return collections.OrderedDict(metric=self._metric_val)
|
return collections.OrderedDict(metric=self._metric_val)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue