Migrate dependency of tree_aggregation_accountant_test on rdp_accountant to differential_privacy.

PiperOrigin-RevId: 453989532
This commit is contained in:
Galen Andrew 2022-06-09 12:43:40 -07:00 committed by A. Unique TensorFlower
parent 6c0cc858e0
commit 125f82707a
2 changed files with 11 additions and 6 deletions

View file

@ -121,5 +121,7 @@ py_test(
deps = [
":rdp_accountant",
":tree_aggregation_accountant",
"@com_google_differential_py//python/dp_accounting:dp_event",
"@com_google_differential_py//python/dp_accounting/rdp:rdp_privacy_accountant",
],
)

View file

@ -16,9 +16,11 @@
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import rdp_accountant
from tensorflow_privacy.privacy.analysis import tree_aggregation_accountant
from com_google_differential_py.python.dp_accounting import dp_event
from com_google_differential_py.python.dp_accounting.rdp import rdp_privacy_accountant
class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase):
@ -31,8 +33,8 @@ class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase):
steps_list, target_delta = 1600, 1e-6
rdp = tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, steps_list, orders)
new_eps = rdp_accountant.get_privacy_spent(
orders, rdp, target_delta=target_delta)[0]
new_eps = rdp_privacy_accountant.compute_epsilon(orders, rdp,
target_delta)[0]
self.assertLess(new_eps, eps)
@parameterized.named_parameters(
@ -65,8 +67,7 @@ class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase):
for noise_multiplier in [0.1 * x for x in range(1, 100, 5)]:
rdp = tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, steps_list, orders)
eps = rdp_accountant.get_privacy_spent(
orders, rdp, target_delta=target_delta)[0]
eps = rdp_privacy_accountant.compute_epsilon(orders, rdp, target_delta)[0]
self.assertLess(eps, prev_eps)
prev_eps = eps
@ -89,7 +90,9 @@ class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase):
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
tree_rdp = tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, [1] * total_steps, orders)
rdp = rdp_accountant.compute_rdp(1., noise_multiplier, total_steps, orders)
accountant = rdp_privacy_accountant.RdpAccountant(orders)
accountant.compose(dp_event.GaussianDpEvent(noise_multiplier), total_steps)
rdp = accountant._rdp # pylint: disable=protected-access
self.assertAllClose(tree_rdp, rdp, rtol=1e-12)
@parameterized.named_parameters(