Add derive_metrics function to DPQuery.

derive_metrics is a new function in the public API so customers can query aspects of the global state that change, such as the clip when using adaptive clipping.

PiperOrigin-RevId: 326174158
This commit is contained in:
Galen Andrew 2020-08-11 22:58:19 -07:00 committed by A. Unique TensorFlower
parent 06bb047525
commit 37ff5d502e
6 changed files with 82 additions and 39 deletions

View file

@ -47,6 +47,7 @@ from __future__ import division
from __future__ import print_function
import abc
import collections
import tensorflow.compat.v1 as tf
@ -83,7 +84,7 @@ class DPQuery(object):
return ()
@abc.abstractmethod
def initial_sample_state(self, template):
def initial_sample_state(self, template=None):
"""Returns an initial state to use for the next sample.
Args:
@ -197,6 +198,20 @@ class DPQuery(object):
"""
pass
def derive_metrics(self, global_state):
"""Derives metric information from the current global state.
Any metrics returned should be derived only from privatized quantities.
Args:
global_state: The global state from which to derive metrics.
Returns:
A `collections.OrderedDict` mapping string metric names to tensor values.
"""
del global_state
return collections.OrderedDict()
def zeros_like(arg):
"""A `zeros_like` function that also works for `tf.TensorSpec`s."""
@ -214,11 +229,14 @@ def safe_add(x, y):
class SumAggregationDPQuery(DPQuery):
"""Base class for DPQueries that aggregate via sum."""
def initial_sample_state(self, template):
def initial_sample_state(self, template=None):
return tf.nest.map_structure(zeros_like, template)
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
return tf.nest.map_structure(safe_add, sample_state, preprocessed_record)
def merge_sample_states(self, sample_state_1, sample_state_2):
return tf.nest.map_structure(safe_add, sample_state_1, sample_state_2)
return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2)
def get_noised_result(self, sample_state, global_state):
return sample_state, global_state

View file

@ -1,4 +1,4 @@
# Copyright 2018, The TensorFlow Authors.
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.dp_query import dp_query
import tree
@ -51,6 +53,7 @@ class NestedQuery(dp_query.DPQuery):
self._queries = queries
def _map_to_queries(self, fn, *inputs, **kwargs):
"""Maps DPQuery methods to the subqueries."""
def caller(query, *args):
return getattr(query, fn)(*args, **kwargs)
@ -61,24 +64,22 @@ class NestedQuery(dp_query.DPQuery):
self._map_to_queries('set_ledger', ledger=ledger)
def initial_global_state(self):
"""See base class."""
return self._map_to_queries('initial_global_state')
def derive_sample_params(self, global_state):
"""See base class."""
return self._map_to_queries('derive_sample_params', global_state)
def initial_sample_state(self, template):
"""See base class."""
def initial_sample_state(self, template=None):
if template is None:
return self._map_to_queries('initial_sample_state')
else:
return self._map_to_queries('initial_sample_state', template)
def preprocess_record(self, params, record):
"""See base class."""
return self._map_to_queries('preprocess_record', params, record)
def accumulate_preprocessed_record(
self, sample_state, preprocessed_record):
"""See base class."""
return self._map_to_queries(
'accumulate_preprocessed_record',
sample_state,
@ -89,18 +90,6 @@ class NestedQuery(dp_query.DPQuery):
'merge_sample_states', sample_state_1, sample_state_2)
def get_noised_result(self, sample_state, global_state):
"""Gets query result after all records of sample have been accumulated.
Args:
sample_state: The sample state after all records have been accumulated.
global_state: The global state.
Returns:
A tuple (result, new_global_state) where "result" is a structure matching
the query structure containing the results of the subqueries and
"new_global_state" is a structure containing the updated global states
for the subqueries.
"""
estimates_and_new_global_states = self._map_to_queries(
'get_noised_result', sample_state, global_state)
@ -109,8 +98,22 @@ class NestedQuery(dp_query.DPQuery):
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
def derive_metrics(self, global_state):
metrics = collections.OrderedDict()
class NestedSumQuery(dp_query.SumAggregationDPQuery, NestedQuery):
def add_metrics(tuple_path, subquery, subquery_global_state):
metrics.update({
'/'.join(str(s) for s in tuple_path + (name,)): metric
for name, metric
in subquery.derive_metrics(subquery_global_state).items()})
tree.map_structure_with_path_up_to(
self._queries, add_metrics, self._queries, global_state)
return metrics
class NestedSumQuery(NestedQuery, dp_query.SumAggregationDPQuery):
"""A NestedQuery that consists only of SumAggregationDPQueries."""
def __init__(self, queries):

View file

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from absl.testing import parameterized
import numpy as np
@ -152,6 +153,32 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
with self.assertRaises(TypeError):
nested_query.NestedSumQuery(non_sum_query)
def test_metrics(self):
class QueryWithMetric(dp_query.SumAggregationDPQuery):
def __init__(self, metric_val):
self._metric_val = metric_val
def derive_metrics(self, global_state):
return collections.OrderedDict(metric=self._metric_val)
query1 = QueryWithMetric(1)
query2 = QueryWithMetric(2)
query3 = QueryWithMetric(3)
nested_a = nested_query.NestedSumQuery(query1)
global_state = nested_a.initial_global_state()
metric_val = nested_a.derive_metrics(global_state)
self.assertEqual(metric_val['metric'], 1)
nested_b = nested_query.NestedSumQuery(
{'key1': query1, 'key2': [query2, query3]})
global_state = nested_b.initial_global_state()
metric_val = nested_b.derive_metrics(global_state)
self.assertEqual(metric_val['key1/metric'], 1)
self.assertEqual(metric_val['key2/0/metric'], 2)
self.assertEqual(metric_val['key2/1/metric'], 3)
if __name__ == '__main__':
tf.test.main()

View file

@ -1,4 +1,4 @@
# Copyright 2019, The TensorFlow Authors.
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -48,11 +48,9 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
assert isinstance(self._numerator, dp_query.SumAggregationDPQuery)
def set_ledger(self, ledger):
"""See base class."""
self._numerator.set_ledger(ledger)
def initial_global_state(self):
"""See base class."""
if self._denominator is not None:
denominator = tf.cast(self._denominator, tf.float32)
else:
@ -61,11 +59,9 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
self._numerator.initial_global_state(), denominator)
def derive_sample_params(self, global_state):
"""See base class."""
return self._numerator.derive_sample_params(global_state.numerator_state)
def initial_sample_state(self, template):
"""See base class."""
# NormalizedQuery has no sample state beyond the numerator state.
return self._numerator.initial_sample_state(template)
@ -73,7 +69,6 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
return self._numerator.preprocess_record(params, record)
def get_noised_result(self, sample_state, global_state):
"""See base class."""
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
sample_state, global_state.numerator_state)
def normalize(v):
@ -81,3 +76,6 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
return (tf.nest.map_structure(normalize, noised_sum),
self._GlobalState(new_sum_global_state, global_state.denominator))
def derive_metrics(self, global_state):
return self._numerator.derive_metrics(global_state.numerator_state)

View file

@ -104,26 +104,22 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
dp_query.SumAggregationDPQuery)
def set_ledger(self, ledger):
"""See base class."""
self._sum_query.set_ledger(ledger)
self._quantile_estimator_query.set_ledger(ledger)
def initial_global_state(self):
"""See base class."""
return self._GlobalState(
tf.cast(self._noise_multiplier, tf.float32),
self._sum_query.initial_global_state(),
self._quantile_estimator_query.initial_global_state())
def derive_sample_params(self, global_state):
"""See base class."""
return self._SampleParams(
self._sum_query.derive_sample_params(global_state.sum_state),
self._quantile_estimator_query.derive_sample_params(
global_state.quantile_estimator_state))
def initial_sample_state(self, template):
"""See base class."""
return self._SampleState(
self._sum_query.initial_sample_state(template),
self._quantile_estimator_query.initial_sample_state())
@ -138,7 +134,6 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
return self._SampleState(clipped_record, was_unclipped)
def get_noised_result(self, sample_state, global_state):
"""See base class."""
noised_vectors, sum_state = self._sum_query.get_noised_result(
sample_state.sum_state, global_state.sum_state)
del sum_state # To be set explicitly later when we know the new clip.
@ -161,6 +156,9 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
return noised_vectors, new_global_state
def derive_metrics(self, global_state):
return collections.OrderedDict(clip=global_state.sum_state.l2_norm_clip)
class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery):
"""DPQuery for average queries with adaptive clipping.

View file

@ -107,11 +107,9 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
denominator=expected_num_records)
def set_ledger(self, ledger):
"""See base class."""
self._below_estimate_query.set_ledger(ledger)
def initial_global_state(self):
"""See base class."""
return self._GlobalState(
tf.cast(self._initial_estimate, tf.float32),
tf.cast(self._target_quantile, tf.float32),
@ -119,7 +117,6 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
self._below_estimate_query.initial_global_state())
def derive_sample_params(self, global_state):
"""See base class."""
below_estimate_params = self._below_estimate_query.derive_sample_params(
global_state.below_estimate_state)
return self._SampleParams(global_state.current_estimate,
@ -141,7 +138,6 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
params.below_estimate_params, below)
def get_noised_result(self, sample_state, global_state):
"""See base class."""
below_estimate_result, new_below_estimate_state = (
self._below_estimate_query.get_noised_result(
sample_state,
@ -170,6 +166,9 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
return new_estimate, new_global_state
def derive_metrics(self, global_state):
return collections.OrderedDict(estimate=global_state.current_estimate)
class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
"""Iterative process to estimate target quantile of a univariate distribution.