forked from 626_privacy/tensorflow_privacy
Add a metric for TA-DP-FTRL,
PiperOrigin-RevId: 590791663
This commit is contained in:
parent
81a4fd82f7
commit
a4deb12ee0
2 changed files with 17 additions and 2 deletions
|
@ -33,11 +33,10 @@ corresponding epsilon for a `target_delta` and `noise_multiplier` to achieve
|
|||
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
|
||||
"""
|
||||
|
||||
import collections
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import dp_accounting
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
||||
|
||||
|
@ -476,6 +475,10 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
|
|||
previous_tree_noise=global_state.previous_tree_noise,
|
||||
)
|
||||
|
||||
def derive_metrics(self, global_state):
|
||||
"""Returns the clip norm as a metric."""
|
||||
return collections.OrderedDict(tree_agg_dpftrl_clip=global_state.clip_value)
|
||||
|
||||
@classmethod
|
||||
def build_l2_gaussian_query(cls,
|
||||
clip_norm,
|
||||
|
|
|
@ -431,6 +431,18 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
)
|
||||
self.assertIsInstance(query._tree_aggregator, tree_class)
|
||||
|
||||
def test_derive_metrics(self):
|
||||
specs = tf.TensorSpec([])
|
||||
l2_clip = 2
|
||||
query = tree_aggregation_query.TreeResidualSumQuery(
|
||||
clip_fn=_get_l2_clip_fn(),
|
||||
clip_value=l2_clip,
|
||||
noise_generator=_get_noise_fn(specs, 1.0),
|
||||
record_specs=specs,
|
||||
)
|
||||
metrics = query.derive_metrics(query.initial_global_state())
|
||||
self.assertEqual(metrics['tree_agg_dpftrl_clip'], l2_clip)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('s0t1f1', 0., 1., 1),
|
||||
('s0t1f2', 0., 1., 2),
|
||||
|
|
Loading…
Reference in a new issue