Add derive_metrics function to DPQuery.
derive_metrics is a new function in the public API so customers can query aspects of the global state that change, such as the clip when using adaptive clipping. PiperOrigin-RevId: 326174158
This commit is contained in:
parent
06bb047525
commit
37ff5d502e
6 changed files with 82 additions and 39 deletions
|
@ -47,6 +47,7 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import collections
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
@ -83,7 +84,7 @@ class DPQuery(object):
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def initial_sample_state(self, template):
|
def initial_sample_state(self, template=None):
|
||||||
"""Returns an initial state to use for the next sample.
|
"""Returns an initial state to use for the next sample.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -197,6 +198,20 @@ class DPQuery(object):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def derive_metrics(self, global_state):
|
||||||
|
"""Derives metric information from the current global state.
|
||||||
|
|
||||||
|
Any metrics returned should be derived only from privatized quantities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
global_state: The global state from which to derive metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `collections.OrderedDict` mapping string metric names to tensor values.
|
||||||
|
"""
|
||||||
|
del global_state
|
||||||
|
return collections.OrderedDict()
|
||||||
|
|
||||||
|
|
||||||
def zeros_like(arg):
|
def zeros_like(arg):
|
||||||
"""A `zeros_like` function that also works for `tf.TensorSpec`s."""
|
"""A `zeros_like` function that also works for `tf.TensorSpec`s."""
|
||||||
|
@ -214,11 +229,14 @@ def safe_add(x, y):
|
||||||
class SumAggregationDPQuery(DPQuery):
|
class SumAggregationDPQuery(DPQuery):
|
||||||
"""Base class for DPQueries that aggregate via sum."""
|
"""Base class for DPQueries that aggregate via sum."""
|
||||||
|
|
||||||
def initial_sample_state(self, template):
|
def initial_sample_state(self, template=None):
|
||||||
return tf.nest.map_structure(zeros_like, template)
|
return tf.nest.map_structure(zeros_like, template)
|
||||||
|
|
||||||
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
|
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
|
||||||
return tf.nest.map_structure(safe_add, sample_state, preprocessed_record)
|
return tf.nest.map_structure(safe_add, sample_state, preprocessed_record)
|
||||||
|
|
||||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||||
return tf.nest.map_structure(safe_add, sample_state_1, sample_state_2)
|
return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2)
|
||||||
|
|
||||||
|
def get_noised_result(self, sample_state, global_state):
|
||||||
|
return sample_state, global_state
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2018, The TensorFlow Authors.
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -19,6 +19,8 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
import tree
|
import tree
|
||||||
|
@ -51,6 +53,7 @@ class NestedQuery(dp_query.DPQuery):
|
||||||
self._queries = queries
|
self._queries = queries
|
||||||
|
|
||||||
def _map_to_queries(self, fn, *inputs, **kwargs):
|
def _map_to_queries(self, fn, *inputs, **kwargs):
|
||||||
|
"""Maps DPQuery methods to the subqueries."""
|
||||||
def caller(query, *args):
|
def caller(query, *args):
|
||||||
return getattr(query, fn)(*args, **kwargs)
|
return getattr(query, fn)(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -61,24 +64,22 @@ class NestedQuery(dp_query.DPQuery):
|
||||||
self._map_to_queries('set_ledger', ledger=ledger)
|
self._map_to_queries('set_ledger', ledger=ledger)
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""See base class."""
|
|
||||||
return self._map_to_queries('initial_global_state')
|
return self._map_to_queries('initial_global_state')
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""See base class."""
|
|
||||||
return self._map_to_queries('derive_sample_params', global_state)
|
return self._map_to_queries('derive_sample_params', global_state)
|
||||||
|
|
||||||
def initial_sample_state(self, template):
|
def initial_sample_state(self, template=None):
|
||||||
"""See base class."""
|
if template is None:
|
||||||
return self._map_to_queries('initial_sample_state', template)
|
return self._map_to_queries('initial_sample_state')
|
||||||
|
else:
|
||||||
|
return self._map_to_queries('initial_sample_state', template)
|
||||||
|
|
||||||
def preprocess_record(self, params, record):
|
def preprocess_record(self, params, record):
|
||||||
"""See base class."""
|
|
||||||
return self._map_to_queries('preprocess_record', params, record)
|
return self._map_to_queries('preprocess_record', params, record)
|
||||||
|
|
||||||
def accumulate_preprocessed_record(
|
def accumulate_preprocessed_record(
|
||||||
self, sample_state, preprocessed_record):
|
self, sample_state, preprocessed_record):
|
||||||
"""See base class."""
|
|
||||||
return self._map_to_queries(
|
return self._map_to_queries(
|
||||||
'accumulate_preprocessed_record',
|
'accumulate_preprocessed_record',
|
||||||
sample_state,
|
sample_state,
|
||||||
|
@ -89,18 +90,6 @@ class NestedQuery(dp_query.DPQuery):
|
||||||
'merge_sample_states', sample_state_1, sample_state_2)
|
'merge_sample_states', sample_state_1, sample_state_2)
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Gets query result after all records of sample have been accumulated.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sample_state: The sample state after all records have been accumulated.
|
|
||||||
global_state: The global state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple (result, new_global_state) where "result" is a structure matching
|
|
||||||
the query structure containing the results of the subqueries and
|
|
||||||
"new_global_state" is a structure containing the updated global states
|
|
||||||
for the subqueries.
|
|
||||||
"""
|
|
||||||
estimates_and_new_global_states = self._map_to_queries(
|
estimates_and_new_global_states = self._map_to_queries(
|
||||||
'get_noised_result', sample_state, global_state)
|
'get_noised_result', sample_state, global_state)
|
||||||
|
|
||||||
|
@ -109,8 +98,22 @@ class NestedQuery(dp_query.DPQuery):
|
||||||
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
|
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
|
||||||
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
|
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
|
||||||
|
|
||||||
|
def derive_metrics(self, global_state):
|
||||||
|
metrics = collections.OrderedDict()
|
||||||
|
|
||||||
class NestedSumQuery(dp_query.SumAggregationDPQuery, NestedQuery):
|
def add_metrics(tuple_path, subquery, subquery_global_state):
|
||||||
|
metrics.update({
|
||||||
|
'/'.join(str(s) for s in tuple_path + (name,)): metric
|
||||||
|
for name, metric
|
||||||
|
in subquery.derive_metrics(subquery_global_state).items()})
|
||||||
|
|
||||||
|
tree.map_structure_with_path_up_to(
|
||||||
|
self._queries, add_metrics, self._queries, global_state)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
class NestedSumQuery(NestedQuery, dp_query.SumAggregationDPQuery):
|
||||||
"""A NestedQuery that consists only of SumAggregationDPQueries."""
|
"""A NestedQuery that consists only of SumAggregationDPQueries."""
|
||||||
|
|
||||||
def __init__(self, queries):
|
def __init__(self, queries):
|
||||||
|
|
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -152,6 +153,32 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
nested_query.NestedSumQuery(non_sum_query)
|
nested_query.NestedSumQuery(non_sum_query)
|
||||||
|
|
||||||
|
def test_metrics(self):
|
||||||
|
class QueryWithMetric(dp_query.SumAggregationDPQuery):
|
||||||
|
|
||||||
|
def __init__(self, metric_val):
|
||||||
|
self._metric_val = metric_val
|
||||||
|
|
||||||
|
def derive_metrics(self, global_state):
|
||||||
|
return collections.OrderedDict(metric=self._metric_val)
|
||||||
|
|
||||||
|
query1 = QueryWithMetric(1)
|
||||||
|
query2 = QueryWithMetric(2)
|
||||||
|
query3 = QueryWithMetric(3)
|
||||||
|
|
||||||
|
nested_a = nested_query.NestedSumQuery(query1)
|
||||||
|
global_state = nested_a.initial_global_state()
|
||||||
|
metric_val = nested_a.derive_metrics(global_state)
|
||||||
|
self.assertEqual(metric_val['metric'], 1)
|
||||||
|
|
||||||
|
nested_b = nested_query.NestedSumQuery(
|
||||||
|
{'key1': query1, 'key2': [query2, query3]})
|
||||||
|
global_state = nested_b.initial_global_state()
|
||||||
|
metric_val = nested_b.derive_metrics(global_state)
|
||||||
|
self.assertEqual(metric_val['key1/metric'], 1)
|
||||||
|
self.assertEqual(metric_val['key2/0/metric'], 2)
|
||||||
|
self.assertEqual(metric_val['key2/1/metric'], 3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2019, The TensorFlow Authors.
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -48,11 +48,9 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
|
||||||
assert isinstance(self._numerator, dp_query.SumAggregationDPQuery)
|
assert isinstance(self._numerator, dp_query.SumAggregationDPQuery)
|
||||||
|
|
||||||
def set_ledger(self, ledger):
|
def set_ledger(self, ledger):
|
||||||
"""See base class."""
|
|
||||||
self._numerator.set_ledger(ledger)
|
self._numerator.set_ledger(ledger)
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""See base class."""
|
|
||||||
if self._denominator is not None:
|
if self._denominator is not None:
|
||||||
denominator = tf.cast(self._denominator, tf.float32)
|
denominator = tf.cast(self._denominator, tf.float32)
|
||||||
else:
|
else:
|
||||||
|
@ -61,11 +59,9 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
|
||||||
self._numerator.initial_global_state(), denominator)
|
self._numerator.initial_global_state(), denominator)
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""See base class."""
|
|
||||||
return self._numerator.derive_sample_params(global_state.numerator_state)
|
return self._numerator.derive_sample_params(global_state.numerator_state)
|
||||||
|
|
||||||
def initial_sample_state(self, template):
|
def initial_sample_state(self, template):
|
||||||
"""See base class."""
|
|
||||||
# NormalizedQuery has no sample state beyond the numerator state.
|
# NormalizedQuery has no sample state beyond the numerator state.
|
||||||
return self._numerator.initial_sample_state(template)
|
return self._numerator.initial_sample_state(template)
|
||||||
|
|
||||||
|
@ -73,7 +69,6 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
|
||||||
return self._numerator.preprocess_record(params, record)
|
return self._numerator.preprocess_record(params, record)
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""See base class."""
|
|
||||||
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
|
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
|
||||||
sample_state, global_state.numerator_state)
|
sample_state, global_state.numerator_state)
|
||||||
def normalize(v):
|
def normalize(v):
|
||||||
|
@ -81,3 +76,6 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
|
||||||
return (tf.nest.map_structure(normalize, noised_sum),
|
return (tf.nest.map_structure(normalize, noised_sum),
|
||||||
self._GlobalState(new_sum_global_state, global_state.denominator))
|
self._GlobalState(new_sum_global_state, global_state.denominator))
|
||||||
|
|
||||||
|
def derive_metrics(self, global_state):
|
||||||
|
return self._numerator.derive_metrics(global_state.numerator_state)
|
||||||
|
|
|
@ -104,26 +104,22 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
dp_query.SumAggregationDPQuery)
|
dp_query.SumAggregationDPQuery)
|
||||||
|
|
||||||
def set_ledger(self, ledger):
|
def set_ledger(self, ledger):
|
||||||
"""See base class."""
|
|
||||||
self._sum_query.set_ledger(ledger)
|
self._sum_query.set_ledger(ledger)
|
||||||
self._quantile_estimator_query.set_ledger(ledger)
|
self._quantile_estimator_query.set_ledger(ledger)
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""See base class."""
|
|
||||||
return self._GlobalState(
|
return self._GlobalState(
|
||||||
tf.cast(self._noise_multiplier, tf.float32),
|
tf.cast(self._noise_multiplier, tf.float32),
|
||||||
self._sum_query.initial_global_state(),
|
self._sum_query.initial_global_state(),
|
||||||
self._quantile_estimator_query.initial_global_state())
|
self._quantile_estimator_query.initial_global_state())
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""See base class."""
|
|
||||||
return self._SampleParams(
|
return self._SampleParams(
|
||||||
self._sum_query.derive_sample_params(global_state.sum_state),
|
self._sum_query.derive_sample_params(global_state.sum_state),
|
||||||
self._quantile_estimator_query.derive_sample_params(
|
self._quantile_estimator_query.derive_sample_params(
|
||||||
global_state.quantile_estimator_state))
|
global_state.quantile_estimator_state))
|
||||||
|
|
||||||
def initial_sample_state(self, template):
|
def initial_sample_state(self, template):
|
||||||
"""See base class."""
|
|
||||||
return self._SampleState(
|
return self._SampleState(
|
||||||
self._sum_query.initial_sample_state(template),
|
self._sum_query.initial_sample_state(template),
|
||||||
self._quantile_estimator_query.initial_sample_state())
|
self._quantile_estimator_query.initial_sample_state())
|
||||||
|
@ -138,7 +134,6 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
return self._SampleState(clipped_record, was_unclipped)
|
return self._SampleState(clipped_record, was_unclipped)
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""See base class."""
|
|
||||||
noised_vectors, sum_state = self._sum_query.get_noised_result(
|
noised_vectors, sum_state = self._sum_query.get_noised_result(
|
||||||
sample_state.sum_state, global_state.sum_state)
|
sample_state.sum_state, global_state.sum_state)
|
||||||
del sum_state # To be set explicitly later when we know the new clip.
|
del sum_state # To be set explicitly later when we know the new clip.
|
||||||
|
@ -161,6 +156,9 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
|
||||||
return noised_vectors, new_global_state
|
return noised_vectors, new_global_state
|
||||||
|
|
||||||
|
def derive_metrics(self, global_state):
|
||||||
|
return collections.OrderedDict(clip=global_state.sum_state.l2_norm_clip)
|
||||||
|
|
||||||
|
|
||||||
class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery):
|
class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery):
|
||||||
"""DPQuery for average queries with adaptive clipping.
|
"""DPQuery for average queries with adaptive clipping.
|
||||||
|
|
|
@ -107,11 +107,9 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
denominator=expected_num_records)
|
denominator=expected_num_records)
|
||||||
|
|
||||||
def set_ledger(self, ledger):
|
def set_ledger(self, ledger):
|
||||||
"""See base class."""
|
|
||||||
self._below_estimate_query.set_ledger(ledger)
|
self._below_estimate_query.set_ledger(ledger)
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""See base class."""
|
|
||||||
return self._GlobalState(
|
return self._GlobalState(
|
||||||
tf.cast(self._initial_estimate, tf.float32),
|
tf.cast(self._initial_estimate, tf.float32),
|
||||||
tf.cast(self._target_quantile, tf.float32),
|
tf.cast(self._target_quantile, tf.float32),
|
||||||
|
@ -119,7 +117,6 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
self._below_estimate_query.initial_global_state())
|
self._below_estimate_query.initial_global_state())
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""See base class."""
|
|
||||||
below_estimate_params = self._below_estimate_query.derive_sample_params(
|
below_estimate_params = self._below_estimate_query.derive_sample_params(
|
||||||
global_state.below_estimate_state)
|
global_state.below_estimate_state)
|
||||||
return self._SampleParams(global_state.current_estimate,
|
return self._SampleParams(global_state.current_estimate,
|
||||||
|
@ -141,7 +138,6 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
params.below_estimate_params, below)
|
params.below_estimate_params, below)
|
||||||
|
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""See base class."""
|
|
||||||
below_estimate_result, new_below_estimate_state = (
|
below_estimate_result, new_below_estimate_state = (
|
||||||
self._below_estimate_query.get_noised_result(
|
self._below_estimate_query.get_noised_result(
|
||||||
sample_state,
|
sample_state,
|
||||||
|
@ -170,6 +166,9 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
|
||||||
return new_estimate, new_global_state
|
return new_estimate, new_global_state
|
||||||
|
|
||||||
|
def derive_metrics(self, global_state):
|
||||||
|
return collections.OrderedDict(estimate=global_state.current_estimate)
|
||||||
|
|
||||||
|
|
||||||
class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
|
class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
|
||||||
"""Iterative process to estimate target quantile of a univariate distribution.
|
"""Iterative process to estimate target quantile of a univariate distribution.
|
||||||
|
|
Loading…
Reference in a new issue