diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index 65bec4d..03bfdb1 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -33,11 +33,10 @@ corresponding epsilon for a `target_delta` and `noise_multiplier` to achieve eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0] """ +import collections from typing import Any, NamedTuple - import dp_accounting import tensorflow as tf - from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import tree_aggregation @@ -476,6 +475,10 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): previous_tree_noise=global_state.previous_tree_noise, ) + def derive_metrics(self, global_state): + """Returns the clip norm as a metric.""" + return collections.OrderedDict(tree_agg_dpftrl_clip=global_state.clip_value) + @classmethod def build_l2_gaussian_query(cls, clip_norm, diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py index f8afe09..4f7f84b 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py @@ -431,6 +431,18 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase): ) self.assertIsInstance(query._tree_aggregator, tree_class) + def test_derive_metrics(self): + specs = tf.TensorSpec([]) + l2_clip = 2 + query = tree_aggregation_query.TreeResidualSumQuery( + clip_fn=_get_l2_clip_fn(), + clip_value=l2_clip, + noise_generator=_get_noise_fn(specs, 1.0), + record_specs=specs, + ) + metrics = query.derive_metrics(query.initial_global_state()) + self.assertEqual(metrics['tree_agg_dpftrl_clip'], l2_clip) + @parameterized.named_parameters( ('s0t1f1', 0., 1., 1), ('s0t1f2', 0., 1., 2),