Add derive_metrics function to DPQuery.
derive_metrics is a new function in the public API so customers can query aspects of the global state that change, such as the clip when using adaptive clipping. PiperOrigin-RevId: 326174158
This commit is contained in:
parent
06bb047525
commit
37ff5d502e
6 changed files with 82 additions and 39 deletions
|
@ -47,6 +47,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import collections
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
@ -83,7 +84,7 @@ class DPQuery(object):
|
|||
return ()
|
||||
|
||||
@abc.abstractmethod
|
||||
def initial_sample_state(self, template):
|
||||
def initial_sample_state(self, template=None):
|
||||
"""Returns an initial state to use for the next sample.
|
||||
|
||||
Args:
|
||||
|
@ -197,6 +198,20 @@ class DPQuery(object):
|
|||
"""
|
||||
pass
|
||||
|
||||
def derive_metrics(self, global_state):
|
||||
"""Derives metric information from the current global state.
|
||||
|
||||
Any metrics returned should be derived only from privatized quantities.
|
||||
|
||||
Args:
|
||||
global_state: The global state from which to derive metrics.
|
||||
|
||||
Returns:
|
||||
A `collections.OrderedDict` mapping string metric names to tensor values.
|
||||
"""
|
||||
del global_state
|
||||
return collections.OrderedDict()
|
||||
|
||||
|
||||
def zeros_like(arg):
|
||||
"""A `zeros_like` function that also works for `tf.TensorSpec`s."""
|
||||
|
@ -214,11 +229,14 @@ def safe_add(x, y):
|
|||
class SumAggregationDPQuery(DPQuery):
|
||||
"""Base class for DPQueries that aggregate via sum."""
|
||||
|
||||
def initial_sample_state(self, template):
|
||||
def initial_sample_state(self, template=None):
|
||||
return tf.nest.map_structure(zeros_like, template)
|
||||
|
||||
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
|
||||
return tf.nest.map_structure(safe_add, sample_state, preprocessed_record)
|
||||
|
||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||
return tf.nest.map_structure(safe_add, sample_state_1, sample_state_2)
|
||||
return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
return sample_state, global_state
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2018, The TensorFlow Authors.
|
||||
# Copyright 2020, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -19,6 +19,8 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
import tree
|
||||
|
@ -51,6 +53,7 @@ class NestedQuery(dp_query.DPQuery):
|
|||
self._queries = queries
|
||||
|
||||
def _map_to_queries(self, fn, *inputs, **kwargs):
|
||||
"""Maps DPQuery methods to the subqueries."""
|
||||
def caller(query, *args):
|
||||
return getattr(query, fn)(*args, **kwargs)
|
||||
|
||||
|
@ -61,24 +64,22 @@ class NestedQuery(dp_query.DPQuery):
|
|||
self._map_to_queries('set_ledger', ledger=ledger)
|
||||
|
||||
def initial_global_state(self):
|
||||
"""See base class."""
|
||||
return self._map_to_queries('initial_global_state')
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""See base class."""
|
||||
return self._map_to_queries('derive_sample_params', global_state)
|
||||
|
||||
def initial_sample_state(self, template):
|
||||
"""See base class."""
|
||||
def initial_sample_state(self, template=None):
|
||||
if template is None:
|
||||
return self._map_to_queries('initial_sample_state')
|
||||
else:
|
||||
return self._map_to_queries('initial_sample_state', template)
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
"""See base class."""
|
||||
return self._map_to_queries('preprocess_record', params, record)
|
||||
|
||||
def accumulate_preprocessed_record(
|
||||
self, sample_state, preprocessed_record):
|
||||
"""See base class."""
|
||||
return self._map_to_queries(
|
||||
'accumulate_preprocessed_record',
|
||||
sample_state,
|
||||
|
@ -89,18 +90,6 @@ class NestedQuery(dp_query.DPQuery):
|
|||
'merge_sample_states', sample_state_1, sample_state_2)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Gets query result after all records of sample have been accumulated.
|
||||
|
||||
Args:
|
||||
sample_state: The sample state after all records have been accumulated.
|
||||
global_state: The global state.
|
||||
|
||||
Returns:
|
||||
A tuple (result, new_global_state) where "result" is a structure matching
|
||||
the query structure containing the results of the subqueries and
|
||||
"new_global_state" is a structure containing the updated global states
|
||||
for the subqueries.
|
||||
"""
|
||||
estimates_and_new_global_states = self._map_to_queries(
|
||||
'get_noised_result', sample_state, global_state)
|
||||
|
||||
|
@ -109,8 +98,22 @@ class NestedQuery(dp_query.DPQuery):
|
|||
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
|
||||
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
|
||||
|
||||
def derive_metrics(self, global_state):
|
||||
metrics = collections.OrderedDict()
|
||||
|
||||
class NestedSumQuery(dp_query.SumAggregationDPQuery, NestedQuery):
|
||||
def add_metrics(tuple_path, subquery, subquery_global_state):
|
||||
metrics.update({
|
||||
'/'.join(str(s) for s in tuple_path + (name,)): metric
|
||||
for name, metric
|
||||
in subquery.derive_metrics(subquery_global_state).items()})
|
||||
|
||||
tree.map_structure_with_path_up_to(
|
||||
self._queries, add_metrics, self._queries, global_state)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
class NestedSumQuery(NestedQuery, dp_query.SumAggregationDPQuery):
|
||||
"""A NestedQuery that consists only of SumAggregationDPQueries."""
|
||||
|
||||
def __init__(self, queries):
|
||||
|
|
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
@ -152,6 +153,32 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
with self.assertRaises(TypeError):
|
||||
nested_query.NestedSumQuery(non_sum_query)
|
||||
|
||||
def test_metrics(self):
|
||||
class QueryWithMetric(dp_query.SumAggregationDPQuery):
|
||||
|
||||
def __init__(self, metric_val):
|
||||
self._metric_val = metric_val
|
||||
|
||||
def derive_metrics(self, global_state):
|
||||
return collections.OrderedDict(metric=self._metric_val)
|
||||
|
||||
query1 = QueryWithMetric(1)
|
||||
query2 = QueryWithMetric(2)
|
||||
query3 = QueryWithMetric(3)
|
||||
|
||||
nested_a = nested_query.NestedSumQuery(query1)
|
||||
global_state = nested_a.initial_global_state()
|
||||
metric_val = nested_a.derive_metrics(global_state)
|
||||
self.assertEqual(metric_val['metric'], 1)
|
||||
|
||||
nested_b = nested_query.NestedSumQuery(
|
||||
{'key1': query1, 'key2': [query2, query3]})
|
||||
global_state = nested_b.initial_global_state()
|
||||
metric_val = nested_b.derive_metrics(global_state)
|
||||
self.assertEqual(metric_val['key1/metric'], 1)
|
||||
self.assertEqual(metric_val['key2/0/metric'], 2)
|
||||
self.assertEqual(metric_val['key2/1/metric'], 3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019, The TensorFlow Authors.
|
||||
# Copyright 2020, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -48,11 +48,9 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
|
|||
assert isinstance(self._numerator, dp_query.SumAggregationDPQuery)
|
||||
|
||||
def set_ledger(self, ledger):
|
||||
"""See base class."""
|
||||
self._numerator.set_ledger(ledger)
|
||||
|
||||
def initial_global_state(self):
|
||||
"""See base class."""
|
||||
if self._denominator is not None:
|
||||
denominator = tf.cast(self._denominator, tf.float32)
|
||||
else:
|
||||
|
@ -61,11 +59,9 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
|
|||
self._numerator.initial_global_state(), denominator)
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""See base class."""
|
||||
return self._numerator.derive_sample_params(global_state.numerator_state)
|
||||
|
||||
def initial_sample_state(self, template):
|
||||
"""See base class."""
|
||||
# NormalizedQuery has no sample state beyond the numerator state.
|
||||
return self._numerator.initial_sample_state(template)
|
||||
|
||||
|
@ -73,7 +69,6 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
|
|||
return self._numerator.preprocess_record(params, record)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""See base class."""
|
||||
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
|
||||
sample_state, global_state.numerator_state)
|
||||
def normalize(v):
|
||||
|
@ -81,3 +76,6 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
|
|||
|
||||
return (tf.nest.map_structure(normalize, noised_sum),
|
||||
self._GlobalState(new_sum_global_state, global_state.denominator))
|
||||
|
||||
def derive_metrics(self, global_state):
|
||||
return self._numerator.derive_metrics(global_state.numerator_state)
|
||||
|
|
|
@ -104,26 +104,22 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
|||
dp_query.SumAggregationDPQuery)
|
||||
|
||||
def set_ledger(self, ledger):
|
||||
"""See base class."""
|
||||
self._sum_query.set_ledger(ledger)
|
||||
self._quantile_estimator_query.set_ledger(ledger)
|
||||
|
||||
def initial_global_state(self):
|
||||
"""See base class."""
|
||||
return self._GlobalState(
|
||||
tf.cast(self._noise_multiplier, tf.float32),
|
||||
self._sum_query.initial_global_state(),
|
||||
self._quantile_estimator_query.initial_global_state())
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""See base class."""
|
||||
return self._SampleParams(
|
||||
self._sum_query.derive_sample_params(global_state.sum_state),
|
||||
self._quantile_estimator_query.derive_sample_params(
|
||||
global_state.quantile_estimator_state))
|
||||
|
||||
def initial_sample_state(self, template):
|
||||
"""See base class."""
|
||||
return self._SampleState(
|
||||
self._sum_query.initial_sample_state(template),
|
||||
self._quantile_estimator_query.initial_sample_state())
|
||||
|
@ -138,7 +134,6 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
|||
return self._SampleState(clipped_record, was_unclipped)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""See base class."""
|
||||
noised_vectors, sum_state = self._sum_query.get_noised_result(
|
||||
sample_state.sum_state, global_state.sum_state)
|
||||
del sum_state # To be set explicitly later when we know the new clip.
|
||||
|
@ -161,6 +156,9 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
|||
|
||||
return noised_vectors, new_global_state
|
||||
|
||||
def derive_metrics(self, global_state):
|
||||
return collections.OrderedDict(clip=global_state.sum_state.l2_norm_clip)
|
||||
|
||||
|
||||
class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery):
|
||||
"""DPQuery for average queries with adaptive clipping.
|
||||
|
|
|
@ -107,11 +107,9 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
|||
denominator=expected_num_records)
|
||||
|
||||
def set_ledger(self, ledger):
|
||||
"""See base class."""
|
||||
self._below_estimate_query.set_ledger(ledger)
|
||||
|
||||
def initial_global_state(self):
|
||||
"""See base class."""
|
||||
return self._GlobalState(
|
||||
tf.cast(self._initial_estimate, tf.float32),
|
||||
tf.cast(self._target_quantile, tf.float32),
|
||||
|
@ -119,7 +117,6 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
|||
self._below_estimate_query.initial_global_state())
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""See base class."""
|
||||
below_estimate_params = self._below_estimate_query.derive_sample_params(
|
||||
global_state.below_estimate_state)
|
||||
return self._SampleParams(global_state.current_estimate,
|
||||
|
@ -141,7 +138,6 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
|||
params.below_estimate_params, below)
|
||||
|
||||
def get_noised_result(self, sample_state, global_state):
|
||||
"""See base class."""
|
||||
below_estimate_result, new_below_estimate_state = (
|
||||
self._below_estimate_query.get_noised_result(
|
||||
sample_state,
|
||||
|
@ -170,6 +166,9 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
|||
|
||||
return new_estimate, new_global_state
|
||||
|
||||
def derive_metrics(self, global_state):
|
||||
return collections.OrderedDict(estimate=global_state.current_estimate)
|
||||
|
||||
|
||||
class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
|
||||
"""Iterative process to estimate target quantile of a univariate distribution.
|
||||
|
|
Loading…
Reference in a new issue