Add a metric for TA-DP-FTRL,

PiperOrigin-RevId: 590791663
This commit is contained in:
Zheng Xu 2023-12-13 20:02:53 -08:00 committed by A. Unique TensorFlower
parent 81a4fd82f7
commit a4deb12ee0
2 changed files with 17 additions and 2 deletions

View file

@ -33,11 +33,10 @@ corresponding epsilon for a `target_delta` and `noise_multiplier` to achieve
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0] eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
""" """
import collections
from typing import Any, NamedTuple from typing import Any, NamedTuple
import dp_accounting import dp_accounting
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation from tensorflow_privacy.privacy.dp_query import tree_aggregation
@ -476,6 +475,10 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
previous_tree_noise=global_state.previous_tree_noise, previous_tree_noise=global_state.previous_tree_noise,
) )
def derive_metrics(self, global_state):
"""Returns the clip norm as a metric."""
return collections.OrderedDict(tree_agg_dpftrl_clip=global_state.clip_value)
@classmethod @classmethod
def build_l2_gaussian_query(cls, def build_l2_gaussian_query(cls,
clip_norm, clip_norm,

View file

@ -431,6 +431,18 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
) )
self.assertIsInstance(query._tree_aggregator, tree_class) self.assertIsInstance(query._tree_aggregator, tree_class)
def test_derive_metrics(self):
specs = tf.TensorSpec([])
l2_clip = 2
query = tree_aggregation_query.TreeResidualSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=l2_clip,
noise_generator=_get_noise_fn(specs, 1.0),
record_specs=specs,
)
metrics = query.derive_metrics(query.initial_global_state())
self.assertEqual(metrics['tree_agg_dpftrl_clip'], l2_clip)
@parameterized.named_parameters( @parameterized.named_parameters(
('s0t1f1', 0., 1., 1), ('s0t1f1', 0., 1., 1),
('s0t1f2', 0., 1., 2), ('s0t1f2', 0., 1., 2),