From 37ff5d502ebb263243259cbb5f2bf548a08fee2a Mon Sep 17 00:00:00 2001 From: Galen Andrew Date: Tue, 11 Aug 2020 22:58:19 -0700 Subject: [PATCH] Add derive_metrics function to DPQuery. derive_metrics is a new function in the public API so customers can query aspects of the global state that change, such as the clip when using adaptive clipping. PiperOrigin-RevId: 326174158 --- .../privacy/dp_query/dp_query.py | 24 ++++++++-- .../privacy/dp_query/nested_query.py | 45 ++++++++++--------- .../privacy/dp_query/nested_query_test.py | 27 +++++++++++ .../privacy/dp_query/normalized_query.py | 10 ++--- .../quantile_adaptive_clip_sum_query.py | 8 ++-- .../dp_query/quantile_estimator_query.py | 7 ++- 6 files changed, 82 insertions(+), 39 deletions(-) diff --git a/tensorflow_privacy/privacy/dp_query/dp_query.py b/tensorflow_privacy/privacy/dp_query/dp_query.py index 782c893..63f21e1 100644 --- a/tensorflow_privacy/privacy/dp_query/dp_query.py +++ b/tensorflow_privacy/privacy/dp_query/dp_query.py @@ -47,6 +47,7 @@ from __future__ import division from __future__ import print_function import abc +import collections import tensorflow.compat.v1 as tf @@ -83,7 +84,7 @@ class DPQuery(object): return () @abc.abstractmethod - def initial_sample_state(self, template): + def initial_sample_state(self, template=None): """Returns an initial state to use for the next sample. Args: @@ -197,6 +198,20 @@ class DPQuery(object): """ pass + def derive_metrics(self, global_state): + """Derives metric information from the current global state. + + Any metrics returned should be derived only from privatized quantities. + + Args: + global_state: The global state from which to derive metrics. + + Returns: + A `collections.OrderedDict` mapping string metric names to tensor values. + """ + del global_state + return collections.OrderedDict() + def zeros_like(arg): """A `zeros_like` function that also works for `tf.TensorSpec`s.""" @@ -214,11 +229,14 @@ def safe_add(x, y): class SumAggregationDPQuery(DPQuery): """Base class for DPQueries that aggregate via sum.""" - def initial_sample_state(self, template): + def initial_sample_state(self, template=None): return tf.nest.map_structure(zeros_like, template) def accumulate_preprocessed_record(self, sample_state, preprocessed_record): return tf.nest.map_structure(safe_add, sample_state, preprocessed_record) def merge_sample_states(self, sample_state_1, sample_state_2): - return tf.nest.map_structure(safe_add, sample_state_1, sample_state_2) + return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2) + + def get_noised_result(self, sample_state, global_state): + return sample_state, global_state diff --git a/tensorflow_privacy/privacy/dp_query/nested_query.py b/tensorflow_privacy/privacy/dp_query/nested_query.py index 75b2db1..c57d704 100644 --- a/tensorflow_privacy/privacy/dp_query/nested_query.py +++ b/tensorflow_privacy/privacy/dp_query/nested_query.py @@ -1,4 +1,4 @@ -# Copyright 2018, The TensorFlow Authors. +# Copyright 2020, The TensorFlow Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + import tensorflow.compat.v1 as tf from tensorflow_privacy.privacy.dp_query import dp_query import tree @@ -51,6 +53,7 @@ class NestedQuery(dp_query.DPQuery): self._queries = queries def _map_to_queries(self, fn, *inputs, **kwargs): + """Maps DPQuery methods to the subqueries.""" def caller(query, *args): return getattr(query, fn)(*args, **kwargs) @@ -61,24 +64,22 @@ class NestedQuery(dp_query.DPQuery): self._map_to_queries('set_ledger', ledger=ledger) def initial_global_state(self): - """See base class.""" return self._map_to_queries('initial_global_state') def derive_sample_params(self, global_state): - """See base class.""" return self._map_to_queries('derive_sample_params', global_state) - def initial_sample_state(self, template): - """See base class.""" - return self._map_to_queries('initial_sample_state', template) + def initial_sample_state(self, template=None): + if template is None: + return self._map_to_queries('initial_sample_state') + else: + return self._map_to_queries('initial_sample_state', template) def preprocess_record(self, params, record): - """See base class.""" return self._map_to_queries('preprocess_record', params, record) def accumulate_preprocessed_record( self, sample_state, preprocessed_record): - """See base class.""" return self._map_to_queries( 'accumulate_preprocessed_record', sample_state, @@ -89,18 +90,6 @@ class NestedQuery(dp_query.DPQuery): 'merge_sample_states', sample_state_1, sample_state_2) def get_noised_result(self, sample_state, global_state): - """Gets query result after all records of sample have been accumulated. - - Args: - sample_state: The sample state after all records have been accumulated. - global_state: The global state. - - Returns: - A tuple (result, new_global_state) where "result" is a structure matching - the query structure containing the results of the subqueries and - "new_global_state" is a structure containing the updated global states - for the subqueries. - """ estimates_and_new_global_states = self._map_to_queries( 'get_noised_result', sample_state, global_state) @@ -109,8 +98,22 @@ class NestedQuery(dp_query.DPQuery): return (tf.nest.pack_sequence_as(self._queries, flat_estimates), tf.nest.pack_sequence_as(self._queries, flat_new_global_states)) + def derive_metrics(self, global_state): + metrics = collections.OrderedDict() -class NestedSumQuery(dp_query.SumAggregationDPQuery, NestedQuery): + def add_metrics(tuple_path, subquery, subquery_global_state): + metrics.update({ + '/'.join(str(s) for s in tuple_path + (name,)): metric + for name, metric + in subquery.derive_metrics(subquery_global_state).items()}) + + tree.map_structure_with_path_up_to( + self._queries, add_metrics, self._queries, global_state) + + return metrics + + +class NestedSumQuery(NestedQuery, dp_query.SumAggregationDPQuery): """A NestedQuery that consists only of SumAggregationDPQueries.""" def __init__(self, queries): diff --git a/tensorflow_privacy/privacy/dp_query/nested_query_test.py b/tensorflow_privacy/privacy/dp_query/nested_query_test.py index 625487e..d486c16 100644 --- a/tensorflow_privacy/privacy/dp_query/nested_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/nested_query_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections from absl.testing import parameterized import numpy as np @@ -152,6 +153,32 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase): with self.assertRaises(TypeError): nested_query.NestedSumQuery(non_sum_query) + def test_metrics(self): + class QueryWithMetric(dp_query.SumAggregationDPQuery): + + def __init__(self, metric_val): + self._metric_val = metric_val + + def derive_metrics(self, global_state): + return collections.OrderedDict(metric=self._metric_val) + + query1 = QueryWithMetric(1) + query2 = QueryWithMetric(2) + query3 = QueryWithMetric(3) + + nested_a = nested_query.NestedSumQuery(query1) + global_state = nested_a.initial_global_state() + metric_val = nested_a.derive_metrics(global_state) + self.assertEqual(metric_val['metric'], 1) + + nested_b = nested_query.NestedSumQuery( + {'key1': query1, 'key2': [query2, query3]}) + global_state = nested_b.initial_global_state() + metric_val = nested_b.derive_metrics(global_state) + self.assertEqual(metric_val['key1/metric'], 1) + self.assertEqual(metric_val['key2/0/metric'], 2) + self.assertEqual(metric_val['key2/1/metric'], 3) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_privacy/privacy/dp_query/normalized_query.py b/tensorflow_privacy/privacy/dp_query/normalized_query.py index f5d8f42..b59700f 100644 --- a/tensorflow_privacy/privacy/dp_query/normalized_query.py +++ b/tensorflow_privacy/privacy/dp_query/normalized_query.py @@ -1,4 +1,4 @@ -# Copyright 2019, The TensorFlow Authors. +# Copyright 2020, The TensorFlow Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -48,11 +48,9 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery): assert isinstance(self._numerator, dp_query.SumAggregationDPQuery) def set_ledger(self, ledger): - """See base class.""" self._numerator.set_ledger(ledger) def initial_global_state(self): - """See base class.""" if self._denominator is not None: denominator = tf.cast(self._denominator, tf.float32) else: @@ -61,11 +59,9 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery): self._numerator.initial_global_state(), denominator) def derive_sample_params(self, global_state): - """See base class.""" return self._numerator.derive_sample_params(global_state.numerator_state) def initial_sample_state(self, template): - """See base class.""" # NormalizedQuery has no sample state beyond the numerator state. return self._numerator.initial_sample_state(template) @@ -73,7 +69,6 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery): return self._numerator.preprocess_record(params, record) def get_noised_result(self, sample_state, global_state): - """See base class.""" noised_sum, new_sum_global_state = self._numerator.get_noised_result( sample_state, global_state.numerator_state) def normalize(v): @@ -81,3 +76,6 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery): return (tf.nest.map_structure(normalize, noised_sum), self._GlobalState(new_sum_global_state, global_state.denominator)) + + def derive_metrics(self, global_state): + return self._numerator.derive_metrics(global_state.numerator_state) diff --git a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py index b260253..bacd442 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py @@ -104,26 +104,22 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery): dp_query.SumAggregationDPQuery) def set_ledger(self, ledger): - """See base class.""" self._sum_query.set_ledger(ledger) self._quantile_estimator_query.set_ledger(ledger) def initial_global_state(self): - """See base class.""" return self._GlobalState( tf.cast(self._noise_multiplier, tf.float32), self._sum_query.initial_global_state(), self._quantile_estimator_query.initial_global_state()) def derive_sample_params(self, global_state): - """See base class.""" return self._SampleParams( self._sum_query.derive_sample_params(global_state.sum_state), self._quantile_estimator_query.derive_sample_params( global_state.quantile_estimator_state)) def initial_sample_state(self, template): - """See base class.""" return self._SampleState( self._sum_query.initial_sample_state(template), self._quantile_estimator_query.initial_sample_state()) @@ -138,7 +134,6 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery): return self._SampleState(clipped_record, was_unclipped) def get_noised_result(self, sample_state, global_state): - """See base class.""" noised_vectors, sum_state = self._sum_query.get_noised_result( sample_state.sum_state, global_state.sum_state) del sum_state # To be set explicitly later when we know the new clip. @@ -161,6 +156,9 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery): return noised_vectors, new_global_state + def derive_metrics(self, global_state): + return collections.OrderedDict(clip=global_state.sum_state.l2_norm_clip) + class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery): """DPQuery for average queries with adaptive clipping. diff --git a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py index eae8f8b..68b2865 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py @@ -107,11 +107,9 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery): denominator=expected_num_records) def set_ledger(self, ledger): - """See base class.""" self._below_estimate_query.set_ledger(ledger) def initial_global_state(self): - """See base class.""" return self._GlobalState( tf.cast(self._initial_estimate, tf.float32), tf.cast(self._target_quantile, tf.float32), @@ -119,7 +117,6 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery): self._below_estimate_query.initial_global_state()) def derive_sample_params(self, global_state): - """See base class.""" below_estimate_params = self._below_estimate_query.derive_sample_params( global_state.below_estimate_state) return self._SampleParams(global_state.current_estimate, @@ -141,7 +138,6 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery): params.below_estimate_params, below) def get_noised_result(self, sample_state, global_state): - """See base class.""" below_estimate_result, new_below_estimate_state = ( self._below_estimate_query.get_noised_result( sample_state, @@ -170,6 +166,9 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery): return new_estimate, new_global_state + def derive_metrics(self, global_state): + return collections.OrderedDict(estimate=global_state.current_estimate) + class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery): """Iterative process to estimate target quantile of a univariate distribution.