diff --git a/tensorflow_privacy/privacy/analysis/privacy_ledger.py b/tensorflow_privacy/privacy/analysis/privacy_ledger.py index ff7eded..08dee5d 100644 --- a/tensorflow_privacy/privacy/analysis/privacy_ledger.py +++ b/tensorflow_privacy/privacy/analysis/privacy_ledger.py @@ -54,7 +54,31 @@ class PrivacyLedger(object): """Class for keeping a record of private queries. The PrivacyLedger keeps a record of all queries executed over a given dataset - for the purpose of computing privacy guarantees. + for the purpose of computing privacy guarantees. To use it, it must be + associated with a `DPQuery` object via a `QueryWithLedger`. + + The current implementation works only with DPQueries that consist of composing + Gaussian sum mechanism with Poisson subsampling. + + Example usage: + + ``` + import tensorflow_privacy as tfp + + dp_query = tfp.QueryWithLedger( + tensorflow_privacy.GaussianSumQuery( + l2_norm_clip=1.0, stddev=1.0), + population_size=10000, + selection_probability=0.01) + + # Use dp_query here in training loop. + + formatted_ledger = dp_query.ledger.get_formatted_ledger_eager() + orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] + + list(range(5, 64)) + [128, 256, 512]) + total_rdp = tfp.compute_rdp_from_ledger(formatted_ledger, orders) + epsilon = tfp.get_privacy_spent(orders, total_rdp, target_delta=1e-5) + ``` """ def __init__(self, @@ -106,7 +130,8 @@ class PrivacyLedger(object): noise_stddev: The standard deviation of the noise applied to the sum. Returns: - An operation recording the sum query to the ledger. + An operation recording the sum query to the ledger. This should be called + for every Gaussian sum query that is issued on a sample. """ def _do_record_query(): @@ -118,7 +143,15 @@ class PrivacyLedger(object): return self._cs.execute(_do_record_query) def finalize_sample(self): - """Finalizes sample and records sample ledger entry.""" + """Finalizes sample and records sample ledger entry. + + This should be called once per application of the mechanism on a sample, + after all sum queries have been recorded. + + Returns: + An operation recording the complete mechanism (sampling and sum + estimation) to the ledger. + """ with tf.control_dependencies([ tf.assign(self._sample_var, [ self._population_size, self._selection_probability, @@ -132,6 +165,7 @@ class PrivacyLedger(object): return self._sample_buffer.append(self._sample_var) def get_unformatted_ledger(self): + """Returns the raw sample and query values.""" return self._sample_buffer.values, self._query_buffer.values def get_formatted_ledger(self, sess): @@ -169,7 +203,10 @@ class QueryWithLedger(dp_query.DPQuery): those contained in the leaves of a nested query) should also contain a reference to the same ledger object. - For example usage, see `privacy_ledger_test.py`. + Only composed Gaussian sum queries with Poisson subsampling are supported. + This includes `GaussianSumQuery`, `QuantileEstimatorQuery`, and + `QuantileAdaptiveClipSumQuery`, as well as `NestedQuery` or `NormalizedQuery` + objects that contain the previous mentioned query types. """ def __init__(self, query, @@ -185,8 +222,8 @@ class QueryWithLedger(dp_query.DPQuery): population, i.e. size of the training data used in each epoch. May be `None` if `ledger` is specified. selection_probability: A floating point value (may be variable) specifying - the probability each record is included in a sample. May be `None` if - `ledger` is specified. + the probability each record is included in a sample under Poisson + subsampling. May be `None` if `ledger` is specified. ledger: A `PrivacyLedger` to use. Must be specified if either of `population_size` or `selection_probability` is `None`. """ @@ -201,46 +238,62 @@ class QueryWithLedger(dp_query.DPQuery): @property def ledger(self): + """Gets the ledger that all inner queries record to.""" return self._ledger def set_ledger(self, ledger): + """Sets a new ledger.""" self._ledger = ledger self._query.set_ledger(ledger) def initial_global_state(self): - """See base class.""" + """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" return self._query.initial_global_state() def derive_sample_params(self, global_state): - """See base class.""" + """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" return self._query.derive_sample_params(global_state) def initial_sample_state(self, template): - """See base class.""" + """Implements `tensorflow_privacy.DPQuery.initial_sample_state`.""" return self._query.initial_sample_state(template) def preprocess_record(self, params, record): - """See base class.""" + """Implements `tensorflow_privacy.DPQuery.preprocess_record`.""" return self._query.preprocess_record(params, record) def accumulate_preprocessed_record(self, sample_state, preprocessed_record): - """See base class.""" + """Implements `tensorflow_privacy.DPQuery.accumulate_preprocessed_record`.""" return self._query.accumulate_preprocessed_record( sample_state, preprocessed_record) def merge_sample_states(self, sample_state_1, sample_state_2): - """See base class.""" + """Implements `tensorflow_privacy.DPQuery.merge_sample_states`.""" return self._query.merge_sample_states(sample_state_1, sample_state_2) def get_noised_result(self, sample_state, global_state): - """Ensures sample is recorded to the ledger and returns noised result.""" + """Implements `tensorflow_privacy.DPQuery.derive_metrics`. + + Besides noising and returning the result of the inner query, ensures that + the sample is recorded to the ledger. + + Args: + sample_state: The sample state after all records have been accumulated. + global_state: The global state, storing long-term privacy bookkeeping. + + Returns: + A tuple (result, new_global_state) where "result" is the result of the + query and "new_global_state" is the updated global state. + """ # Ensure sample_state is fully aggregated before calling get_noised_result. with tf.control_dependencies(tf.nest.flatten(sample_state)): result, new_global_state = self._query.get_noised_result( sample_state, global_state) + # Ensure inner queries have recorded before finalizing. with tf.control_dependencies(tf.nest.flatten(result)): finalize = self._ledger.finalize_sample() + # Ensure finalizing happens. with tf.control_dependencies([finalize]): return tf.nest.map_structure(tf.identity, result), new_global_state diff --git a/tensorflow_privacy/privacy/dp_query/dp_query.py b/tensorflow_privacy/privacy/dp_query/dp_query.py index 06d857c..d7f8e18 100644 --- a/tensorflow_privacy/privacy/dp_query/dp_query.py +++ b/tensorflow_privacy/privacy/dp_query/dp_query.py @@ -53,13 +53,58 @@ import tensorflow.compat.v1 as tf class DPQuery(object): - """Interface for differentially private query mechanisms.""" + """Interface for differentially private query mechanisms. + + Differential privacy is achieved by processing records to bound sensitivity, + accumulating the processed records (usually by summing them) and then + adding noise to the aggregated result. The process can be repeated to compose + applications of the same mechanism, possibly with different parameters. + + The DPQuery interface specifies a functional approach to this process. A + global state maintains state that persists across applications of the + mechanism. For each application, the following steps are performed: + + 1. Use the global state to derive parameters to use for the next sample of + records. + 2. Initialize a sample state that will accumulate processed records. + 3. For each record: + a. Process the record. + b. Accumulate the record into the sample state. + 4. Get the result of the mechanism, possibly updating the global state to use + in the next application. + 5. Derive metrics from the global state. + + Here is an example using the GaussianSumQuery. Assume there is some function + records_for_round(round) that returns an iterable of records to use on some + round. + + ``` + dp_query = tensorflow_privacy.GaussianSumQuery( + l2_norm_clip=1.0, stddev=1.0) + global_state = dp_query.initial_global_state() + + for round in range(num_rounds): + sample_params = dp_query.derive_sample_params(global_state) + sample_state = dp_query.initial_sample_state() + for record in records_for_round(round): + sample_state = dp_query.accumulate_record( + sample_params, sample_state, record) + + result, global_state = dp_query.get_noised_result( + sample_state, global_state) + metrics = dp_query.derive_metrics(global_state) + + # Do something with result and metrics... + ``` + """ __metaclass__ = abc.ABCMeta def set_ledger(self, ledger): """Supplies privacy ledger to which the query can record privacy events. + The ledger should be updated with each call to get_noised_result. + Args: ledger: A `PrivacyLedger`. """ @@ -68,12 +113,26 @@ class DPQuery(object): 'DPQuery type %s does not support set_ledger.' % type(self).__name__) def initial_global_state(self): - """Returns the initial global state for the DPQuery.""" + """Returns the initial global state for the DPQuery. + + The global state contains any state information that changes across + repeated applications of the mechanism. The default implementation returns + just an empty tuple for implementing classes that do not have any persistent + state. + + Returns: + The global state. + """ return () def derive_sample_params(self, global_state): """Given the global state, derives parameters to use for the next sample. + For example, if the mechanism needs to clip records to bound the norm, + the clipping norm should be part of the sample params. In a distributed + context, this is the part of the state that would be sent to the workers + so they can process records. + Args: global_state: The current global state. @@ -87,6 +146,10 @@ class DPQuery(object): def initial_sample_state(self, template=None): """Returns an initial state to use for the next sample. + For typical `DPQuery` classes that are aggregated by summation, this should + return a nested structure of zero tensors of the appropriate shapes, to + which processed records will be aggregated. + Args: template: A nested structure of tensors, TensorSpecs, or numpy arrays used as a template to create the initial sample state. It is assumed that the @@ -145,7 +208,7 @@ class DPQuery(object): This is a helper method that simply delegates to `preprocess_record` and `accumulate_preprocessed_record` for the common case when both of those - functions run on a single device. + functions run on a single device. Typically this will be a simple sum. Args: params: The parameters for the sample. In standard DP-SGD training, @@ -169,6 +232,11 @@ class DPQuery(object): def merge_sample_states(self, sample_state_1, sample_state_2): """Merges two sample states into a single state. + This can be useful if aggregation is performed hierarchically, where + multiple sample states are used to accumulate records and then + hierarchically merged into the final accumulated state. Typically this will + be a simple sum. + Args: sample_state_1: The first sample state to merge. sample_state_2: The second sample state to merge. @@ -180,11 +248,14 @@ class DPQuery(object): @abc.abstractmethod def get_noised_result(self, sample_state, global_state): - """Gets query result after all records of sample have been accumulated. + """Gets the query result after all records of sample have been accumulated. + + The global state can also be updated for use in the next application of the + DP mechanism. Args: - sample_state: The sample state after all records have been accumulated. - In standard DP-SGD training, the accumulated sum of clipped microbatch + sample_state: The sample state after all records have been accumulated. In + standard DP-SGD training, the accumulated sum of clipped microbatch gradients (in the special case of microbatches of size 1, the clipped per-example gradients). global_state: The global state, storing long-term privacy bookkeeping. @@ -213,7 +284,7 @@ class DPQuery(object): return collections.OrderedDict() -def zeros_like(arg): +def _zeros_like(arg): """A `zeros_like` function that also works for `tf.TensorSpec`s.""" try: arg = tf.convert_to_tensor(value=arg) @@ -222,7 +293,8 @@ def zeros_like(arg): return tf.zeros(arg.shape, arg.dtype) -def safe_add(x, y): +def _safe_add(x, y): + """Adds x and y but if y is None, simply returns x.""" return x if y is None else tf.add(x, y) @@ -230,13 +302,17 @@ class SumAggregationDPQuery(DPQuery): """Base class for DPQueries that aggregate via sum.""" def initial_sample_state(self, template=None): - return tf.nest.map_structure(zeros_like, template) + """Implements `tensorflow_privacy.DPQuery.initial_sample_state`.""" + return tf.nest.map_structure(_zeros_like, template) def accumulate_preprocessed_record(self, sample_state, preprocessed_record): - return tf.nest.map_structure(safe_add, sample_state, preprocessed_record) + """Implements `tensorflow_privacy.DPQuery.accumulate_preprocessed_record`.""" + return tf.nest.map_structure(_safe_add, sample_state, preprocessed_record) def merge_sample_states(self, sample_state_1, sample_state_2): + """Implements `tensorflow_privacy.DPQuery.merge_sample_states`.""" return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2) def get_noised_result(self, sample_state, global_state): + """Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" return sample_state, global_state diff --git a/tensorflow_privacy/privacy/dp_query/gaussian_query.py b/tensorflow_privacy/privacy/dp_query/gaussian_query.py index 28a8bb5..bc0888c 100644 --- a/tensorflow_privacy/privacy/dp_query/gaussian_query.py +++ b/tensorflow_privacy/privacy/dp_query/gaussian_query.py @@ -28,7 +28,7 @@ from tensorflow_privacy.privacy.dp_query import dp_query class GaussianSumQuery(dp_query.SumAggregationDPQuery): """Implements DPQuery interface for Gaussian sum queries. - Accumulates clipped vectors, then adds Gaussian noise to the sum. + Clips records to bound the L2 norm, then adds Gaussian noise to the sum. """ # pylint: disable=invalid-name @@ -48,6 +48,7 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery): self._ledger = None def set_ledger(self, ledger): + """Implements `tensorflow_privacy.DPQuery.set_ledger`.""" self._ledger = ledger def make_global_state(self, l2_norm_clip, stddev): @@ -56,9 +57,11 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery): tf.cast(l2_norm_clip, tf.float32), tf.cast(stddev, tf.float32)) def initial_global_state(self): + """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" return self.make_global_state(self._l2_norm_clip, self._stddev) def derive_sample_params(self, global_state): + """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" return global_state.l2_norm_clip def preprocess_record_impl(self, params, record): @@ -79,11 +82,12 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery): return tf.nest.pack_sequence_as(record, clipped_as_list), norm def preprocess_record(self, params, record): + """Implements `tensorflow_privacy.DPQuery.preprocess_record`.""" preprocessed_record, _ = self.preprocess_record_impl(params, record) return preprocessed_record def get_noised_result(self, sample_state, global_state): - """See base class.""" + """Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" if distutils.version.LooseVersion( tf.__version__) < distutils.version.LooseVersion('2.0.0'): diff --git a/tensorflow_privacy/privacy/dp_query/nested_query.py b/tensorflow_privacy/privacy/dp_query/nested_query.py index c57d704..783485e 100644 --- a/tensorflow_privacy/privacy/dp_query/nested_query.py +++ b/tensorflow_privacy/privacy/dp_query/nested_query.py @@ -32,12 +32,18 @@ class NestedQuery(dp_query.DPQuery): NestedQuery evaluates arbitrary nested structures of queries. Records must be nested structures of tensors that are compatible (in type and arity) with the query structure, but are allowed to have deeper structure within each leaf of - the query structure. For example, the nested query [q1, q2] is compatible with - the record [t1, t2] or [t1, (t2, t3)], but not with (t1, t2), [t1] or - [t1, t2, t3]. The entire substructure of each record corresponding to a leaf - node of the query structure is routed to the corresponding query. If the same - tensor should be consumed by multiple sub-queries, it can be replicated in the - record, for example [t1, t1]. + the query structure. The entire substructure of each record corresponding to a + leaf node of the query structure is routed to the corresponding query. + + For example, a nested query with structure "[q1, q2]" is compatible with a + record of structure "[t1, (t2, t3)]": t1 would be processed by q1, and (t2, + t3) would be processed by q2. On the other hand, "[q1, q2]" is not compatible + with "(t1, t2)" (type mismatch), "[t1]" (arity-mismatch) or "[t1, t2, t3]" + (arity-mismatch). + + It is possible for the same tensor to be consumed by multiple sub-queries, by + simply replicating it in the record, for example providing "[t1, t1]" to + "[q1, q2]". NestedQuery is intended to allow privacy mechanisms for groups as described in [McMahan & Andrew, 2018: "A General Approach to Adding Differential Privacy to @@ -61,35 +67,43 @@ class NestedQuery(dp_query.DPQuery): *inputs) def set_ledger(self, ledger): + """Implements `tensorflow_privacy.DPQuery.set_ledger`.""" self._map_to_queries('set_ledger', ledger=ledger) def initial_global_state(self): + """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" return self._map_to_queries('initial_global_state') def derive_sample_params(self, global_state): + """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" return self._map_to_queries('derive_sample_params', global_state) def initial_sample_state(self, template=None): + """Implements `tensorflow_privacy.DPQuery.initial_sample_state`.""" if template is None: return self._map_to_queries('initial_sample_state') else: return self._map_to_queries('initial_sample_state', template) def preprocess_record(self, params, record): + """Implements `tensorflow_privacy.DPQuery.preprocess_record`.""" return self._map_to_queries('preprocess_record', params, record) def accumulate_preprocessed_record( self, sample_state, preprocessed_record): + """Implements `tensorflow_privacy.DPQuery.accumulate_preprocessed_record`.""" return self._map_to_queries( 'accumulate_preprocessed_record', sample_state, preprocessed_record) def merge_sample_states(self, sample_state_1, sample_state_2): + """Implements `tensorflow_privacy.DPQuery.merge_sample_states`.""" return self._map_to_queries( 'merge_sample_states', sample_state_1, sample_state_2) def get_noised_result(self, sample_state, global_state): + """Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" estimates_and_new_global_states = self._map_to_queries( 'get_noised_result', sample_state, global_state) @@ -99,6 +113,7 @@ class NestedQuery(dp_query.DPQuery): tf.nest.pack_sequence_as(self._queries, flat_new_global_states)) def derive_metrics(self, global_state): + """Implements `tensorflow_privacy.DPQuery.derive_metrics`.""" metrics = collections.OrderedDict() def add_metrics(tuple_path, subquery, subquery_global_state): @@ -122,6 +137,8 @@ class NestedSumQuery(NestedQuery, dp_query.SumAggregationDPQuery): Args: queries: A nested structure of queries that must all be SumAggregationDPQueries. + + Raises: TypeError if any of the subqueries are not SumAggregationDPQueries. """ def check(query): if not isinstance(query, dp_query.SumAggregationDPQuery): diff --git a/tensorflow_privacy/privacy/dp_query/no_privacy_query.py b/tensorflow_privacy/privacy/dp_query/no_privacy_query.py index 889fcb2..bee419c 100644 --- a/tensorflow_privacy/privacy/dp_query/no_privacy_query.py +++ b/tensorflow_privacy/privacy/dp_query/no_privacy_query.py @@ -34,13 +34,14 @@ class NoPrivacySumQuery(dp_query.SumAggregationDPQuery): self._ledger = None def set_ledger(self, ledger): + """Implements `tensorflow_privacy.DPQuery.set_ledger`.""" warnings.warn( 'Attempt to use NoPrivacySumQuery with privacy ledger. Privacy ' 'guarantees will be vacuous.') self._ledger = ledger def get_noised_result(self, sample_state, global_state): - """See base class.""" + """Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" if self._ledger: dependencies = [ @@ -57,35 +58,67 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery): """Implements DPQuery interface for an average query with no privacy. Accumulates vectors and normalizes by the total number of accumulated vectors. + Under some sampling schemes, such as Poisson subsampling, the number of + records in a sample is a private quantity, so we lose all privacy guarantees + by using the number of records directly to normalize. + + Also allows weighted accumulation, unlike the base class DPQuery. In a private + implementation of weighted average, the weight would have to be itself + privatized. """ def __init__(self): + """Initializes the NoPrivacyAverageQuery.""" self._ledger = None def set_ledger(self, ledger): + """Implements `tensorflow_privacy.DPQuery.set_ledger`.""" warnings.warn( 'Attempt to use NoPrivacyAverageQuery with privacy ledger. Privacy ' 'guarantees will be vacuous.') self._ledger = ledger def initial_sample_state(self, template): - """See base class.""" + """Implements `tensorflow_privacy.DPQuery.initial_sample_state`.""" return (super(NoPrivacyAverageQuery, self).initial_sample_state(template), tf.constant(0.0)) def preprocess_record(self, params, record, weight=1): - """Multiplies record by weight.""" + """Implements `tensorflow_privacy.DPQuery.preprocess_record`. + + Optional `weight` argument allows weighted accumulation. + + Args: + params: The parameters for the sample. + record: The record to accumulate. + weight: Optional weight for the record. + + Returns: + The preprocessed record. + """ weighted_record = tf.nest.map_structure(lambda t: weight * t, record) return (weighted_record, tf.cast(weight, tf.float32)) def accumulate_record(self, params, sample_state, record, weight=1): - """Accumulates record, multiplying by weight.""" + """Implements `tensorflow_privacy.DPQuery.accumulate_record`. + + Optional `weight` argument allows weighted accumulation. + + Args: + params: The parameters for the sample. + sample_state: The current sample state. + record: The record to accumulate. + weight: Optional weight for the record. + + Returns: + The updated sample state. + """ weighted_record = tf.nest.map_structure(lambda t: weight * t, record) return self.accumulate_preprocessed_record( sample_state, (weighted_record, tf.cast(weight, tf.float32))) def get_noised_result(self, sample_state, global_state): - """See base class.""" + """Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" sum_state, denominator = sample_state if self._ledger: diff --git a/tensorflow_privacy/privacy/dp_query/normalized_query.py b/tensorflow_privacy/privacy/dp_query/normalized_query.py index b59700f..2b9cdfc 100644 --- a/tensorflow_privacy/privacy/dp_query/normalized_query.py +++ b/tensorflow_privacy/privacy/dp_query/normalized_query.py @@ -27,14 +27,22 @@ from tensorflow_privacy.privacy.dp_query import dp_query class NormalizedQuery(dp_query.SumAggregationDPQuery): - """DPQuery for queries with a DPQuery numerator and fixed denominator.""" + """`DPQuery` for queries with a `DPQuery` numerator and fixed denominator. + + If the number of records per round is a public constant R, `NormalizedQuery` + could be used with a sum query as the numerator and R as the denominator to + implement an average. Under some sampling schemes, such as Poisson + subsampling, the actual number of records in a sample is a private quantity, + so we cannot use it directly. Using this class with the expected number of + records as the denominator gives an unbiased estimate of the average. + """ # pylint: disable=invalid-name _GlobalState = collections.namedtuple( '_GlobalState', ['numerator_state', 'denominator']) def __init__(self, numerator_query, denominator): - """Initializer for NormalizedQuery. + """Initializes the NormalizedQuery. Args: numerator_query: A SumAggregationDPQuery for the numerator. @@ -48,27 +56,30 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery): assert isinstance(self._numerator, dp_query.SumAggregationDPQuery) def set_ledger(self, ledger): + """Implements `tensorflow_privacy.DPQuery.set_ledger`.""" self._numerator.set_ledger(ledger) def initial_global_state(self): - if self._denominator is not None: - denominator = tf.cast(self._denominator, tf.float32) - else: - denominator = None + """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" + denominator = tf.cast(self._denominator, tf.float32) return self._GlobalState( self._numerator.initial_global_state(), denominator) def derive_sample_params(self, global_state): + """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" return self._numerator.derive_sample_params(global_state.numerator_state) def initial_sample_state(self, template): + """Implements `tensorflow_privacy.DPQuery.initial_sample_state`.""" # NormalizedQuery has no sample state beyond the numerator state. return self._numerator.initial_sample_state(template) def preprocess_record(self, params, record): + """Implements `tensorflow_privacy.DPQuery.preprocess_record`.""" return self._numerator.preprocess_record(params, record) def get_noised_result(self, sample_state, global_state): + """Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" noised_sum, new_sum_global_state = self._numerator.get_noised_result( sample_state, global_state.numerator_state) def normalize(v): @@ -78,4 +89,5 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery): self._GlobalState(new_sum_global_state, global_state.denominator)) def derive_metrics(self, global_state): + """Implements `tensorflow_privacy.DPQuery.derive_metrics`.""" return self._numerator.derive_metrics(global_state.numerator_state) diff --git a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py index a7d4ef7..4d3cd2a 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query.py @@ -11,14 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Implements DPQuery interface for adaptive clip queries. - -Instead of a fixed clipping norm specified in advance, the clipping norm is -dynamically adjusted to match a target fraction of clipped updates per sample, -where the actual fraction of clipped updates is itself estimated in a -differentially private manner. For details see Thakkar et al., "Differentially -Private Learning with Adaptive Clipping" [http://arxiv.org/abs/1905.03871]. -""" +"""`DPQuery` for Gaussian sum queries with adaptive clipping.""" from __future__ import absolute_import from __future__ import division @@ -34,10 +27,12 @@ from tensorflow_privacy.privacy.dp_query import quantile_estimator_query class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery): - """DPQuery for sum queries with adaptive clipping. + """`DPQuery` for Gaussian sum queries with adaptive clipping. Clipping norm is tuned adaptively to converge to a value such that a specified - quantile of updates are clipped. + quantile of updates are clipped, using the algorithm of Andrew et al. ( + https://arxiv.org/abs/1905.03871). See the paper for details and suggested + hyperparameter settings. """ # pylint: disable=invalid-name @@ -65,20 +60,23 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery): Args: initial_l2_norm_clip: The initial value of clipping norm. - noise_multiplier: The multiplier of the l2_norm_clip to make the stddev of - the noise added to the output of the sum query. + noise_multiplier: The stddev of the noise added to the output will be this + times the current value of the clipping norm. target_unclipped_quantile: The desired quantile of updates which should be unclipped. I.e., a value of 0.8 means a value of l2_norm_clip should be found for which approximately 20% of updates are clipped each round. - learning_rate: The learning rate for the clipping norm adaptation. A rate - of r means that the clipping norm will change by a maximum of r at each - step. This maximum is attained when |clip - target| is 1.0. + Andrew et al. recommends that this be set to 0.5 to clip to the median. + learning_rate: The learning rate for the clipping norm adaptation. With + geometric updating, a rate of r means that the clipping norm will change + by a maximum factor of exp(r) at each round. This maximum is attained + when |actual_unclipped_fraction - target_unclipped_quantile| is 1.0. + Andrew et al. recommends that this be set to 0.2 for geometric updating. clipped_count_stddev: The stddev of the noise added to the clipped_count. - Since the sensitivity of the clipped count is 0.5, as a rule of thumb it - should be about 0.5 for reasonable privacy. + Andrew et al. recommends that this be set to `expected_num_records / 20` + for reasonably fast adaptation and high privacy. expected_num_records: The expected number of records per round, used to estimate the clipped count quantile. - geometric_update: If True, use geometric updating of clip. + geometric_update: If `True`, use geometric updating of clip (recommended). """ self._noise_multiplier = noise_multiplier @@ -94,27 +92,32 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery): dp_query.SumAggregationDPQuery) def set_ledger(self, ledger): + """Implements `tensorflow_privacy.DPQuery.set_ledger`.""" self._sum_query.set_ledger(ledger) self._quantile_estimator_query.set_ledger(ledger) def initial_global_state(self): + """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" return self._GlobalState( tf.cast(self._noise_multiplier, tf.float32), self._sum_query.initial_global_state(), self._quantile_estimator_query.initial_global_state()) def derive_sample_params(self, global_state): + """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" return self._SampleParams( self._sum_query.derive_sample_params(global_state.sum_state), self._quantile_estimator_query.derive_sample_params( global_state.quantile_estimator_state)) def initial_sample_state(self, template): + """Implements `tensorflow_privacy.DPQuery.initial_sample_state`.""" return self._SampleState( self._sum_query.initial_sample_state(template), self._quantile_estimator_query.initial_sample_state()) def preprocess_record(self, params, record): + """Implements `tensorflow_privacy.DPQuery.preprocess_record`.""" clipped_record, global_norm = ( self._sum_query.preprocess_record_impl(params.sum_params, record)) @@ -124,6 +127,7 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery): return self._SampleState(clipped_record, was_unclipped) def get_noised_result(self, sample_state, global_state): + """Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" noised_vectors, sum_state = self._sum_query.get_noised_result( sample_state.sum_state, global_state.sum_state) del sum_state # To be set explicitly later when we know the new clip. @@ -145,4 +149,5 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery): return noised_vectors, new_global_state def derive_metrics(self, global_state): + """Returns the current clipping norm as a metric.""" return collections.OrderedDict(clip=global_state.sum_state.l2_norm_clip) diff --git a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py index 405d83f..4358a95 100644 --- a/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py +++ b/tensorflow_privacy/privacy/dp_query/quantile_estimator_query.py @@ -11,13 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Implements DPQuery interface for quantile estimator. - -From a starting estimate of the target quantile, the estimate is updated -dynamically where the fraction of below_estimate updates is estimated in a -differentially private manner. For details see Thakkar et al., "Differentially -Private Learning with Adaptive Clipping" [http://arxiv.org/abs/1905.03871]. -""" +"""Implements DPQuery interface for quantile estimator.""" from __future__ import absolute_import from __future__ import division @@ -33,7 +27,11 @@ from tensorflow_privacy.privacy.dp_query import normalized_query class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery): - """Iterative process to estimate target quantile of a univariate distribution.""" + """DPQuery to estimate target quantile of a univariate distribution. + + Uses the algorithm of Andrew et al. (https://arxiv.org/abs/1905.03871). See + the paper for details and suggested hyperparameter settings. + """ # pylint: disable=invalid-name _GlobalState = collections.namedtuple('_GlobalState', [ @@ -55,7 +53,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery): below_estimate_stddev, expected_num_records, geometric_update=False): - """Initializes the QuantileAdaptiveClipSumQuery. + """Initializes the QuantileEstimatorQuery. Args: initial_estimate: The initial estimate of the quantile. @@ -64,11 +62,12 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery): estimate each round. learning_rate: The learning rate. A rate of r means that the estimate will change by a maximum of r at each step (for arithmetic updating) or by a - maximum factor of exp(r) (for geometric updating). + maximum factor of exp(r) (for geometric updating). Andrew et al. + recommends that this be set to 0.2 for geometric updating. below_estimate_stddev: The stddev of the noise added to the count of - records currently below the estimate. Since the sensitivity of the count - query is 0.5, as a rule of thumb it should be about 0.5 for reasonable - privacy. + records currently below the estimate. Andrew et al. recommends that this + be set to `expected_num_records / 20` for reasonably fast adaptation and + high privacy. expected_num_records: The expected number of records per round. geometric_update: If True, use geometric updating of estimate. Geometric updating is preferred for non-negative records like vector norms that @@ -102,9 +101,11 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery): denominator=expected_num_records) def set_ledger(self, ledger): + """Implements `tensorflow_privacy.DPQuery.set_ledger`.""" self._below_estimate_query.set_ledger(ledger) def initial_global_state(self): + """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" return self._GlobalState( tf.cast(self._initial_estimate, tf.float32), tf.cast(self._target_quantile, tf.float32), @@ -112,39 +113,42 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery): self._below_estimate_query.initial_global_state()) def derive_sample_params(self, global_state): + """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" below_estimate_params = self._below_estimate_query.derive_sample_params( global_state.below_estimate_state) return self._SampleParams(global_state.current_estimate, below_estimate_params) def initial_sample_state(self, template=None): + """Implements `tensorflow_privacy.DPQuery.initial_sample_state`.""" # Template is ignored because records are required to be scalars. del template return self._below_estimate_query.initial_sample_state(0.0) def preprocess_record(self, params, record): + """Implements `tensorflow_privacy.DPQuery.preprocess_record`.""" tf.debugging.assert_scalar(record) - # We accumulate counts shifted by 0.5 so they are centered at zero. - # This makes the sensitivity of the count query 0.5 instead of 1.0. + # Shift counts by 0.5 so they are centered at zero. (See comment in + # `_construct_below_estimate_query`.) below = tf.cast(record <= params.current_estimate, tf.float32) - 0.5 return self._below_estimate_query.preprocess_record( params.below_estimate_params, below) def get_noised_result(self, sample_state, global_state): + """Implements `tensorflow_privacy.DPQuery.get_noised_result`.""" below_estimate_result, new_below_estimate_state = ( self._below_estimate_query.get_noised_result( sample_state, global_state.below_estimate_state)) - # Unshift below_estimate percentile by 0.5. (See comment in initializer.) + # Unshift below_estimate percentile by 0.5. (See comment in + # `_construct_below_estimate_query`.) below_estimate = below_estimate_result + 0.5 # Protect against out-of-range estimates. below_estimate = tf.minimum(1.0, tf.maximum(0.0, below_estimate)) - # Loss function is convex, with derivative in [-1, 1], and minimized when - # the true quantile matches the target. loss_grad = below_estimate - global_state.target_quantile update = global_state.learning_rate * loss_grad @@ -161,6 +165,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery): return new_estimate, new_global_state def derive_metrics(self, global_state): + """Implements `tensorflow_privacy.DPQuery.derive_metrics`.""" return collections.OrderedDict(estimate=global_state.current_estimate) @@ -168,7 +173,7 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery): """Iterative process to estimate target quantile of a univariate distribution. Unlike the base class, this uses a NoPrivacyQuery to estimate the fraction - below estimate with an exact denominator. + below estimate with an exact denominator, so there are no privacy guarantees. """ def __init__(self, @@ -185,7 +190,8 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery): estimate each round. learning_rate: The learning rate. A rate of r means that the estimate will change by a maximum of r at each step (for arithmetic updating) or by a - maximum factor of exp(r) (for geometric updating). + maximum factor of exp(r) (for geometric updating). Andrew et al. + recommends that this be set to 0.2 for geometric updating. geometric_update: If True, use geometric updating of estimate. Geometric updating is preferred for non-negative records like vector norms that could potentially be very large or very close to zero. diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation.py index ad84053..58d4094 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation.py @@ -27,29 +27,63 @@ import tensorflow as tf class ValueGenerator(metaclass=abc.ABCMeta): - """Base class establishing interface for stateful value generation.""" + """Base class establishing interface for stateful value generation. + + A `ValueGenerator` maintains a state, and each time `next` is called, a new + value is generated and the state is advanced. + """ @abc.abstractmethod def initialize(self): - """Returns initialized state.""" + """Makes an initialized state for the ValueGenerator. + + Returns: + An initial state. + """ @abc.abstractmethod def next(self, state): - """Returns tree node value and updated state.""" + """Gets next value and advances the ValueGenerator. + + Args: + state: The current state. + + Returns: + A pair (value, new_state) where value is the next value and new_state + is the advanced state. + """ class GaussianNoiseGenerator(ValueGenerator): - """Gaussian noise generator with counter as pseudo state.""" + """Gaussian noise generator with counter as pseudo state. + + Produces i.i.d. spherical Gaussian noise at each step shaped according to a + nested structure of `tf.TensorSpec`s. + """ def __init__(self, noise_std: float, specs: Collection[tf.TensorSpec], seed: Optional[int] = None): + """Initializes the GaussianNoiseGenerator. + + Args: + noise_std: The standard deviation of the noise. + specs: A nested structure of `tf.TensorSpec`s specifying the shape of the + noise to generate. + seed: An optional integer seed. If None, generator is seeded from the + clock. + """ self.noise_std = noise_std self.specs = specs self.seed = seed def initialize(self): + """Makes an initial state for the GaussianNoiseGenerator. + + Returns: + An initial state. + """ if self.seed is None: return tf.cast( tf.stack([ @@ -61,6 +95,15 @@ class GaussianNoiseGenerator(ValueGenerator): return tf.constant(self.seed, dtype=tf.int64, shape=(2,)) def next(self, state): + """Gets next value and advances the GaussianNoiseGenerator. + + Args: + state: The current state. + + Returns: + A pair (sample, new_state) where sample is a new sample and new_state + is the advanced state. + """ flat_structure = tf.nest.flatten(self.specs) flat_seeds = [state + i for i in range(len(flat_structure))] nest_seeds = tf.nest.pack_sequence_as(self.specs, flat_seeds) @@ -74,15 +117,34 @@ class GaussianNoiseGenerator(ValueGenerator): class StatelessValueGenerator(ValueGenerator): - """A wrapper for stateless value generator initialized by a no-arg function.""" + """A wrapper for stateless value generator that calls a no-arg function.""" def __init__(self, value_fn): + """Initializes the StatelessValueGenerator. + + Args: + value_fn: The function to call to generate values. + """ self.value_fn = value_fn def initialize(self): + """Makes an initialized state for the StatelessValueGenerator. + + Returns: + An initial state (empty, because stateless). + """ return () def next(self, state): + """Gets next value. + + Args: + state: The current state (simply passed through). + + Returns: + A pair (value, new_state) where value is the next value and new_state + is the advanced state. + """ return self.value_fn(), state @@ -127,7 +189,12 @@ class TreeAggregator(): """ def __init__(self, value_generator: Union[ValueGenerator, Callable[[], Any]]): - """Initialize the aggregator with a noise generator.""" + """Initialize the aggregator with a noise generator. + + Args: + value_generator: A `ValueGenerator` or a no-arg function to generate a + noise value for each tree node. + """ if isinstance(value_generator, ValueGenerator): self.value_generator = value_generator else: @@ -235,7 +302,7 @@ class EfficientTreeAggregator(): This class implements the efficient tree aggregation algorithm based on Honaker 2015 "Efficient Use of Differentially Private Binary Trees". - The noise standard deviation for the note at depth d is roughly + The noise standard deviation for a node at depth d is roughly `sigma * sqrt(2^{d-1}/(2^d-1))`. which becomes `sigma / sqrt(2)` when the tree is very tall. @@ -245,7 +312,12 @@ class EfficientTreeAggregator(): """ def __init__(self, value_generator: Union[ValueGenerator, Callable[[], Any]]): - """Initialize the aggregator with a noise generator.""" + """Initialize the aggregator with a noise generator. + + Args: + value_generator: A `ValueGenerator` or a no-arg function to generate a + noise value for each tree node. + """ if isinstance(value_generator, ValueGenerator): self.value_generator = value_generator else: @@ -257,6 +329,9 @@ class EfficientTreeAggregator(): Initializes `TreeState` for a tree of a single leaf node: the respective initial node value in `TreeState.level_buffer` is generated by the value generator function, and the node index is 0. + + Returns: + An initialized `TreeState`. """ value_generator_state = self.value_generator.initialize() level_buffer_idx = tf.TensorArray(dtype=tf.int32, size=1, dynamic_size=True) diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index 5580222..fb7dc76 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -68,7 +68,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): `TreeCumulativeSumQuery` with L2 norm clipping and Gaussian noise. Args: - record_specs: `Collection[tf.TensorSpec]` specifying shapes of records. + record_specs: A nested structure of `tf.TensorSpec`s specifying structure + and shapes of records. noise_generator: `tree_aggregation.ValueGenerator` to generate the noise value for a tree node. Should be coupled with clipping norm to guarantee privacy. @@ -89,7 +90,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator) def initial_global_state(self): - """Returns initial global state.""" + """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" initial_tree_state = self._tree_aggregator.init_state() initial_samples_cumulative_sum = tf.nest.map_structure( lambda spec: tf.zeros(spec.shape), self._record_specs) @@ -100,10 +101,11 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): return initial_state def derive_sample_params(self, global_state): + """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" return global_state.clip_value def preprocess_record(self, params, record): - """Returns the clipped record using `clip_fn` and params. + """Implements `tensorflow_privacy.DPQuery.preprocess_record`. Args: params: `clip_value` for the record. @@ -118,14 +120,16 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): return tf.nest.pack_sequence_as(record, clipped_as_list) def get_noised_result(self, sample_state, global_state): - """Updates tree, state, and returns noised cumulative sum and updated state. + """Implements `tensorflow_privacy.DPQuery.get_noised_result`. - Computes new cumulative sum, and returns its noised value. Grows tree_state + Updates tree state, and returns noised cumulative sum and updated state. + + Computes new cumulative sum, and returns its noised value. Grows tree state by one new leaf, and returns the new state. Args: sample_state: Sum of clipped records for this round. - global_state: Global state with current samples cumulative sum and tree + global_state: Global state with current sample's cumulative sum and tree state. Returns: @@ -157,7 +161,8 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): `clip_norm`. noise_multiplier: The effective noise multiplier for the sum of records. Noise standard deviation is `clip_norm*noise_multiplier`. - record_specs: `Collection[tf.TensorSpec]` specifying shapes of records. + record_specs: A nested structure of `tf.TensorSpec`s specifying structure + and shapes of records. noise_seed: Integer seed for the Gaussian noise generator. If `None`, a nondeterministic seed based on system time will be generated. use_efficient: Boolean indicating the usage of the efficient tree @@ -204,9 +209,10 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): Attributes: clip_fn: Callable that specifies clipping function. `clip_fn` receives two arguments: a flat list of vars in a record and a `clip_value` to clip the - corresponding record, e.g. clip_fn(flat_record, clip_value). + corresponding record, e.g. clip_fn(flat_record, clip_value). clip_value: float indicating the value at which to clip the record. - record_specs: `Collection[tf.TensorSpec]` specifying shapes of records. + record_specs: A nested structure of `tf.TensorSpec`s specifying structure + and shapes of records. tree_aggregator: `tree_aggregation.TreeAggregator` initialized with user defined `noise_generator`. `noise_generator` is a `tree_aggregation.ValueGenerator` to generate the noise value for a tree @@ -242,7 +248,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): `TreeResidualSumQuery` with L2 norm clipping and Gaussian noise. Args: - record_specs: `Collection[tf.TensorSpec]` specifying shapes of records. + record_specs: A nested structure of `tf.TensorSpec`s specifying structure + and shapes of records. noise_generator: `tree_aggregation.ValueGenerator` to generate the noise value for a tree node. Should be coupled with clipping norm to guarantee privacy. @@ -263,7 +270,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator) def initial_global_state(self): - """Returns initial global state.""" + """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" initial_tree_state = self._tree_aggregator.init_state() initial_noise = tf.nest.map_structure(lambda spec: tf.zeros(spec.shape), self._record_specs) @@ -273,10 +280,11 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): previous_tree_noise=initial_noise) def derive_sample_params(self, global_state): + """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" return global_state.clip_value def preprocess_record(self, params, record): - """Returns the clipped record using `clip_fn` and params. + """Implements `tensorflow_privacy.DPQuery.preprocess_record`. Args: params: `clip_value` for the record. @@ -291,7 +299,9 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): return tf.nest.pack_sequence_as(record, clipped_as_list) def get_noised_result(self, sample_state, global_state): - """Updates tree state, and returns residual of noised cumulative sum. + """Implements `tensorflow_privacy.DPQuery.get_noised_result`. + + Updates tree state, and returns residual of noised cumulative sum. Args: sample_state: Sum of clipped records for this round. @@ -324,7 +334,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): `clip_norm`. noise_multiplier: The effective noise multiplier for the sum of records. Noise standard deviation is `clip_norm*noise_multiplier`. - record_specs: `Collection[tf.TensorSpec]` specifying shapes of records. + record_specs: A nested structure of `tf.TensorSpec`s specifying structure + and shapes of records. noise_seed: Integer seed for the Gaussian noise generator. If `None`, a nondeterministic seed based on system time will be generated. use_efficient: Boolean indicating the usage of the efficient tree