diff --git a/tensorflow_privacy/privacy/analysis/BUILD b/tensorflow_privacy/privacy/analysis/BUILD index 79e3f27..63abc69 100644 --- a/tensorflow_privacy/privacy/analysis/BUILD +++ b/tensorflow_privacy/privacy/analysis/BUILD @@ -51,55 +51,11 @@ py_test( deps = [":compute_noise_from_budget_lib"], ) -py_library( - name = "dp_event", - srcs = ["dp_event.py"], - srcs_version = "PY3", -) - -py_library( - name = "dp_event_builder", - srcs = ["dp_event_builder.py"], - srcs_version = "PY3", - deps = [":dp_event"], -) - -py_test( - name = "dp_event_builder_test", - srcs = ["dp_event_builder_test.py"], - python_version = "PY3", - srcs_version = "PY3", - deps = [ - ":dp_event", - ":dp_event_builder", - ], -) - py_library( name = "gdp_accountant", srcs = ["gdp_accountant.py"], ) -py_library( - name = "privacy_accountant", - srcs = ["privacy_accountant.py"], - srcs_version = "PY3", - deps = [ - ":dp_event", - ":dp_event_builder", - ], -) - -py_library( - name = "privacy_accountant_test", - srcs = ["privacy_accountant_test.py"], - srcs_version = "PY3", - deps = [ - ":dp_event", - ":privacy_accountant", - ], -) - py_library( name = "rdp_accountant", srcs = ["rdp_accountant.py"], @@ -116,30 +72,6 @@ py_test( deps = [":rdp_accountant"], ) -py_library( - name = "rdp_privacy_accountant", - srcs = ["rdp_privacy_accountant.py"], - srcs_version = "PY3", - deps = [ - ":dp_event", - ":privacy_accountant", - ], -) - -py_test( - name = "rdp_privacy_accountant_test", - size = "small", - srcs = ["rdp_privacy_accountant_test.py"], - python_version = "PY3", - srcs_version = "PY3", - deps = [ - ":dp_event", - ":privacy_accountant", - ":privacy_accountant_test", - ":rdp_privacy_accountant", - ], -) - py_library( name = "tensor_buffer", srcs = ["tensor_buffer.py"], diff --git a/tensorflow_privacy/privacy/analysis/dp_event.py b/tensorflow_privacy/privacy/analysis/dp_event.py deleted file mode 100644 index 5d37b43..0000000 --- a/tensorflow_privacy/privacy/analysis/dp_event.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright 2021, The TensorFlow Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Standard DpEvent classes. - -A `DpEvent` represents the (hyper)parameters of a differentially -private query, amplification mechanism, or composition, that are necessary -and sufficient for privacy accounting. Various independent implementations of DP -algorithms that are functionally equivalent from an accounting perspective may -correspond to the same `DpEvent`. Similarly, various independent implementations -of accounting algorithms may consume the same `DpEvent`. - -All `DpEvents` processed together are assumed to take place on a single dataset -of records. `DpEvents` fall into roughly three categories: - - `DpEvents` that release an output, and incur a privacy cost, - e.g., `GaussianDpEvent`. - - `DpEvents` that select a subset (or subsets) of the dataset, and run nested - `DpEvents` on those subsets, e.g., `PoissonSampledDpEvent`. - - `DpEvents` that represent (possibly sequentially) applying (multiple) - mechanisms to the dataset (or currently active subset). Currently, this is - only `ComposedDpEvent` and `SelfComposedDpEvent`. - -Each `DpEvent` should completely document the mathematical behavior and -assumptions of the mechanism it represents so that the writer of an accountant -class can implement the accounting correctly without knowing any other -implementation details of the algorithm that produced it. - -New mechanism types should be given a corresponding `DpEvent` class, although -not all accountants will be required to support them. In general, -`PrivacyAccountant` implementations are not required to be aware of all -`DpEvent` classes, but they should support the following basic events and handle -them appropriately: `NoOpDpEvent`, `NonPrivateDpEvent`, `ComposedDpEvent`, and -`SelfComposedDpEvent`. They should return `supports(event)` is False for -`UnsupportedDpEvent` or any other event type they have not been designed to -handle. - -To ensure that a `PrivacyAccountant` does not accidentally start to return -incorrect results, the following should be enforced: - * `DpEvent` classes and their parameters should never be removed, barring some - extended, onerous deprecation process. - * New parameters cannot be added to existing mechanisms unless they are - optional. That is, old composed `DpEvent` objects that do not include them - must remain valid. - * The meaning of existing mechanisms or parameters must not change. That is, - existing mechanisms should not have their implementations change in ways that - alter their privacy properties; new `DpEvent` classes should be added - instead. - * `PrivacyAccountant` implementations are expected to return `supports(event)` - is `False` when processing unknown mechanisms. -""" - -from typing import List, Union - -import attr - - -class DpEvent(object): - """Represents application of a private mechanism. - - A `DpEvent` describes a differentially private mechanism sufficiently for - computing the associated privacy losses, both in isolation and in combination - with other `DpEvent`s. - """ - - -@attr.s(frozen=True) -class NoOpDpEvent(DpEvent): - """Represents appplication of an operation with no privacy impact. - - A `NoOpDpEvent` is generally never required, but it can be useful as a - placeholder where a `DpEvent` is expected, such as in tests or some live - accounting pipelines. - """ - - -@attr.s(frozen=True) -class NonPrivateDpEvent(DpEvent): - """Represents application of a non-private operation. - - This `DpEvent` should be used when an operation is performed that does not - satisfy (epsilon, delta)-DP. All `PrivacyAccountant`s should return infinite - epsilon/delta when encountering a `NonPrivateDpEvent`. - """ - - -@attr.s(frozen=True) -class UnsupportedDpEvent(DpEvent): - """Represents application of an as-yet unsupported operation. - - This `DpEvent` should be used when an operation is performed that does not yet - have any associated DP description, or if the description is temporarily - inaccessible, for example, during development. All `PrivacyAccountant`s should - return `supports(event) == False` for `UnsupportedDpEvent`. - """ - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class GaussianDpEvent(DpEvent): - """Represents an application of the Gaussian mechanism. - - For values v_i and noise z ~ N(0, s^2I), this mechanism returns sum_i v_i + z. - If the norms of the values are bounded ||v_i|| <= C, the noise_multiplier is - defined as s / C. - """ - noise_multiplier: float - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class LaplaceDpEvent(DpEvent): - """Represents an application of the Laplace mechanism. - - For values v_i and noise z sampled coordinate-wise from the Laplace - distribution L(0, s), this mechanism returns sum_i v_i + z. - The probability density function of the Laplace distribution L(0, s) with - parameter s is given as exp(-|x|/s) * (0.5/s) at x for any real value x. - If the L_1 norm of the values are bounded ||v_i||_1 <= C, the noise_multiplier - is defined as s / C. - """ - noise_multiplier: float - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class SelfComposedDpEvent(DpEvent): - """Represents repeated application of a mechanism. - - The repeated applications may be adaptive, where the query producing each - event depends on the results of prior queries. - - This is equivalent to `ComposedDpEvent` that contains a list of length `count` - of identical copies of `event`. - """ - event: DpEvent - count: int - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class ComposedDpEvent(DpEvent): - """Represents application of a series of composed mechanisms. - - The composition may be adaptive, where the query producing each event depends - on the results of prior queries. - """ - events: List[DpEvent] - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class PoissonSampledDpEvent(DpEvent): - """Represents an application of Poisson subsampling. - - Each record in the dataset is included in the sample independently with - probability `sampling_probability`. Then the `DpEvent` `event` is applied - to the sample of records. - """ - sampling_probability: float - event: DpEvent - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class SampledWithReplacementDpEvent(DpEvent): - """Represents sampling a fixed sized batch of records with replacement. - - A sample of `sample_size` (possibly repeated) records is drawn uniformly at - random from the set of possible samples of a source dataset of size - `source_dataset_size`. Then the `DpEvent` `event` is applied to the sample of - records. - """ - source_dataset_size: int - sample_size: int - event: DpEvent - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class SampledWithoutReplacementDpEvent(DpEvent): - """Represents sampling a fixed sized batch of records without replacement. - - A sample of `sample_size` unique records is drawn uniformly at random from the - set of possible samples of a source dataset of size `source_dataset_size`. - Then the `DpEvent` `event` is applied to the sample of records. - """ - source_dataset_size: int - sample_size: int - event: DpEvent - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class SingleEpochTreeAggregationDpEvent(DpEvent): - """Represents aggregation for a single epoch using one or more trees. - - Multiple tree-aggregation steps can occur, but it is required that each - record occurs at most once *across all trees*. See appendix D of - "Practical and Private (Deep) Learning without Sampling or Shuffling" - https://arxiv.org/abs/2103.00039. - - To represent the common case where the same record can occur in multiple - trees (but still at most once per tree), wrap this with `SelfComposedDpEvent` - or `ComposedDpEvent` and use a scalar for `step_counts`. - - Attributes: - noise_multiplier: The ratio of the noise per node to the sensitivity. - step_counts: The number of steps in each tree. May be a scalar for a single - tree. - """ - noise_multiplier: float - step_counts: Union[int, List[int]] diff --git a/tensorflow_privacy/privacy/analysis/dp_event_builder.py b/tensorflow_privacy/privacy/analysis/dp_event_builder.py deleted file mode 100644 index 53d4cc2..0000000 --- a/tensorflow_privacy/privacy/analysis/dp_event_builder.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2021, The TensorFlow Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Builder class for ComposedDpEvent.""" - -from tensorflow_privacy.privacy.analysis import dp_event - - -class DpEventBuilder(object): - """Constructs a `DpEvent` representing the composition of a series of events. - - Two common use cases of the `DpEventBuilder` are 1) for producing and tracking - a ledger of `DpEvent`s during sequential accounting using a - `PrivacyAccountant`, and 2) for building up a description of a composite - mechanism for subsequent batch accounting. - """ - - def __init__(self): - # A list of (event, count) pairs. - self._event_counts = [] - self._composed_event = None - - def compose(self, event: dp_event.DpEvent, count: int = 1): - """Composes new event into event represented by builder. - - Args: - event: The new event to compose. - count: The number of times to compose the event. - """ - if not isinstance(event, dp_event.DpEvent): - raise TypeError('`event` must be a subclass of `DpEvent`. ' - f'Found {type(event)}.') - if not isinstance(count, int): - raise TypeError(f'`count` must be an integer. Found {type(count)}.') - if count < 1: - raise ValueError(f'`count` must be positive. Found {count}.') - - if isinstance(event, dp_event.NoOpDpEvent): - return - elif isinstance(event, dp_event.SelfComposedDpEvent): - self.compose(event.event, count * event.count) - else: - if self._event_counts and self._event_counts[-1][0] == event: - new_event_count = (event, self._event_counts[-1][1] + count) - self._event_counts[-1] = new_event_count - else: - self._event_counts.append((event, count)) - self._composed_event = None - - def build(self) -> dp_event.DpEvent: - """Builds and returns the composed DpEvent represented by the builder.""" - if not self._composed_event: - events = [] - for event, count in self._event_counts: - if count == 1: - events.append(event) - else: - events.append(dp_event.SelfComposedDpEvent(event, count)) - if not events: - self._composed_event = dp_event.NoOpDpEvent() - elif len(events) == 1: - self._composed_event = events[0] - else: - self._composed_event = dp_event.ComposedDpEvent(events) - - return self._composed_event diff --git a/tensorflow_privacy/privacy/analysis/dp_event_builder_test.py b/tensorflow_privacy/privacy/analysis/dp_event_builder_test.py deleted file mode 100644 index 6ca352c..0000000 --- a/tensorflow_privacy/privacy/analysis/dp_event_builder_test.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2021, The TensorFlow Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -from tensorflow_privacy.privacy.analysis import dp_event -from tensorflow_privacy.privacy.analysis import dp_event_builder - -_gaussian_event = dp_event.GaussianDpEvent(1.0) -_laplace_event = dp_event.LaplaceDpEvent(1.0) -_poisson_event = dp_event.PoissonSampledDpEvent(_gaussian_event, 0.1) -_self_composed_event = dp_event.SelfComposedDpEvent(_gaussian_event, 3) - - -class DpEventBuilderTest(absltest.TestCase): - - def test_no_op(self): - builder = dp_event_builder.DpEventBuilder() - self.assertEqual(dp_event.NoOpDpEvent(), builder.build()) - - def test_single_gaussian(self): - builder = dp_event_builder.DpEventBuilder() - builder.compose(_gaussian_event) - self.assertEqual(_gaussian_event, builder.build()) - - def test_single_laplace(self): - builder = dp_event_builder.DpEventBuilder() - builder.compose(_laplace_event) - self.assertEqual(_laplace_event, builder.build()) - - def test_compose_no_op(self): - builder = dp_event_builder.DpEventBuilder() - builder.compose(dp_event.NoOpDpEvent()) - builder.compose(_gaussian_event) - builder.compose(dp_event.NoOpDpEvent()) - self.assertEqual(_gaussian_event, builder.build()) - - def test_compose_self(self): - builder = dp_event_builder.DpEventBuilder() - builder.compose(_gaussian_event) - builder.compose(_gaussian_event, 2) - self.assertEqual(_self_composed_event, builder.build()) - - def test_compose_heterogenous(self): - builder = dp_event_builder.DpEventBuilder() - builder.compose(_poisson_event) - builder.compose(_gaussian_event) - builder.compose(_gaussian_event, 2) - builder.compose(_poisson_event) - expected_event = dp_event.ComposedDpEvent( - [_poisson_event, _self_composed_event, _poisson_event]) - self.assertEqual(expected_event, builder.build()) - - def test_compose_composed(self): - builder = dp_event_builder.DpEventBuilder() - composed_event = dp_event.ComposedDpEvent( - [_gaussian_event, _poisson_event, _self_composed_event]) - builder.compose(_gaussian_event) - builder.compose(composed_event) - builder.compose(composed_event, 2) - builder.compose(_poisson_event) - builder.compose(_poisson_event) - expected_event = dp_event.ComposedDpEvent([ - _gaussian_event, - dp_event.SelfComposedDpEvent(composed_event, 3), - dp_event.SelfComposedDpEvent(_poisson_event, 2) - ]) - self.assertEqual(expected_event, builder.build()) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_privacy/privacy/analysis/privacy_accountant.py b/tensorflow_privacy/privacy/analysis/privacy_accountant.py deleted file mode 100644 index 89ca70b..0000000 --- a/tensorflow_privacy/privacy/analysis/privacy_accountant.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2021, The TensorFlow Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PrivacyAccountant abstract base class.""" - -import abc -import enum - -from tensorflow_privacy.privacy.analysis import dp_event -from tensorflow_privacy.privacy.analysis import dp_event_builder - - -class NeighboringRelation(enum.Enum): - ADD_OR_REMOVE_ONE = 1 - REPLACE_ONE = 2 - - # A record is replaced with a special record, such as the "zero record". See - # https://arxiv.org/pdf/2103.00039.pdf, Definition 1.1. - REPLACE_SPECIAL = 3 - - -class UnsupportedEventError(Exception): - """Exception to raise if _compose is called on unsupported event type.""" - - -class PrivacyAccountant(metaclass=abc.ABCMeta): - """Abstract base class for privacy accountants.""" - - def __init__(self, neighboring_relation: NeighboringRelation): - self._neighboring_relation = neighboring_relation - self._ledger = dp_event_builder.DpEventBuilder() - - @property - def neighboring_relation(self) -> NeighboringRelation: - """The neighboring relation used by the accountant. - - The neighboring relation is expected to remain constant after - initialization. Subclasses should not override this property or change the - value of the private attribute. - """ - return self._neighboring_relation - - @abc.abstractmethod - def supports(self, event: dp_event.DpEvent) -> bool: - """Checks whether the `DpEvent` can be processed by this accountant. - - In general this will require recursively checking the structure of the - `DpEvent`. In particular `ComposedDpEvent` and `SelfComposedDpEvent` should - be recursively examined. - - Args: - event: The `DpEvent` to check. - - Returns: - True iff this accountant supports processing `event`. - """ - - @abc.abstractmethod - def _compose(self, event: dp_event.DpEvent, count: int = 1): - """Updates internal state to account for application of a `DpEvent`. - - Calls to `get_epsilon` or `get_delta` after calling `_compose` will return - values that account for this `DpEvent`. - - Args: - event: A `DpEvent` to process. - count: The number of times to compose the event. - """ - - def compose(self, event: dp_event.DpEvent, count: int = 1): - """Updates internal state to account for application of a `DpEvent`. - - Calls to `get_epsilon` or `get_delta` after calling `compose` will return - values that account for this `DpEvent`. - - Args: - event: A `DpEvent` to process. - count: The number of times to compose the event. - - Raises: - UnsupportedEventError: `event` is not supported by this - `PrivacyAccountant`. - """ - if not isinstance(event, dp_event.DpEvent): - raise TypeError(f'`event` must be `DpEvent`. Found {type(event)}.') - - if not self.supports(event): - raise UnsupportedEventError(f'Unsupported event: {event}.') - - self._ledger.compose(event, count) - self._compose(event, count) - - @property - def ledger(self) -> dp_event.DpEvent: - """Returns the (composed) `DpEvent` processed so far by this accountant.""" - return self._ledger.build() - - @abc.abstractmethod - def get_epsilon(self, target_delta: float) -> float: - """Gets the current epsilon. - - Args: - target_delta: The target delta. - - Returns: - The current epsilon, accounting for all composed `DpEvent`s. - """ - - def get_delta(self, target_epsilon: float) -> float: - """Gets the current delta. - - An implementer of `PrivacyAccountant` may choose not to override this, in - which case `NotImplementedError` will be raised. - - Args: - target_epsilon: The target epsilon. - - Returns: - The current delta, accounting for all composed `DpEvent`s. - """ - raise NotImplementedError() diff --git a/tensorflow_privacy/privacy/analysis/privacy_accountant_test.py b/tensorflow_privacy/privacy/analysis/privacy_accountant_test.py deleted file mode 100644 index 344f3e4..0000000 --- a/tensorflow_privacy/privacy/analysis/privacy_accountant_test.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Abstract base class for tests of `PrivacyAccountant` classes. - -Checks that a class derived from `PrivacyAccountant` has the correct behavior -for standard `DpEvent` classes. -""" - -from typing import Collection - -from absl.testing import absltest - -from tensorflow_privacy.privacy.analysis import dp_event -from tensorflow_privacy.privacy.analysis import privacy_accountant - - -class PrivacyAccountantTest(absltest.TestCase): - - def _make_test_accountants( - self) -> Collection[privacy_accountant.PrivacyAccountant]: - """Makes a list of accountants to test. - - Subclasses should define this to return a list of accountants to be tested. - - Returns: - A list of accountants to test. - """ - return [] - - def test_make_test_accountants(self): - self.assertNotEmpty(self._make_test_accountants()) - - def test_unsupported(self): - - class UnknownDpEvent(dp_event.DpEvent): - pass - - for accountant in self._make_test_accountants(): - for unsupported in [dp_event.UnsupportedDpEvent(), UnknownDpEvent()]: - self.assertFalse(accountant.supports(unsupported)) - self.assertFalse( - accountant.supports(dp_event.SelfComposedDpEvent(unsupported, 10))) - self.assertFalse( - accountant.supports(dp_event.ComposedDpEvent([unsupported]))) - - def test_no_events(self): - for accountant in self._make_test_accountants(): - self.assertEqual(accountant.get_epsilon(1e-12), 0) - self.assertEqual(accountant.get_epsilon(0), 0) - self.assertEqual(accountant.get_epsilon(1), 0) - try: - self.assertEqual(accountant.get_delta(1e-12), 0) - self.assertEqual(accountant.get_delta(0), 0) - self.assertEqual(accountant.get_delta(float('inf')), 0) - except NotImplementedError: - # Implementing `get_delta` is optional. - pass - - def test_no_op(self): - for accountant in self._make_test_accountants(): - event = dp_event.NoOpDpEvent() - self.assertTrue(accountant.supports(event)) - accountant._compose(event) - self.assertEqual(accountant.get_epsilon(1e-12), 0) - self.assertEqual(accountant.get_epsilon(0), 0) - self.assertEqual(accountant.get_epsilon(1), 0) - try: - self.assertEqual(accountant.get_delta(1e-12), 0) - self.assertEqual(accountant.get_delta(0), 0) - self.assertEqual(accountant.get_delta(float('inf')), 0) - except NotImplementedError: - # Implementing `get_delta` is optional. - pass - - def test_non_private(self): - for accountant in self._make_test_accountants(): - event = dp_event.NonPrivateDpEvent() - self.assertTrue(accountant.supports(event)) - accountant._compose(event) - self.assertEqual(accountant.get_epsilon(0.99), float('inf')) - self.assertEqual(accountant.get_epsilon(0), float('inf')) - self.assertEqual(accountant.get_epsilon(1), float('inf')) - try: - self.assertEqual(accountant.get_delta(100), 1) - self.assertEqual(accountant.get_delta(0), 1) - self.assertEqual(accountant.get_delta(float('inf')), 1) - except NotImplementedError: - # Implementing `get_delta` is optional. - pass diff --git a/tensorflow_privacy/privacy/analysis/rdp_privacy_accountant.py b/tensorflow_privacy/privacy/analysis/rdp_privacy_accountant.py deleted file mode 100644 index 9239f90..0000000 --- a/tensorflow_privacy/privacy/analysis/rdp_privacy_accountant.py +++ /dev/null @@ -1,678 +0,0 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Privacy accountant that uses Renyi differential privacy.""" - -import math -from typing import Callable, Optional, Sequence, Tuple, Union - -import numpy as np -from scipy import special - -from tensorflow_privacy.privacy.analysis import dp_event -from tensorflow_privacy.privacy.analysis import privacy_accountant - -NeighborRel = privacy_accountant.NeighboringRelation - - -def _log_add(logx: float, logy: float) -> float: - """Adds two numbers in the log space.""" - a, b = min(logx, logy), max(logx, logy) - if a == -np.inf: # adding 0 - return b - # Use exp(a) + exp(b) = (exp(a - b) + 1) * exp(b) - return math.log1p(math.exp(a - b)) + b # log1p(x) = log(x + 1) - - -def _log_sub(logx: float, logy: float) -> float: - """Subtracts two numbers in the log space. Answer must be non-negative.""" - if logx < logy: - raise ValueError('The result of subtraction must be non-negative.') - if logy == -np.inf: # subtracting 0 - return logx - if logx == logy: - return -np.inf # 0 is represented as -np.inf in the log space. - - try: - # Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y). - return math.log(math.expm1(logx - logy)) + logy # expm1(x) = exp(x) - 1 - except OverflowError: - return logx - - -def _log_sub_sign(logx: float, logy: float) -> Tuple[bool, float]: - """Returns log(exp(logx)-exp(logy)) and its sign.""" - if logx > logy: - s = True - mag = logx + np.log(1 - np.exp(logy - logx)) - elif logx < logy: - s = False - mag = logy + np.log(1 - np.exp(logx - logy)) - else: - s = True - mag = -np.inf - - return s, mag - - -def _log_comb(n: int, k: int) -> float: - """Computes log of binomial coefficient.""" - return (special.gammaln(n + 1) - special.gammaln(k + 1) - - special.gammaln(n - k + 1)) - - -def _compute_log_a_int(q: float, sigma: float, alpha: int) -> float: - """Computes log(A_alpha) for integer alpha, 0 < q < 1.""" - - # Initialize with 0 in the log space. - log_a = -np.inf - - for i in range(alpha + 1): - log_coef_i = ( - _log_comb(alpha, i) + i * math.log(q) + (alpha - i) * math.log(1 - q)) - - s = log_coef_i + (i * i - i) / (2 * (sigma**2)) - log_a = _log_add(log_a, s) - - return float(log_a) - - -def _compute_log_a_frac(q: float, sigma: float, alpha: float) -> float: - """Computes log(A_alpha) for fractional alpha, 0 < q < 1.""" - # The two parts of A_alpha, integrals over (-inf,z0] and [z0, +inf), are - # initialized to 0 in the log space: - log_a0, log_a1 = -np.inf, -np.inf - i = 0 - - z0 = sigma**2 * math.log(1 / q - 1) + .5 - - while True: # do ... until loop - coef = special.binom(alpha, i) - log_coef = math.log(abs(coef)) - j = alpha - i - - log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q) - log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q) - - log_e0 = math.log(.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma)) - log_e1 = math.log(.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma)) - - log_s0 = log_t0 + (i * i - i) / (2 * (sigma**2)) + log_e0 - log_s1 = log_t1 + (j * j - j) / (2 * (sigma**2)) + log_e1 - - if coef > 0: - log_a0 = _log_add(log_a0, log_s0) - log_a1 = _log_add(log_a1, log_s1) - else: - log_a0 = _log_sub(log_a0, log_s0) - log_a1 = _log_sub(log_a1, log_s1) - - i += 1 - if max(log_s0, log_s1) < -30: - break - - return _log_add(log_a0, log_a1) - - -def _log_erfc(x: float) -> float: - """Computes log(erfc(x)) with high accuracy for large x.""" - try: - return math.log(2) + special.log_ndtr(-x * 2**.5) - except NameError: - # If log_ndtr is not available, approximate as follows: - r = special.erfc(x) - if r == 0.0: - # Using the Laurent series at infinity for the tail of the erfc function: - # erfc(x) ~ exp(-x^2-.5/x^2+.625/x^4)/(x*pi^.5) - # To verify in Mathematica: - # Series[Log[Erfc[x]] + Log[x] + Log[Pi]/2 + x^2, {x, Infinity, 6}] - return (-math.log(math.pi) / 2 - math.log(x) - x**2 - .5 * x**-2 + - .625 * x**-4 - 37. / 24. * x**-6 + 353. / 64. * x**-8) - else: - return math.log(r) - - -def _compute_delta(orders: Sequence[float], rdp: Sequence[float], - epsilon: float) -> float: - """Compute delta given a list of RDP values and target epsilon. - - Args: - orders: An array of orders. - rdp: An array of RDP guarantees. - epsilon: The target epsilon. - - Returns: - Optimal delta. - - Raises: - ValueError: If input is malformed. - - """ - if epsilon < 0: - raise ValueError(f'Epsilon cannot be negative. Found {epsilon}.') - if len(orders) != len(rdp): - raise ValueError('Input lists must have the same length.') - - # Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3): - # delta = min( np.exp((rdp - epsilon) * (orders - 1)) ) - - # Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4): - logdeltas = [] # work in log space to avoid overflows - for (a, r) in zip(orders, rdp): - if a < 1: - raise ValueError(f'Renyi divergence order must be at least 1. Found {a}.') - if r < 0: - raise ValueError(f'Renyi divergence cannot be negative. Found {r}.') - # For small alpha, we are better of with bound via KL divergence: - # delta <= sqrt(1-exp(-KL)). - # Take a min of the two bounds. - if r == 0: - logdelta = -np.inf - else: - logdelta = 0.5 * math.log1p(-math.exp(-r)) - if a > 1.01: - # This bound is not numerically stable as alpha->1. - # Thus we have a min value for alpha. - # The bound is also not useful for small alpha, so doesn't matter. - rdp_bound = (a - 1) * (r - epsilon + math.log1p(-1 / a)) - math.log(a) - logdelta = min(logdelta, rdp_bound) - - logdeltas.append(logdelta) - - return min(math.exp(np.min(logdeltas)), 1.) - - -def _compute_epsilon(orders: Sequence[float], rdp: Sequence[float], - delta: float) -> float: - """Compute epsilon given a list of RDP values and target delta. - - Args: - orders: An array of orders. - rdp: An array of RDP guarantees. - delta: The target delta. Must be >= 0. - - Returns: - Optimal epsilon. - - Raises: - ValueError: If input is malformed. - - """ - if delta < 0: - raise ValueError(f'Delta cannot be negative. Found {delta}.') - - if delta == 0: - if all(r == 0 for r in rdp): - return 0 - else: - return np.inf - - if len(orders) != len(rdp): - raise ValueError('Input lists must have the same length.') - - # Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3): - # epsilon = min( rdp - math.log(delta) / (orders - 1) ) - - # Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4). - # Also appears in https://arxiv.org/abs/2001.05990 Equation 20 (in v1). - eps = [] - for (a, r) in zip(orders, rdp): - if a < 1: - raise ValueError(f'Renyi divergence order must be at least 1. Found {a}.') - if r < 0: - raise ValueError(f'Renyi divergence cannot be negative. Found {r}.') - - if delta**2 + math.expm1(-r) > 0: - # In this case, we can simply bound via KL divergence: - # delta <= sqrt(1-exp(-KL)). - epsilon = 0 # No need to try further computation if we have epsilon = 0. - elif a > 1.01: - # This bound is not numerically stable as alpha->1. - # Thus we have a min value of alpha. - # The bound is also not useful for small alpha, so doesn't matter. - epsilon = r + math.log1p(-1 / a) - math.log(delta * a) / (a - 1) - else: - # In this case we can't do anything. E.g., asking for delta = 0. - epsilon = np.inf - eps.append(epsilon) - - return max(0, np.min(eps)) - - -def _stable_inplace_diff_in_log(vec: np.ndarray, - signs: np.ndarray, - n: Optional[int] = None): - """Replaces the first n-1 dims of vec with the log of abs difference operator. - - Args: - vec: numpy array of floats with size larger than 'n' - signs: Optional numpy array of bools with the same size as vec in case one - needs to compute partial differences vec and signs jointly describe a - vector of real numbers' sign and abs in log scale. - n: Optonal upper bound on number of differences to compute. If None, all - differences are computed. - - Returns: - The first n-1 dimension of vec and signs will store the log-abs and sign of - the difference. - - Raises: - ValueError: If input is malformed. - """ - - if vec.shape != signs.shape: - raise ValueError('Shape of vec and signs do not match.') - if signs.dtype != bool: - raise ValueError('signs must be of type bool') - if n is None: - n = np.max(vec.shape) - 1 - else: - assert np.max(vec.shape) >= n + 1 - for j in range(0, n, 1): - if signs[j] == signs[j + 1]: # When the signs are the same - # if the signs are both positive, then we can just use the standard one - signs[j], vec[j] = _log_sub_sign(vec[j + 1], vec[j]) - # otherwise, we do that but toggle the sign - if not signs[j + 1]: - signs[j] = ~signs[j] - else: # When the signs are different. - vec[j] = _log_add(vec[j], vec[j + 1]) - signs[j] = signs[j + 1] - - -def _get_forward_diffs(fun: Callable[[float], float], - n: int) -> Tuple[np.ndarray, np.ndarray]: - """Computes up to nth order forward difference evaluated at 0. - - See Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf - - Args: - fun: Function to compute forward differences of. - n: Number of differences to compute. - - Returns: - Pair (deltas, signs_deltas) of the log deltas and their signs. - """ - func_vec = np.zeros(n + 3) - signs_func_vec = np.ones(n + 3, dtype=bool) - - # ith coordinate of deltas stores log(abs(ith order discrete derivative)) - deltas = np.zeros(n + 2) - signs_deltas = np.zeros(n + 2, dtype=bool) - for i in range(1, n + 3, 1): - func_vec[i] = fun(1.0 * (i - 1)) - for i in range(0, n + 2, 1): - # Diff in log scale - _stable_inplace_diff_in_log(func_vec, signs_func_vec, n=n + 2 - i) - deltas[i] = func_vec[0] - signs_deltas[i] = signs_func_vec[0] - return deltas, signs_deltas - - -def _compute_log_a(q: float, noise_multiplier: float, - alpha: Union[int, float]) -> float: - if float(alpha).is_integer(): - return _compute_log_a_int(q, noise_multiplier, int(alpha)) - else: - return _compute_log_a_frac(q, noise_multiplier, alpha) - - -def _compute_rdp_poisson_subsampled_gaussian( - q: float, noise_multiplier: float, - orders: Sequence[float]) -> Union[float, np.ndarray]: - """Computes RDP of the Poisson sampled Gaussian mechanism. - - Args: - q: The sampling rate. - noise_multiplier: The ratio of the standard deviation of the Gaussian noise - to the l2-sensitivity of the function to which it is added. - orders: An array of RDP orders. - - Returns: - The RDPs at all orders. Can be `np.inf`. - """ - - def compute_one_order(q, alpha): - if np.isinf(alpha) or noise_multiplier == 0: - return np.inf - - if q == 0: - return 0 - - if q == 1.: - return alpha / (2 * noise_multiplier**2) - - return _compute_log_a(q, noise_multiplier, alpha) / (alpha - 1) - - return np.array([compute_one_order(q, order) for order in orders]) - - -def _compute_rdp_sample_wor_gaussian( - q: float, noise_multiplier: float, - orders: Sequence[float]) -> Union[float, np.ndarray]: - """Computes RDP of Gaussian mechanism using sampling without replacement. - - This function applies to the following schemes: - 1. Sampling w/o replacement: Sample a uniformly random subset of size m = q*n. - 2. ``Replace one data point'' version of differential privacy, i.e., n is - considered public information. - - Reference: Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf (A strengthened - version applies subsampled-Gaussian mechanism.) - - Wang, Balle, Kasiviswanathan. "Subsampled Renyi Differential Privacy and - Analytical Moments Accountant." AISTATS'2019. - - Args: - q: The sampling proportion = m / n. Assume m is an integer <= n. - noise_multiplier: The ratio of the standard deviation of the Gaussian noise - to the l2-sensitivity of the function to which it is added. - orders: An array of RDP orders. - - Returns: - The RDPs at all orders, can be np.inf. - """ - return np.array([ - _compute_rdp_sample_wor_gaussian_scalar(q, noise_multiplier, order) - for order in orders - ]) - - -def _compute_rdp_sample_wor_gaussian_scalar(q: float, sigma: float, - alpha: Union[float, int]) -> float: - """Compute RDP of the Sampled Gaussian mechanism at order alpha. - - Args: - q: The sampling proportion = m / n. Assume m is an integer <= n. - sigma: The std of the additive Gaussian noise. - alpha: The order at which RDP is computed. - - Returns: - RDP at alpha, can be np.inf. - """ - - assert (q <= 1) and (q >= 0) and (alpha >= 1) - - if q == 0: - return 0 - - if q == 1.: - return alpha / (2 * sigma**2) - - if np.isinf(alpha): - return np.inf - - if float(alpha).is_integer(): - return _compute_rdp_sample_wor_gaussian_int(q, sigma, int(alpha)) / ( - alpha - 1) - else: - # When alpha not an integer, we apply Corollary 10 of [WBK19] to interpolate - # the CGF and obtain an upper bound - alpha_f = math.floor(alpha) - alpha_c = math.ceil(alpha) - - x = _compute_rdp_sample_wor_gaussian_int(q, sigma, alpha_f) - y = _compute_rdp_sample_wor_gaussian_int(q, sigma, alpha_c) - t = alpha - alpha_f - return ((1 - t) * x + t * y) / (alpha - 1) - - -def _compute_rdp_sample_wor_gaussian_int(q: float, sigma: float, - alpha: int) -> float: - """Compute log(A_alpha) for integer alpha, subsampling without replacement. - - When alpha is smaller than max_alpha, compute the bound Theorem 27 exactly, - otherwise compute the bound with Stirling approximation. - - Args: - q: The sampling proportion = m / n. Assume m is an integer <= n. - sigma: The std of the additive Gaussian noise. - alpha: The order at which RDP is computed. - - Returns: - RDP at alpha, can be np.inf. - """ - - max_alpha = 256 - - if np.isinf(alpha): - return np.inf - elif alpha == 1: - return 0 - - def cgf(x): - # Return rdp(x+1)*x, the rdp of Gaussian mechanism is alpha/(2*sigma**2) - return x * 1.0 * (x + 1) / (2.0 * sigma**2) - - def func(x): - # Return the rdp of Gaussian mechanism - return 1.0 * x / (2.0 * sigma**2) - - # Initialize with 1 in the log space. - log_a = 0 - # Calculates the log term when alpha = 2 - log_f2m1 = func(2.0) + np.log(1 - np.exp(-func(2.0))) - if alpha <= max_alpha: - # We need forward differences of exp(cgf) - # The following line is the numerically stable way of implementing it. - # The output is in polar form with logarithmic magnitude - deltas, _ = _get_forward_diffs(cgf, alpha) - # Compute the bound exactly requires book keeping of O(alpha**2) - - for i in range(2, alpha + 1): - if i == 2: - s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum( - np.log(4) + log_f2m1, - func(2.0) + np.log(2)) - elif i > 2: - delta_lo = deltas[int(2 * np.floor(i / 2.0)) - 1] - delta_hi = deltas[int(2 * np.ceil(i / 2.0)) - 1] - s = np.log(4) + 0.5 * (delta_lo + delta_hi) - s = np.minimum(s, np.log(2) + cgf(i - 1)) - s += i * np.log(q) + _log_comb(alpha, i) - log_a = _log_add(log_a, s) - return float(log_a) - else: - # Compute the bound with stirling approximation. Everything is O(x) now. - for i in range(2, alpha + 1): - if i == 2: - s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum( - np.log(4) + log_f2m1, - func(2.0) + np.log(2)) - else: - s = np.log(2) + cgf(i - 1) + i * np.log(q) + _log_comb(alpha, i) - log_a = _log_add(log_a, s) - - return log_a - - -def _effective_gaussian_noise_multiplier( - event: dp_event.DpEvent) -> Optional[float]: - """Determines the effective noise multiplier of nested structure of Gaussians. - - A series of Gaussian queries on the same data can be reexpressed as a single - query with pre- and post- processing. For details, see section 3 of - https://arxiv.org/pdf/1812.06210.pdf. - - Args: - event: A `dp_event.DpEvent`. In order for conversion to be successful it - must consist of a single `dp_event.GaussianDpEvent`, or a nested structure - of `dp_event.ComposedDpEvent` and/or `dp_event.SelfComposedDpEvent` - bottoming out in `dp_event.GaussianDpEvent`s. - - Returns: - The noise multiplier of the equivalent `dp_event.GaussianDpEvent`, or None - if the input event was not a `dp_event.GaussianDpEvent` or a nested - structure of `dp_event.ComposedDpEvent` and/or - `dp_event.SelfComposedDpEvent` bottoming out in `dp_event.GaussianDpEvent`s. - """ - if isinstance(event, dp_event.GaussianDpEvent): - return event.noise_multiplier - elif isinstance(event, dp_event.ComposedDpEvent): - sum_sigma_inv_sq = 0 - for e in event.events: - sigma = _effective_gaussian_noise_multiplier(e) - if sigma is None: - return None - sum_sigma_inv_sq += sigma**-2 - return sum_sigma_inv_sq**-0.5 - elif isinstance(event, dp_event.SelfComposedDpEvent): - sigma = _effective_gaussian_noise_multiplier(event.event) - return None if sigma is None else (event.count * sigma**-2)**-0.5 - else: - return None - - -def _compute_rdp_single_epoch_tree_aggregation( - noise_multiplier: float, step_counts: Union[int, Sequence[int]], - orders: Sequence[float]) -> Union[float, np.ndarray]: - """Computes RDP of the Tree Aggregation Protocol for Gaussian Mechanism. - - This function implements the accounting when the tree is periodically - restarted and no record occurs twice across all trees. See appendix D of - "Practical and Private (Deep) Learning without Sampling or Shuffling" - https://arxiv.org/abs/2103.00039. - - Args: - noise_multiplier: A non-negative float representing the ratio of the - standard deviation of the Gaussian noise to the l2-sensitivity of the - function to which it is added. - step_counts: A scalar or a list of non-negative integers representing the - number of steps per epoch (between two restarts). - orders: An array of RDP orders. - - Returns: - The RDPs at all orders. Can be `np.inf`. - """ - if noise_multiplier < 0: - raise ValueError( - f'noise_multiplier must be non-negative. Got {noise_multiplier}.') - if noise_multiplier == 0: - return np.inf - - if not step_counts: - raise ValueError( - 'steps_list must be a non-empty list, or a non-zero scalar. Got ' - f'{step_counts}.') - - if np.isscalar(step_counts): - step_counts = [step_counts] - - for steps in step_counts: - if steps < 0: - raise ValueError(f'Steps must be non-negative. Got {step_counts}') - - max_depth = math.ceil(math.log2(max(step_counts) + 1)) - return np.array([a * max_depth / (2 * noise_multiplier**2) for a in orders]) - - -class RdpAccountant(privacy_accountant.PrivacyAccountant): - """Privacy accountant that uses Renyi differential privacy.""" - - def __init__( - self, - orders: Optional[Sequence[float]] = None, - neighboring_relation: NeighborRel = NeighborRel.ADD_OR_REMOVE_ONE, - ): - super().__init__(neighboring_relation) - if orders is None: - # Default orders chosen to give good coverage for Gaussian mechanism in - # the privacy regime of interest. In the future, more orders might be - # added, in particular, fractional orders between 1.0 and 10.0 or so. - orders = [ - 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 20, 24, 28, 32, 48, 64, 128, - 256, 512, 1024 - ] - self._orders = np.array(orders) - self._rdp = np.zeros_like(orders, dtype=np.float64) - - def supports(self, event: dp_event.DpEvent) -> bool: - return self._maybe_compose(event, 0, False) - - def _compose(self, event: dp_event.DpEvent, count: int = 1): - self._maybe_compose(event, count, True) - - def _maybe_compose(self, event: dp_event.DpEvent, count: int, - do_compose: bool) -> bool: - """Traverses `event` and performs composition if `do_compose` is True. - - If `do_compose` is False, can be used to check whether composition is - supported. - - Args: - event: A `DpEvent` to process. - count: The number of times to compose the event. - do_compose: Whether to actually perform the composition. - - Returns: - True if event is supported, otherwise False. - """ - - if isinstance(event, dp_event.NoOpDpEvent): - return True - elif isinstance(event, dp_event.NonPrivateDpEvent): - if do_compose: - self._rdp += np.inf - return True - elif isinstance(event, dp_event.SelfComposedDpEvent): - return self._maybe_compose(event.event, event.count * count, do_compose) - elif isinstance(event, dp_event.ComposedDpEvent): - return all( - self._maybe_compose(e, count, do_compose) for e in event.events) - elif isinstance(event, dp_event.GaussianDpEvent): - if do_compose: - self._rdp += count * _compute_rdp_poisson_subsampled_gaussian( - q=1.0, noise_multiplier=event.noise_multiplier, orders=self._orders) - return True - elif isinstance(event, dp_event.PoissonSampledDpEvent): - if self._neighboring_relation is not NeighborRel.ADD_OR_REMOVE_ONE: - return False - gaussian_noise_multiplier = _effective_gaussian_noise_multiplier( - event.event) - if gaussian_noise_multiplier is None: - return False - if do_compose: - self._rdp += count * _compute_rdp_poisson_subsampled_gaussian( - q=event.sampling_probability, - noise_multiplier=gaussian_noise_multiplier, - orders=self._orders) - return True - elif isinstance(event, dp_event.SampledWithoutReplacementDpEvent): - if self._neighboring_relation is not NeighborRel.REPLACE_ONE: - return False - gaussian_noise_multiplier = _effective_gaussian_noise_multiplier( - event.event) - if gaussian_noise_multiplier is None: - return False - if do_compose: - self._rdp += count * _compute_rdp_sample_wor_gaussian( - q=event.sample_size / event.source_dataset_size, - noise_multiplier=gaussian_noise_multiplier, - orders=self._orders) - return True - elif isinstance(event, dp_event.SingleEpochTreeAggregationDpEvent): - if self._neighboring_relation is not NeighborRel.REPLACE_SPECIAL: - return False - if do_compose: - self._rdp += count * _compute_rdp_single_epoch_tree_aggregation( - event.noise_multiplier, event.step_counts, self._orders) - return True - else: - # Unsupported event (including `UnsupportedDpEvent`). - return False - - def get_epsilon(self, target_delta: float) -> float: - return _compute_epsilon(self._orders, self._rdp, target_delta) - - def get_delta(self, target_epsilon: float) -> float: - return _compute_delta(self._orders, self._rdp, target_epsilon) diff --git a/tensorflow_privacy/privacy/analysis/rdp_privacy_accountant_test.py b/tensorflow_privacy/privacy/analysis/rdp_privacy_accountant_test.py deleted file mode 100644 index b169f71..0000000 --- a/tensorflow_privacy/privacy/analysis/rdp_privacy_accountant_test.py +++ /dev/null @@ -1,465 +0,0 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for rdp_privacy_accountant.""" - -import math -import sys - -from absl.testing import absltest -from absl.testing import parameterized -import mpmath -import numpy as np - -from tensorflow_privacy.privacy.analysis import dp_event -from tensorflow_privacy.privacy.analysis import privacy_accountant -from tensorflow_privacy.privacy.analysis import privacy_accountant_test -from tensorflow_privacy.privacy.analysis import rdp_privacy_accountant - - -def _get_test_rdp(event, count=1): - accountant = rdp_privacy_accountant.RdpAccountant(orders=[2.71828]) - accountant.compose(event, count) - return accountant._rdp[0] - - -def _log_float_mp(x): - # Convert multi-precision input to float log space. - if x >= sys.float_info.min: - return float(mpmath.log(x)) - else: - return -np.inf - - -def _compute_a_mp(sigma, q, alpha): - """Compute A_alpha for arbitrary alpha by numerical integration.""" - - def mu0(x): - return mpmath.npdf(x, mu=0, sigma=sigma) - - def _mu_over_mu0(x, q, sigma): - return (1 - q) + q * mpmath.exp((2 * x - 1) / (2 * sigma**2)) - - def a_alpha_fn(z): - return mu0(z) * _mu_over_mu0(z, q, sigma)**alpha - - bounds = (-mpmath.inf, mpmath.inf) - a_alpha, _ = mpmath.quad(a_alpha_fn, bounds, error=True, maxdegree=8) - return a_alpha - - -def _compose_trees(noise_multiplier, step_counts, orders): - accountant = rdp_privacy_accountant.RdpAccountant( - orders, privacy_accountant.NeighboringRelation.REPLACE_SPECIAL) - accountant.compose( - dp_event.ComposedDpEvent([ - dp_event.SingleEpochTreeAggregationDpEvent(noise_multiplier, - step_count) - for step_count in step_counts - ])) - return accountant - - -def _compose_trees_single_epoch(noise_multiplier, step_counts, orders): - accountant = rdp_privacy_accountant.RdpAccountant( - orders, privacy_accountant.NeighboringRelation.REPLACE_SPECIAL) - accountant.compose( - dp_event.SingleEpochTreeAggregationDpEvent(noise_multiplier, step_counts)) - return accountant - - -class RdpPrivacyAccountantTest(privacy_accountant_test.PrivacyAccountantTest, - parameterized.TestCase): - - def _make_test_accountants(self): - return [ - rdp_privacy_accountant.RdpAccountant( - [2.0], privacy_accountant.NeighboringRelation.ADD_OR_REMOVE_ONE), - rdp_privacy_accountant.RdpAccountant( - [2.0], privacy_accountant.NeighboringRelation.REPLACE_ONE), - rdp_privacy_accountant.RdpAccountant( - [2.0], privacy_accountant.NeighboringRelation.REPLACE_SPECIAL) - ] - - def test_supports(self): - aor_accountant = rdp_privacy_accountant.RdpAccountant( - [2.0], privacy_accountant.NeighboringRelation.ADD_OR_REMOVE_ONE) - ro_accountant = rdp_privacy_accountant.RdpAccountant( - [2.0], privacy_accountant.NeighboringRelation.REPLACE_ONE) - - event = dp_event.GaussianDpEvent(1.0) - self.assertTrue(aor_accountant.supports(event)) - self.assertTrue(ro_accountant.supports(event)) - - event = dp_event.SelfComposedDpEvent(dp_event.GaussianDpEvent(1.0), 6) - self.assertTrue(aor_accountant.supports(event)) - self.assertTrue(ro_accountant.supports(event)) - - event = dp_event.ComposedDpEvent( - [dp_event.GaussianDpEvent(1.0), - dp_event.GaussianDpEvent(2.0)]) - self.assertTrue(aor_accountant.supports(event)) - self.assertTrue(ro_accountant.supports(event)) - - event = dp_event.PoissonSampledDpEvent(0.1, dp_event.GaussianDpEvent(1.0)) - self.assertTrue(aor_accountant.supports(event)) - self.assertFalse(ro_accountant.supports(event)) - - composed_gaussian = dp_event.ComposedDpEvent( - [dp_event.GaussianDpEvent(1.0), - dp_event.GaussianDpEvent(2.0)]) - event = dp_event.PoissonSampledDpEvent(0.1, composed_gaussian) - self.assertTrue(aor_accountant.supports(event)) - self.assertFalse(ro_accountant.supports(event)) - - event = dp_event.SampledWithoutReplacementDpEvent( - 1000, 10, dp_event.GaussianDpEvent(1.0)) - self.assertFalse(aor_accountant.supports(event)) - self.assertTrue(ro_accountant.supports(event)) - - event = dp_event.SampledWithoutReplacementDpEvent(1000, 10, - composed_gaussian) - self.assertFalse(aor_accountant.supports(event)) - self.assertTrue(ro_accountant.supports(event)) - - event = dp_event.SampledWithReplacementDpEvent( - 1000, 10, dp_event.GaussianDpEvent(1.0)) - self.assertFalse(aor_accountant.supports(event)) - self.assertFalse(ro_accountant.supports(event)) - - def test_rdp_composition(self): - base_event = dp_event.GaussianDpEvent(3.14159) - base_rdp = _get_test_rdp(base_event) - - rdp_with_count = _get_test_rdp(base_event, count=6) - self.assertAlmostEqual(rdp_with_count, base_rdp * 6) - - rdp_with_self_compose = _get_test_rdp( - dp_event.SelfComposedDpEvent(base_event, 6)) - self.assertAlmostEqual(rdp_with_self_compose, base_rdp * 6) - - rdp_with_self_compose_and_count = _get_test_rdp( - dp_event.SelfComposedDpEvent(base_event, 2), count=3) - self.assertAlmostEqual(rdp_with_self_compose_and_count, base_rdp * 6) - - rdp_with_compose = _get_test_rdp(dp_event.ComposedDpEvent([base_event] * 6)) - self.assertAlmostEqual(rdp_with_compose, base_rdp * 6) - - rdp_with_compose_and_self_compose = _get_test_rdp( - dp_event.ComposedDpEvent([ - dp_event.SelfComposedDpEvent(base_event, 1), - dp_event.SelfComposedDpEvent(base_event, 2), - dp_event.SelfComposedDpEvent(base_event, 3) - ])) - self.assertAlmostEqual(rdp_with_compose_and_self_compose, base_rdp * 6) - - base_event_2 = dp_event.GaussianDpEvent(1.61803) - base_rdp_2 = _get_test_rdp(base_event_2) - rdp_with_heterogeneous_compose = _get_test_rdp( - dp_event.ComposedDpEvent([base_event, base_event_2])) - self.assertAlmostEqual(rdp_with_heterogeneous_compose, - base_rdp + base_rdp_2) - - def test_zero_poisson_sample(self): - accountant = rdp_privacy_accountant.RdpAccountant([3.14159]) - accountant.compose( - dp_event.PoissonSampledDpEvent(0, dp_event.GaussianDpEvent(1.0))) - self.assertEqual(accountant.get_epsilon(1e-10), 0) - self.assertEqual(accountant.get_delta(1e-10), 0) - - def test_zero_fixed_batch_sample(self): - accountant = rdp_privacy_accountant.RdpAccountant( - [3.14159], privacy_accountant.NeighboringRelation.REPLACE_ONE) - accountant.compose( - dp_event.SampledWithoutReplacementDpEvent( - 1000, 0, dp_event.GaussianDpEvent(1.0))) - self.assertEqual(accountant.get_epsilon(1e-10), 0) - self.assertEqual(accountant.get_delta(1e-10), 0) - - def test_epsilon_non_private_gaussian(self): - accountant = rdp_privacy_accountant.RdpAccountant([3.14159]) - accountant.compose(dp_event.GaussianDpEvent(0)) - self.assertEqual(accountant.get_epsilon(1e-1), np.inf) - - def test_compute_rdp_gaussian(self): - alpha = 3.14159 - sigma = 2.71828 - event = dp_event.GaussianDpEvent(sigma) - accountant = rdp_privacy_accountant.RdpAccountant(orders=[alpha]) - accountant.compose(event) - self.assertAlmostEqual(accountant._rdp[0], alpha / (2 * sigma**2)) - - def test_compute_rdp_multi_gaussian(self): - alpha = 3.14159 - sigma1, sigma2 = 2.71828, 6.28319 - - rdp1 = alpha / (2 * sigma1**2) - rdp2 = alpha / (2 * sigma2**2) - rdp = rdp1 + rdp2 - - accountant = rdp_privacy_accountant.RdpAccountant(orders=[alpha]) - accountant.compose( - dp_event.PoissonSampledDpEvent( - 1.0, - dp_event.ComposedDpEvent([ - dp_event.GaussianDpEvent(sigma1), - dp_event.GaussianDpEvent(sigma2) - ]))) - self.assertAlmostEqual(accountant._rdp[0], rdp) - - def test_effective_gaussian_noise_multiplier(self): - np.random.seed(0xBAD5EED) - sigmas = np.random.uniform(size=(4,)) - - event = dp_event.ComposedDpEvent([ - dp_event.GaussianDpEvent(sigmas[0]), - dp_event.SelfComposedDpEvent(dp_event.GaussianDpEvent(sigmas[1]), 3), - dp_event.ComposedDpEvent([ - dp_event.GaussianDpEvent(sigmas[2]), - dp_event.GaussianDpEvent(sigmas[3]) - ]) - ]) - - sigma = rdp_privacy_accountant._effective_gaussian_noise_multiplier(event) - multi_sigmas = list(sigmas) + [sigmas[1]] * 2 - expected = sum(s**-2 for s in multi_sigmas)**-0.5 - self.assertAlmostEqual(sigma, expected) - - def test_compute_rdp_poisson_sampled_gaussian(self): - orders = [1.5, 2.5, 5, 50, 100, np.inf] - noise_multiplier = 2.5 - sampling_probability = 0.01 - count = 50 - event = dp_event.SelfComposedDpEvent( - dp_event.PoissonSampledDpEvent( - sampling_probability, dp_event.GaussianDpEvent(noise_multiplier)), - count) - accountant = rdp_privacy_accountant.RdpAccountant(orders=orders) - accountant.compose(event) - self.assertTrue( - np.allclose( - accountant._rdp, [ - 6.5007e-04, 1.0854e-03, 2.1808e-03, 2.3846e-02, 1.6742e+02, - np.inf - ], - rtol=1e-4)) - - def test_compute_epsilon_delta_pure_dp(self): - orders = range(2, 33) - rdp = [1.1 for o in orders] # Constant corresponds to pure DP. - - epsilon = rdp_privacy_accountant._compute_epsilon(orders, rdp, delta=1e-5) - # Compare with epsilon computed by hand. - self.assertAlmostEqual(epsilon, 1.32783806176) - - delta = rdp_privacy_accountant._compute_delta( - orders, rdp, epsilon=1.32783806176) - self.assertAlmostEqual(delta, 1e-5) - - def test_compute_epsilon_delta_gaussian(self): - orders = [0.001 * i for i in range(1000, 100000)] - - # noise multiplier is chosen to obtain exactly (1,1e-6)-DP. - rdp = rdp_privacy_accountant._compute_rdp_poisson_subsampled_gaussian( - 1, 4.530877117, orders) - - eps = rdp_privacy_accountant._compute_epsilon(orders, rdp, delta=1e-6) - self.assertAlmostEqual(eps, 1) - - delta = rdp_privacy_accountant._compute_delta(orders, rdp, epsilon=1) - self.assertAlmostEqual(delta, 1e-6) - - params = ({ - 'q': 1e-7, - 'sigma': .1, - 'order': 1.01 - }, { - 'q': 1e-6, - 'sigma': .1, - 'order': 256 - }, { - 'q': 1e-5, - 'sigma': .1, - 'order': 256.1 - }, { - 'q': 1e-6, - 'sigma': 1, - 'order': 27 - }, { - 'q': 1e-4, - 'sigma': 1., - 'order': 1.5 - }, { - 'q': 1e-3, - 'sigma': 1., - 'order': 2 - }, { - 'q': .01, - 'sigma': 10, - 'order': 20 - }, { - 'q': .1, - 'sigma': 100, - 'order': 20.5 - }, { - 'q': .99, - 'sigma': .1, - 'order': 256 - }, { - 'q': .999, - 'sigma': 100, - 'order': 256.1 - }) - - # pylint:disable=undefined-variable - @parameterized.parameters(p for p in params) - def test_compute_log_a_equals_mp(self, q, sigma, order): - # Compare the cheap computation of log(A) with an expensive, multi-precision - # computation. - log_a = rdp_privacy_accountant._compute_log_a(q, sigma, order) - log_a_mp = _log_float_mp(_compute_a_mp(sigma, q, order)) - np.testing.assert_allclose(log_a, log_a_mp, rtol=1e-4) - - def test_delta_bounds_gaussian(self): - # Compare the optimal bound for Gaussian with the one derived from RDP. - # Also compare the RDP upper bound with the "standard" upper bound. - orders = [0.1 * x for x in range(10, 505)] - eps_vec = [0.1 * x for x in range(500)] - rdp = rdp_privacy_accountant._compute_rdp_poisson_subsampled_gaussian( - 1, 1, orders) - for eps in eps_vec: - delta = rdp_privacy_accountant._compute_delta(orders, rdp, epsilon=eps) - # For comparison, we compute the optimal guarantee for Gaussian - # using https://arxiv.org/abs/1805.06530 Theorem 8 (in v2). - delta0 = math.erfc((eps - .5) / math.sqrt(2)) / 2 - delta0 = delta0 - math.exp(eps) * math.erfc((eps + .5) / math.sqrt(2)) / 2 - self.assertLessEqual(delta0, delta + 1e-300) # need tolerance 10^-300 - - # Compute the "standard" upper bound, which should be an upper bound. - # Note, if orders is too sparse, this will NOT be an upper bound. - if eps >= 0.5: - delta1 = math.exp(-0.5 * (eps - 0.5)**2) - else: - delta1 = 1 - self.assertLessEqual(delta, delta1 + 1e-300) - - def test_epsilon_delta_consistency(self): - orders = range(2, 50) # Large range of orders (helps test for overflows). - for q in [0, 0.01, 0.1, 0.8, 1.]: - for multiplier in [0.0, 0.1, 1., 10., 100.]: - event = dp_event.PoissonSampledDpEvent( - q, dp_event.GaussianDpEvent(multiplier)) - accountant = rdp_privacy_accountant.RdpAccountant(orders) - accountant.compose(event) - for delta in [.99, .9, .1, .01, 1e-3, 1e-5, 1e-9, 1e-12]: - epsilon = accountant.get_epsilon(delta) - delta2 = accountant.get_delta(epsilon) - if np.isposinf(epsilon): - self.assertEqual(delta2, 1.0) - elif epsilon == 0: - self.assertLessEqual(delta2, delta) - else: - self.assertAlmostEqual(delta, delta2) - - @parameterized.named_parameters( - ('add_remove', privacy_accountant.NeighboringRelation.ADD_OR_REMOVE_ONE), - ('replace', privacy_accountant.NeighboringRelation.REPLACE_ONE)) - def test_tree_wrong_neighbor_rel(self, neighboring_relation): - event = dp_event.SingleEpochTreeAggregationDpEvent(1.0, 1) - accountant = rdp_privacy_accountant.RdpAccountant( - neighboring_relation=neighboring_relation) - self.assertFalse(accountant.supports(event)) - - @parameterized.named_parameters(('eps20', 1.13, 19.74), ('eps2', 8.83, 2.04)) - def test_compute_eps_tree(self, noise_multiplier, eps): - orders = [1 + x / 10 for x in range(1, 100)] + list(range(12, 64)) - # This test is based on the StackOverflow setting in "Practical and - # Private (Deep) Learning without Sampling or Shuffling". The calculated - # epsilon could be better as the method in this package keeps improving. - step_counts, target_delta = 1600, 1e-6 - new_eps = _compose_trees_single_epoch(noise_multiplier, step_counts, - orders).get_epsilon(target_delta) - self.assertLess(new_eps, eps) - - @parameterized.named_parameters( - ('restart4', [400] * 4), - ('restart2', [800] * 2), - ('adaptive', [10, 400, 400, 400, 390]), - ) - def test_compose_tree_rdp(self, step_counts): - noise_multiplier, orders = 0.1, [1] - - def get_rdp(step_count): - return _compose_trees_single_epoch(noise_multiplier, [step_count], - orders)._rdp[0] - - rdp_summed = sum(get_rdp(step_count) for step_count in step_counts) - rdp_composed = _compose_trees(noise_multiplier, step_counts, orders)._rdp[0] - self.assertTrue(np.allclose(rdp_composed, rdp_summed, rtol=1e-12)) - - def test_single_epoch_multi_tree_rdp(self): - noise_multiplier, orders = 0.1, [1] - step_counts = [10, 40, 30, 20] - single_rdp = _compose_trees_single_epoch(noise_multiplier, step_counts, - orders)._rdp[0] - - max_rdp = max( - _compose_trees_single_epoch(noise_multiplier, step_count, - orders)._rdp[0] - for step_count in step_counts) - - self.assertEqual(single_rdp, max_rdp) - - @parameterized.named_parameters( - ('restart4', [400] * 4), - ('restart2', [800] * 2), - ('adaptive', [10, 400, 400, 400, 390]), - ) - def test_compute_eps_tree_decreasing(self, step_counts): - # Test privacy epsilon decreases with noise multiplier increasing when - # keeping other parameters the same. - orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64)) - target_delta = 1e-6 - prev_eps = np.inf - for noise_multiplier in [0.1 * x for x in range(1, 100, 5)]: - accountant = _compose_trees(noise_multiplier, step_counts, orders) - eps = accountant.get_epsilon(target_delta=target_delta) - self.assertLess(eps, prev_eps) - prev_eps = eps - - @parameterized.named_parameters( - ('negative_noise', -1, [3]), - ('negative_steps', 1, [-3]), - ) - def test_compute_rdp_tree_restart_raise(self, noise_multiplier, step_counts): - with self.assertRaisesRegex(ValueError, 'non-negative'): - _compose_trees(noise_multiplier, step_counts, orders=[1]) - - @parameterized.named_parameters( - ('t100n0.1', 100, 0.1), - ('t1000n0.01', 1000, 0.01), - ) - def test_no_tree_no_sampling(self, total_steps, noise_multiplier): - orders = [1 + x / 10 for x in range(1, 100)] + list(range(12, 64)) - tree_rdp = _compose_trees(noise_multiplier, [1] * total_steps, orders)._rdp - accountant = rdp_privacy_accountant.RdpAccountant(orders) - event = dp_event.SelfComposedDpEvent( - dp_event.GaussianDpEvent(noise_multiplier), total_steps) - accountant.compose(event) - base_rdp = accountant._rdp - self.assertTrue(np.allclose(tree_rdp, base_rdp, rtol=1e-12)) - - -if __name__ == '__main__': - absltest.main()