Adds NeighboringRelation to Accountant and clarifies FixedBatchSample events to be with or without replacement.

PiperOrigin-RevId: 393459878
This commit is contained in:
Galen Andrew 2021-08-27 17:33:11 -07:00 committed by A. Unique TensorFlower
parent 48e4836a36
commit 07c248d868
2 changed files with 28 additions and 3 deletions

View file

@ -85,8 +85,16 @@ class PoissonSampledDpEvent(DpEvent):
@attr.s(frozen=True, slots=True, auto_attribs=True)
class EqualBatchSampledDpEvent(DpEvent):
"""An application of sampling exactly `batch_size` records."""
class FixedBatchSampledWrDpEvent(DpEvent):
"""Sampling exactly `batch_size` records with replacement."""
dataset_size: int
batch_size: int
event: DpEvent
@attr.s(frozen=True, slots=True, auto_attribs=True)
class FixedBatchSampledWorDpEvent(DpEvent):
"""Sampling exactly `batch_size` records without replacement."""
dataset_size: int
batch_size: int
event: DpEvent

View file

@ -14,17 +14,34 @@
"""PrivacyAccountant abstract base class."""
import abc
import enum
from tensorflow_privacy.privacy.dp_event import dp_event
from tensorflow_privacy.privacy.dp_event import dp_event_builder
class NeighboringRelation(enum.Enum):
ADD_OR_REMOVE_ONE = 1
REPLACE_ONE = 2
class PrivacyAccountant(metaclass=abc.ABCMeta):
"""Abstract base class for privacy accountants."""
def __init__(self):
def __init__(self, neighboring_relation: NeighboringRelation):
self._neighboring_relation = neighboring_relation
self._ledger = dp_event_builder.DpEventBuilder()
@property
def neighboring_relation(self) -> NeighboringRelation:
"""The neighboring relation used by the accountant.
The neighboring relation is expected to remain constant after
initialization. Subclasses should not override this property or change the
value of the private attribute.
"""
return self._neighboring_relation
@abc.abstractmethod
def is_supported(self, event: dp_event.DpEvent) -> bool:
"""Checks whether the `DpEvent` can be processed by this accountant.