Automated rollback of commit b16a0abf1c

PiperOrigin-RevId: 458478847
This commit is contained in:
A. Unique TensorFlower 2022-07-01 08:52:06 -07:00
parent b16a0abf1c
commit e32766cc73
24 changed files with 158 additions and 109 deletions

View file

@ -13,7 +13,10 @@ py_library(
name = "compute_dp_sgd_privacy_lib",
srcs = ["compute_dp_sgd_privacy_lib.py"],
srcs_version = "PY3",
deps = ["@com_google_differential_py//python/dp_accounting"],
deps = [
"@com_google_differential_py//python/dp_accounting:dp_event",
"@com_google_differential_py//python/dp_accounting/rdp:rdp_privacy_accountant",
],
)
py_binary(
@ -42,7 +45,10 @@ py_binary(
py_library(
name = "compute_noise_from_budget_lib",
srcs = ["compute_noise_from_budget_lib.py"],
deps = ["@com_google_differential_py//python/dp_accounting"],
deps = [
"@com_google_differential_py//python/dp_accounting:dp_event",
"@com_google_differential_py//python/dp_accounting/rdp:rdp_privacy_accountant",
],
)
py_test(
@ -61,7 +67,11 @@ py_library(
srcs = ["rdp_accountant.py"],
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = ["@com_google_differential_py//python/dp_accounting"],
deps = [
"@com_google_differential_py//python/dp_accounting:dp_event",
"@com_google_differential_py//python/dp_accounting:privacy_accountant",
"@com_google_differential_py//python/dp_accounting/rdp:rdp_privacy_accountant",
],
)
py_test(
@ -109,7 +119,9 @@ py_test(
python_version = "PY3",
srcs_version = "PY3",
deps = [
":rdp_accountant",
":tree_aggregation_accountant",
"@com_google_differential_py//python/dp_accounting",
"@com_google_differential_py//python/dp_accounting:dp_event",
"@com_google_differential_py//python/dp_accounting/rdp:rdp_privacy_accountant",
],
)

View file

@ -17,18 +17,18 @@
import math
from absl import app
from com_google_differential_py.python.dp_accounting
from com_google_differential_py.python.dp_accounting import dp_event
from com_google_differential_py.python.dp_accounting.rdp import rdp_privacy_accountant
def apply_dp_sgd_analysis(q, sigma, steps, orders, delta):
"""Compute and print results of DP-SGD analysis."""
accountant = dp_accounting.rdp.RdpAccountant(orders)
accountant = rdp_privacy_accountant.RdpAccountant(orders)
event = dp_accounting.SelfComposedDpEvent(
dp_accounting.PoissonSampledDpEvent(q,
dp_accounting.GaussianDpEvent(sigma)),
steps)
event = dp_event.SelfComposedDpEvent(
dp_event.PoissonSampledDpEvent(q, dp_event.GaussianDpEvent(sigma)), steps)
accountant.compose(event)

View file

@ -17,18 +17,18 @@
import math
from absl import app
from com_google_differential_py.python.dp_accounting
from scipy import optimize
from com_google_differential_py.python.dp_accounting import dp_event
from com_google_differential_py.python.dp_accounting.rdp import rdp_privacy_accountant
def apply_dp_sgd_analysis(q, sigma, steps, orders, delta):
"""Compute and print results of DP-SGD analysis."""
accountant = dp_accounting.rdp.RdpAccountant(orders)
event = dp_accounting.SelfComposedDpEvent(
dp_accounting.PoissonSampledDpEvent(q,
dp_accounting.GaussianDpEvent(sigma)),
steps)
accountant = rdp_privacy_accountant.RdpAccountant(orders)
event = dp_event.SelfComposedDpEvent(
dp_event.PoissonSampledDpEvent(q, dp_event.GaussianDpEvent(sigma)), steps)
accountant.compose(event)
return accountant.get_epsilon_and_optimal_order(delta)

View file

@ -41,9 +41,12 @@ The example code would be:
eps, _, opt_order = rdp_accountant.get_privacy_spent(rdp, target_delta=delta)
"""
from com_google_differential_py.python.dp_accounting
import numpy as np
from com_google_differential_py.python.dp_accounting import dp_event
from com_google_differential_py.python.dp_accounting import privacy_accountant
from com_google_differential_py.python.dp_accounting.rdp import rdp_privacy_accountant
def _compute_rdp_from_event(orders, event, count):
"""Computes RDP from a DpEvent using RdpAccountant.
@ -58,14 +61,15 @@ def _compute_rdp_from_event(orders, event, count):
"""
orders_vec = np.atleast_1d(orders)
if isinstance(event, dp_accounting.SampledWithoutReplacementDpEvent):
neighboring_relation = dp_accounting.NeighboringRelation.REPLACE_ONE
elif isinstance(event, dp_accounting.SingleEpochTreeAggregationDpEvent):
neighboring_relation = dp_accounting.NeighboringRelation.REPLACE_SPECIAL
if isinstance(event, dp_event.SampledWithoutReplacementDpEvent):
neighboring_relation = privacy_accountant.NeighboringRelation.REPLACE_ONE
elif isinstance(event, dp_event.SingleEpochTreeAggregationDpEvent):
neighboring_relation = privacy_accountant.NeighboringRelation.REPLACE_SPECIAL
else:
neighboring_relation = dp_accounting.NeighboringRelation.ADD_OR_REMOVE_ONE
neighboring_relation = privacy_accountant.NeighboringRelation.ADD_OR_REMOVE_ONE
accountant = dp_accounting.rdp.RdpAccountant(orders_vec, neighboring_relation)
accountant = rdp_privacy_accountant.RdpAccountant(orders_vec,
neighboring_relation)
accountant.compose(event, count)
rdp = accountant._rdp # pylint: disable=protected-access
@ -92,8 +96,8 @@ def compute_rdp(q, noise_multiplier, steps, orders):
Returns:
The RDPs at all orders. Can be `np.inf`.
"""
event = dp_accounting.PoissonSampledDpEvent(
q, dp_accounting.GaussianDpEvent(noise_multiplier))
event = dp_event.PoissonSampledDpEvent(
q, dp_event.GaussianDpEvent(noise_multiplier))
return _compute_rdp_from_event(orders, event, steps)
@ -125,8 +129,8 @@ def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders):
Returns:
The RDPs at all orders, can be np.inf.
"""
event = dp_accounting.SampledWithoutReplacementDpEvent(
1, q, dp_accounting.GaussianDpEvent(noise_multiplier))
event = dp_event.SampledWithoutReplacementDpEvent(
1, q, dp_event.GaussianDpEvent(noise_multiplier))
return _compute_rdp_from_event(orders, event, steps)
@ -191,7 +195,7 @@ def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
raise ValueError(
"Exactly one out of eps and delta must be None. (None is).")
accountant = dp_accounting.rdp.RdpAccountant(orders)
accountant = rdp_privacy_accountant.RdpAccountant(orders)
accountant._rdp = rdp # pylint: disable=protected-access
if target_eps is not None:

View file

@ -14,10 +14,13 @@
# ==============================================================================
from absl.testing import parameterized
from com_google_differential_py.python.dp_accounting
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import tree_aggregation_accountant
from com_google_differential_py.python.dp_accounting import dp_event
from com_google_differential_py.python.dp_accounting.rdp import rdp_privacy_accountant
class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase):
@ -30,7 +33,8 @@ class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase):
steps_list, target_delta = 1600, 1e-6
rdp = tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, steps_list, orders)
new_eps = dp_accounting.rdp.compute_epsilon(orders, rdp, target_delta)[0]
new_eps = rdp_privacy_accountant.compute_epsilon(orders, rdp,
target_delta)[0]
self.assertLess(new_eps, eps)
@parameterized.named_parameters(
@ -63,7 +67,7 @@ class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase):
for noise_multiplier in [0.1 * x for x in range(1, 100, 5)]:
rdp = tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, steps_list, orders)
eps = dp_accounting.rdp.compute_epsilon(orders, rdp, target_delta)[0]
eps = rdp_privacy_accountant.compute_epsilon(orders, rdp, target_delta)[0]
self.assertLess(eps, prev_eps)
prev_eps = eps
@ -86,9 +90,8 @@ class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase):
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
tree_rdp = tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, [1] * total_steps, orders)
accountant = dp_accounting.rdp.RdpAccountant(orders)
accountant.compose(
dp_accounting.GaussianDpEvent(noise_multiplier), total_steps)
accountant = rdp_privacy_accountant.RdpAccountant(orders)
accountant.compose(dp_event.GaussianDpEvent(noise_multiplier), total_steps)
rdp = accountant._rdp # pylint: disable=protected-access
self.assertAllClose(tree_rdp, rdp, rtol=1e-12)

View file

@ -36,7 +36,7 @@ py_library(
deps = [
":discrete_gaussian_utils",
":dp_query",
"@com_google_differential_py//python/dp_accounting",
"@com_google_differential_py//python/dp_accounting:dp_event",
],
)
@ -59,7 +59,7 @@ py_library(
deps = [
":discrete_gaussian_utils",
":dp_query",
"@com_google_differential_py//python/dp_accounting",
"@com_google_differential_py//python/dp_accounting:dp_event",
],
)
@ -82,7 +82,7 @@ py_library(
deps = [
":dp_query",
":normalized_query",
"@com_google_differential_py//python/dp_accounting",
"@com_google_differential_py//python/dp_accounting:dp_event",
],
)
@ -103,7 +103,7 @@ py_library(
srcs_version = "PY3",
deps = [
":dp_query",
"@com_google_differential_py//python/dp_accounting",
"@com_google_differential_py//python/dp_accounting:dp_event",
],
)
@ -125,7 +125,7 @@ py_library(
srcs_version = "PY3",
deps = [
":dp_query",
"@com_google_differential_py//python/dp_accounting",
"@com_google_differential_py//python/dp_accounting:dp_event",
],
)
@ -167,7 +167,7 @@ py_library(
srcs_version = "PY3",
deps = [
":dp_query",
"@com_google_differential_py//python/dp_accounting",
"@com_google_differential_py//python/dp_accounting:dp_event",
],
)
@ -194,7 +194,7 @@ py_library(
":dp_query",
":gaussian_query",
":quantile_estimator_query",
"@com_google_differential_py//python/dp_accounting",
"@com_google_differential_py//python/dp_accounting:dp_event",
],
)
@ -274,7 +274,7 @@ py_library(
deps = [
":dp_query",
":tree_aggregation",
"@com_google_differential_py//python/dp_accounting",
"@com_google_differential_py//python/dp_accounting:dp_event",
],
)
@ -286,7 +286,7 @@ py_library(
":distributed_discrete_gaussian_query",
":dp_query",
":gaussian_query",
"@com_google_differential_py//python/dp_accounting",
"@com_google_differential_py//python/dp_accounting:dp_event",
],
)

View file

@ -15,11 +15,12 @@
import collections
from com_google_differential_py.python.dp_accounting
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
from tensorflow_privacy.privacy.dp_query import dp_query
from com_google_differential_py.python.dp_accounting import dp_event
class DiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
"""Implements DPQuery for discrete Gaussian sum queries.
@ -83,5 +84,5 @@ class DiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
return tf.ensure_shape(noised_v, v.shape)
result = tf.nest.map_structure(add_noise, sample_state)
event = dp_accounting.UnsupportedDpEvent()
event = dp_event.UnsupportedDpEvent()
return result, global_state, event

View file

@ -15,11 +15,12 @@
import collections
from com_google_differential_py.python.dp_accounting
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
from tensorflow_privacy.privacy.dp_query import dp_query
from com_google_differential_py.python.dp_accounting import dp_event
class DistributedDiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
"""Implements DPQuery for discrete distributed Gaussian sum queries.
@ -107,5 +108,5 @@ class DistributedDiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state):
# Note that by directly returning the aggregate, this assumes that there
# will not be missing local noise shares during execution.
event = dp_accounting.UnsupportedDpEvent()
event = dp_event.UnsupportedDpEvent()
return sample_state, global_state, event

View file

@ -15,11 +15,12 @@
import collections
from com_google_differential_py.python.dp_accounting
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import normalized_query
from com_google_differential_py.python.dp_accounting import dp_event
class DistributedSkellamSumQuery(dp_query.SumAggregationDPQuery):
"""Implements DPQuery interface for discrete distributed sum queries.
@ -126,7 +127,7 @@ class DistributedSkellamSumQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state):
"""The noise was already added locally, therefore just continue."""
event = dp_accounting.UnsupportedDpEvent()
event = dp_event.UnsupportedDpEvent()
return sample_state, global_state, event

View file

@ -16,10 +16,11 @@
import collections
import distutils
from com_google_differential_py.python.dp_accounting
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
from com_google_differential_py.python.dp_accounting import dp_event
class GaussianSumQuery(dp_query.SumAggregationDPQuery):
"""Implements DPQuery interface for Gaussian sum queries.
@ -93,6 +94,6 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
result = tf.nest.map_structure(add_noise, sample_state)
noise_multiplier = global_state.stddev / global_state.l2_norm_clip
event = dp_accounting.GaussianDpEvent(noise_multiplier)
event = dp_event.GaussianDpEvent(noise_multiplier)
return result, global_state, event

View file

@ -15,11 +15,12 @@
import collections
from com_google_differential_py.python.dp_accounting
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
import tree
from com_google_differential_py.python.dp_accounting import dp_event
class NestedQuery(dp_query.DPQuery):
"""Implements DPQuery interface for structured queries.
@ -101,7 +102,7 @@ class NestedQuery(dp_query.DPQuery):
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
tf.nest.pack_sequence_as(self._queries, flat_new_global_states),
dp_accounting.ComposedDpEvent(events=flat_events))
dp_event.ComposedDpEvent(events=flat_events))
def derive_metrics(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""

View file

@ -13,10 +13,11 @@
# limitations under the License.
"""Implements DPQuery interface for no privacy average queries."""
from com_google_differential_py.python.dp_accounting
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
from com_google_differential_py.python.dp_accounting import dp_event
class NoPrivacySumQuery(dp_query.SumAggregationDPQuery):
"""Implements DPQuery interface for a sum query with no privacy.
@ -26,7 +27,7 @@ class NoPrivacySumQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
return sample_state, global_state, dp_accounting.NonPrivateDpEvent()
return sample_state, global_state, dp_event.NonPrivateDpEvent()
class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
@ -85,4 +86,4 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
sum_state, denominator = sample_state
result = tf.nest.map_structure(lambda t: t / denominator, sum_state)
return result, global_state, dp_accounting.NonPrivateDpEvent()
return result, global_state, dp_event.NonPrivateDpEvent()

View file

@ -15,12 +15,13 @@
import collections
from com_google_differential_py.python.dp_accounting
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import quantile_estimator_query
from com_google_differential_py.python.dp_accounting import dp_event
class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
"""`DPQuery` for Gaussian sum queries with adaptive clipping.
@ -137,7 +138,7 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
new_sum_query_state,
new_quantile_estimator_state)
event = dp_accounting.ComposedDpEvent(events=[sum_event, quantile_event])
event = dp_event.ComposedDpEvent(events=[sum_event, quantile_event])
return noised_vectors, new_global_state, event
def derive_metrics(self, global_state):

View file

@ -34,11 +34,11 @@ corresponding epsilon for a `target_delta` and `noise_multiplier` to achieve
"""
import attr
from com_google_differential_py.python.dp_accounting
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation
from com_google_differential_py.python.dp_accounting import dp_event
# TODO(b/193679963): define `RestartQuery` and move `RestartIndicator` to be
# in the same module.
@ -186,7 +186,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
global_state,
samples_cumulative_sum=new_cumulative_sum,
tree_state=new_tree_state)
event = dp_accounting.UnsupportedDpEvent()
event = dp_event.UnsupportedDpEvent()
return noised_cumulative_sum, new_global_state, event
def reset_state(self, noised_results, global_state):
@ -428,7 +428,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
global_state.previous_tree_noise)
new_global_state = attr.evolve(
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
event = dp_accounting.UnsupportedDpEvent()
event = dp_event.UnsupportedDpEvent()
return noised_sample, new_global_state, event
def reset_state(self, noised_results, global_state):

View file

@ -21,12 +21,13 @@ import math
from typing import Optional
import attr
from com_google_differential_py.python.dp_accounting
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query
from com_google_differential_py.python.dp_accounting import dp_event
def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
"""A function constructs a complete tree given all the leaf nodes.
@ -203,7 +204,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
]
tree = tf.RaggedTensor.from_row_splits(
values=sample_state, row_splits=row_splits)
event = dp_accounting.UnsupportedDpEvent()
event = dp_event.UnsupportedDpEvent()
return tree, new_global_state, event
@classmethod

View file

@ -12,7 +12,9 @@ py_library(
":datasets",
":single_layer_softmax",
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
"@com_google_differential_py//python/dp_accounting",
"@com_google_differential_py//python/dp_accounting:dp_event",
"@com_google_differential_py//python/dp_accounting:mechanism_calibration",
"@com_google_differential_py//python/dp_accounting/rdp:rdp_privacy_accountant",
],
)

View file

@ -28,13 +28,16 @@ the algorithm of Abadi et al.: https://arxiv.org/pdf/1607.00133.pdf%20.
import math
from typing import List, Optional, Tuple
from com_google_differential_py.python.dp_accounting
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.logistic_regression import datasets
from tensorflow_privacy.privacy.logistic_regression import single_layer_softmax
from tensorflow_privacy.privacy.optimizers import dp_optimizer_keras
from com_google_differential_py.python.dp_accounting import dp_event
from com_google_differential_py.python.dp_accounting import mechanism_calibration
from com_google_differential_py.python.dp_accounting.rdp import rdp_privacy_accountant
@tf.keras.utils.register_keras_serializable(package='Custom', name='Kifer')
class KiferRegularizer(tf.keras.regularizers.Regularizer):
@ -173,17 +176,17 @@ def compute_dpsgd_noise_multiplier(num_train: int,
steps = int(math.ceil(epochs * num_train / batch_size))
def make_event_from_param(noise_multiplier):
return dp_accounting.SelfComposedDpEvent(
dp_accounting.PoissonSampledDpEvent(
return dp_event.SelfComposedDpEvent(
dp_event.PoissonSampledDpEvent(
sampling_probability=batch_size / num_train,
event=dp_accounting.GaussianDpEvent(noise_multiplier)), steps)
event=dp_event.GaussianDpEvent(noise_multiplier)), steps)
return dp_accounting.calibrate_dp_mechanism(
lambda: dp_accounting.rdp.RdpAccountant(orders),
return mechanism_calibration.calibrate_dp_mechanism(
lambda: rdp_privacy_accountant.RdpAccountant(orders),
make_event_from_param,
epsilon,
delta,
dp_accounting.LowerEndpointAndGuess(0, 1),
mechanism_calibration.LowerEndpointAndGuess(0, 1),
tol=tolerance)

View file

@ -27,7 +27,8 @@ py_binary(
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/optimizers:dp_optimizer",
"@com_google_differential_py//python/dp_accounting",
"@com_google_differential_py//python/dp_accounting:dp_event",
"@com_google_differential_py//python/dp_accounting/rdp:rdp_privacy_accountant",
],
)
@ -38,7 +39,8 @@ py_binary(
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
"@com_google_differential_py//python/dp_accounting",
"@com_google_differential_py//python/dp_accounting:dp_event",
"@com_google_differential_py//python/dp_accounting/rdp:rdp_privacy_accountant",
],
)
@ -49,7 +51,8 @@ py_binary(
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/keras_models:dp_keras_model",
"@com_google_differential_py//python/dp_accounting",
"@com_google_differential_py//python/dp_accounting:dp_event",
"@com_google_differential_py//python/dp_accounting/rdp:rdp_privacy_accountant",
],
)
@ -60,7 +63,8 @@ py_binary(
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_vectorized",
"@com_google_differential_py//python/dp_accounting",
"@com_google_differential_py//python/dp_accounting:dp_event",
"@com_google_differential_py//python/dp_accounting/rdp:rdp_privacy_accountant",
],
)
@ -83,7 +87,8 @@ py_binary(
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/optimizers:dp_optimizer",
"@com_google_differential_py//python/dp_accounting",
"@com_google_differential_py//python/dp_accounting:dp_event",
"@com_google_differential_py//python/dp_accounting/rdp:rdp_privacy_accountant",
],
)
@ -94,7 +99,8 @@ py_binary(
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/optimizers:dp_optimizer",
"@com_google_differential_py//python/dp_accounting",
"@com_google_differential_py//python/dp_accounting:dp_event",
"@com_google_differential_py//python/dp_accounting/rdp:rdp_privacy_accountant",
],
)

View file

@ -35,13 +35,15 @@ import os
from absl import app
from absl import flags
from absl import logging
from com_google_differential_py.python.dp_accounting
import numpy as np
import tensorflow as tf
from tensorflow import estimator as tf_estimator
from tensorflow.compat.v1 import estimator as tf_compat_v1_estimator
import tensorflow_datasets as tfds
from tensorflow_privacy.privacy.optimizers import dp_optimizer
from com_google_differential_py.python.dp_accounting import dp_event
from com_google_differential_py.python.dp_accounting.rdp import rdp_privacy_accountant
flags.DEFINE_boolean(
@ -151,11 +153,11 @@ def compute_epsilon(steps):
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
sampling_probability = FLAGS.batch_size / NB_TRAIN
accountant = dp_accounting.rdp.RdpAccountant(orders)
event = dp_accounting.SelfComposedDpEvent(
dp_accounting.PoissonSampledDpEvent(
accountant = rdp_privacy_accountant.RdpAccountant(orders)
event = dp_event.SelfComposedDpEvent(
dp_event.PoissonSampledDpEvent(
sampling_probability,
dp_accounting.GaussianDpEvent(FLAGS.noise_multiplier)), steps)
dp_event.GaussianDpEvent(FLAGS.noise_multiplier)), steps)
accountant.compose(event)
# Delta is set to 1e-5 because Penn TreeBank has 60000 training points.

View file

@ -15,11 +15,14 @@
from absl import app
from absl import flags
from com_google_differential_py.python.dp_accounting
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
from com_google_differential_py.python.dp_accounting import dp_event
from com_google_differential_py.python.dp_accounting.rdp import rdp_privacy_accountant
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
tf.compat.v1.enable_eager_execution()
@ -44,13 +47,13 @@ def compute_epsilon(steps):
if FLAGS.noise_multiplier == 0.0:
return float('inf')
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
accountant = dp_accounting.rdp.RdpAccountant(orders)
accountant = rdp_privacy_accountant.RdpAccountant(orders)
sampling_probability = FLAGS.batch_size / 60000
event = dp_accounting.SelfComposedDpEvent(
dp_accounting.PoissonSampledDpEvent(
event = dp_event.SelfComposedDpEvent(
dp_event.PoissonSampledDpEvent(
sampling_probability,
dp_accounting.GaussianDpEvent(FLAGS.noise_multiplier)), steps)
dp_event.GaussianDpEvent(FLAGS.noise_multiplier)), steps)
accountant.compose(event)

View file

@ -16,11 +16,14 @@
from absl import app
from absl import flags
from absl import logging
from com_google_differential_py.python.dp_accounting
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
from com_google_differential_py.python.dp_accounting import dp_event
from com_google_differential_py.python.dp_accounting.rdp import rdp_privacy_accountant
flags.DEFINE_boolean(
'dpsgd', True, 'If True, train with DP-SGD. If False, '
@ -44,13 +47,13 @@ def compute_epsilon(steps):
if FLAGS.noise_multiplier == 0.0:
return float('inf')
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
accountant = dp_accounting.rdp.RdpAccountant(orders)
accountant = rdp_privacy_accountant.RdpAccountant(orders)
sampling_probability = FLAGS.batch_size / 60000
event = dp_accounting.SelfComposedDpEvent(
dp_accounting.PoissonSampledDpEvent(
event = dp_event.SelfComposedDpEvent(
dp_event.PoissonSampledDpEvent(
sampling_probability,
dp_accounting.GaussianDpEvent(FLAGS.noise_multiplier)), steps)
dp_event.GaussianDpEvent(FLAGS.noise_multiplier)), steps)
accountant.compose(event)

View file

@ -16,11 +16,13 @@
from absl import app
from absl import flags
from absl import logging
from com_google_differential_py.python.dp_accounting
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.keras_models.dp_keras_model import DPSequential
from com_google_differential_py.python.dp_accounting import dp_event
from com_google_differential_py.python.dp_accounting.rdp import rdp_privacy_accountant
flags.DEFINE_boolean(
'dpsgd', True, 'If True, train with DP-SGD. If False, '
@ -44,13 +46,13 @@ def compute_epsilon(steps):
if FLAGS.noise_multiplier == 0.0:
return float('inf')
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
accountant = dp_accounting.rdp.RdpAccountant(orders)
accountant = rdp_privacy_accountant.RdpAccountant(orders)
sampling_probability = FLAGS.batch_size / 60000
event = dp_accounting.SelfComposedDpEvent(
dp_accounting.PoissonSampledDpEvent(
event = dp_event.SelfComposedDpEvent(
dp_event.PoissonSampledDpEvent(
sampling_probability,
dp_accounting.GaussianDpEvent(FLAGS.noise_multiplier)), steps)
dp_event.GaussianDpEvent(FLAGS.noise_multiplier)), steps)
accountant.compose(event)

View file

@ -16,13 +16,14 @@
from absl import app
from absl import flags
from absl import logging
from com_google_differential_py.python.dp_accounting
import numpy as np
import tensorflow as tf
from tensorflow import estimator as tf_estimator
from tensorflow.compat.v1 import estimator as tf_compat_v1_estimator
from tensorflow_privacy.privacy.optimizers import dp_optimizer_vectorized
from com_google_differential_py.python.dp_accounting import dp_event
from com_google_differential_py.python.dp_accounting.rdp import rdp_privacy_accountant
flags.DEFINE_boolean(
'dpsgd', True, 'If True, train with DP-SGD. If False, '
@ -50,13 +51,13 @@ def compute_epsilon(steps):
if FLAGS.noise_multiplier == 0.0:
return float('inf')
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
accountant = dp_accounting.rdp.RdpAccountant(orders)
accountant = rdp_privacy_accountant.RdpAccountant(orders)
sampling_probability = FLAGS.batch_size / 60000
event = dp_accounting.SelfComposedDpEvent(
dp_accounting.PoissonSampledDpEvent(
event = dp_event.SelfComposedDpEvent(
dp_event.PoissonSampledDpEvent(
sampling_probability,
dp_accounting.GaussianDpEvent(FLAGS.noise_multiplier)), steps)
dp_event.GaussianDpEvent(FLAGS.noise_multiplier)), steps)
accountant.compose(event)

View file

@ -26,12 +26,13 @@ import math
from absl import app
from absl import flags
from absl import logging
from com_google_differential_py.python.dp_accounting
import numpy as np
import tensorflow as tf
from tensorflow import estimator as tf_estimator
from tensorflow.compat.v1 import estimator as tf_compat_v1_estimator
from tensorflow_privacy.privacy.optimizers import dp_optimizer
from com_google_differential_py.python.dp_accounting import dp_event
from com_google_differential_py.python.dp_accounting.rdp import rdp_privacy_accountant
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
@ -165,14 +166,13 @@ def print_privacy_guarantees(epochs, batch_size, samples, noise_multiplier):
# Using RDP accountant to compute eps. Doing computation analytically is
# an option.
rdp = [order * coef for order in orders]
eps = dp_accounting.rdp.compute_epsilon(orders, rdp, delta)
eps = rdp_privacy_accountant.compute_epsilon(orders, rdp, delta)
print('\t{:g}% enjoy at least ({:.2f}, {})-DP'.format(p * 100, eps, delta))
accountant = dp_accounting.rdp.RdpAccountant(orders)
event = dp_accounting.SelfComposedDpEvent(
dp_accounting.PoissonSampledDpEvent(
batch_size / samples,
dp_accounting.GaussianDpEvent(noise_multiplier)),
accountant = rdp_privacy_accountant.RdpAccountant(orders)
event = dp_event.SelfComposedDpEvent(
dp_event.PoissonSampledDpEvent(
batch_size / samples, dp_event.GaussianDpEvent(noise_multiplier)),
epochs * steps_per_epoch)
accountant.compose(event)
eps_sgm = accountant.get_epsilon(target_delta=delta)