From aaf029edadff577f5882d396f88fc294616194c4 Mon Sep 17 00:00:00 2001 From: Galen Andrew Date: Tue, 14 May 2019 13:35:08 -0700 Subject: [PATCH] Add quantile_adaptive_clip_sum_query which dynamically adjusts the clipping norm so a specified fraction of records per sample are clipped. PiperOrigin-RevId: 248201320 --- privacy/__init__.py | 2 + privacy/dp_query/BUILD | 23 ++ .../quantile_adaptive_clip_sum_query.py | 271 ++++++++++++++++ .../quantile_adaptive_clip_sum_query_test.py | 298 ++++++++++++++++++ 4 files changed, 594 insertions(+) create mode 100644 privacy/dp_query/quantile_adaptive_clip_sum_query.py create mode 100644 privacy/dp_query/quantile_adaptive_clip_sum_query_test.py diff --git a/privacy/__init__.py b/privacy/__init__.py index 176623c..4aa5cc5 100644 --- a/privacy/__init__.py +++ b/privacy/__init__.py @@ -33,6 +33,8 @@ else: from privacy.dp_query.no_privacy_query import NoPrivacyAverageQuery from privacy.dp_query.no_privacy_query import NoPrivacySumQuery from privacy.dp_query.normalized_query import NormalizedQuery + from privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery + from privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipAverageQuery from privacy.optimizers.dp_optimizer import DPAdagradGaussianOptimizer from privacy.optimizers.dp_optimizer import DPAdagradOptimizer diff --git a/privacy/dp_query/BUILD b/privacy/dp_query/BUILD index e91dce4..2494096 100644 --- a/privacy/dp_query/BUILD +++ b/privacy/dp_query/BUILD @@ -105,6 +105,29 @@ py_test( ], ) +py_library( + name = "quantile_adaptive_clip_sum_query", + srcs = ["quantile_adaptive_clip_sum_query.py"], + deps = [ + ":dp_query", + ":gaussian_query", + ":normalized_query", + "//third_party/py/tensorflow", + ], +) + +py_test( + name = "quantile_adaptive_clip_sum_query_test", + srcs = ["quantile_adaptive_clip_sum_query_test.py"], + deps = [ + ":quantile_adaptive_clip_sum_query", + ":test_utils", + "//third_party/py/numpy", + "//third_party/py/tensorflow", + "//third_party/py/tensorflow_privacy/privacy/analysis:privacy_ledger", + ], +) + py_library( name = "test_utils", srcs = ["test_utils.py"], diff --git a/privacy/dp_query/quantile_adaptive_clip_sum_query.py b/privacy/dp_query/quantile_adaptive_clip_sum_query.py new file mode 100644 index 0000000..3391eb1 --- /dev/null +++ b/privacy/dp_query/quantile_adaptive_clip_sum_query.py @@ -0,0 +1,271 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implements DPQuery interface for adaptive clip queries. + +Instead of a fixed clipping norm specified in advance, the clipping norm is +dynamically adjusted to match a target fraction of clipped updates per sample, +where the actual fraction of clipped updates is itself estimated in a +differentially private manner. For details see Thakkar et al., "Differentially +Private Learning with Adaptive Clipping" [http://arxiv.org/abs/1905.03871]. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import tensorflow as tf + +from privacy.dp_query import dp_query +from privacy.dp_query import gaussian_query +from privacy.dp_query import normalized_query + +nest = tf.contrib.framework.nest + + +class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): + """DPQuery for sum queries with adaptive clipping. + + Clipping norm is tuned adaptively to converge to a value such that a specified + quantile of updates are clipped. + """ + + # pylint: disable=invalid-name + _GlobalState = collections.namedtuple( + '_GlobalState', ['l2_norm_clip', 'sum_state', 'clipped_fraction_state']) + + # pylint: disable=invalid-name + _SampleState = collections.namedtuple( + '_SampleState', ['sum_state', 'clipped_fraction_state']) + + # pylint: disable=invalid-name + _SampleParams = collections.namedtuple( + '_SampleParams', ['sum_params', 'clipped_fraction_params']) + + def __init__( + self, + initial_l2_norm_clip, + noise_multiplier, + target_unclipped_quantile, + learning_rate, + clipped_count_stddev, + expected_num_records, + ledger=None): + """Initializes the QuantileAdaptiveClipSumQuery. + + Args: + initial_l2_norm_clip: The initial value of clipping norm. + noise_multiplier: The multiplier of the l2_norm_clip to make the stddev of + the noise added to the output of the sum query. + target_unclipped_quantile: The desired quantile of updates which should be + unclipped. I.e., a value of 0.8 means a value of l2_norm_clip should be + found for which approximately 20% of updates are clipped each round. + learning_rate: The learning rate for the clipping norm adaptation. A + rate of r means that the clipping norm will change by a maximum of r at + each step. This maximum is attained when |clip - target| is 1.0. Can be + a tf.Variable for example to implement a learning rate schedule. + clipped_count_stddev: The stddev of the noise added to the clipped_count. + Since the sensitivity of the clipped count is 0.5, as a rule of thumb it + should be about 0.5 for reasonable privacy. + expected_num_records: The expected number of records per round, used to + estimate the clipped count quantile. + ledger: The privacy ledger to which queries should be recorded. + """ + self._initial_l2_norm_clip = tf.cast(initial_l2_norm_clip, tf.float32) + self._noise_multiplier = tf.cast(noise_multiplier, tf.float32) + self._target_unclipped_quantile = tf.cast( + target_unclipped_quantile, tf.float32) + self._learning_rate = tf.cast(learning_rate, tf.float32) + + self._l2_norm_clip = tf.Variable(self._initial_l2_norm_clip) + self._sum_stddev = tf.Variable( + self._initial_l2_norm_clip * self._noise_multiplier) + self._sum_query = gaussian_query.GaussianSumQuery( + self._l2_norm_clip, + self._sum_stddev, + ledger) + + # self._clipped_fraction_query is a DPQuery used to estimate the fraction of + # records that are clipped. It accumulates an indicator 0/1 of whether each + # record is clipped, and normalizes by the expected number of records. In + # practice, we accumulate clipped counts shifted by -0.5 so they are + # centered at zero. This makes the sensitivity of the clipped count query + # 0.5 instead of 1.0, since the maximum that a single record could affect + # the count is 0.5. Note that although the l2_norm_clip of the clipped + # fraction query is 0.5, no clipping will ever actually occur because the + # value of each record is always +/-0.5. + self._clipped_fraction_query = gaussian_query.GaussianAverageQuery( + l2_norm_clip=0.5, + sum_stddev=clipped_count_stddev, + denominator=expected_num_records, + ledger=ledger) + + def initial_global_state(self): + """See base class.""" + return self._GlobalState( + self._initial_l2_norm_clip, + self._sum_query.initial_global_state(), + self._clipped_fraction_query.initial_global_state()) + + @tf.function + def derive_sample_params(self, global_state): + """See base class.""" + gs = global_state + + # Assign values to variables that inner sum query uses. + tf.assign(self._l2_norm_clip, gs.l2_norm_clip) + tf.assign(self._sum_stddev, gs.l2_norm_clip * self._noise_multiplier) + sum_params = self._sum_query.derive_sample_params(gs.sum_state) + clipped_fraction_params = self._clipped_fraction_query.derive_sample_params( + gs.clipped_fraction_state) + return self._SampleParams(sum_params, clipped_fraction_params) + + def initial_sample_state(self, global_state, template): + """See base class.""" + clipped_fraction_state = self._clipped_fraction_query.initial_sample_state( + global_state.clipped_fraction_state, tf.constant(0.0)) + sum_state = self._sum_query.initial_sample_state( + global_state.sum_state, template) + return self._SampleState(sum_state, clipped_fraction_state) + + def preprocess_record(self, params, record): + preprocessed_sum_record, global_norm = ( + self._sum_query.preprocess_record_impl(params.sum_params, record)) + + # Note we are relying on the internals of GaussianSumQuery here. If we want + # to open this up to other kinds of inner queries we'd have to do this in a + # more general way. + l2_norm_clip = params.sum_params + + # We accumulate clipped counts shifted by 0.5 so they are centered at zero. + # This makes the sensitivity of the clipped count query 0.5 instead of 1.0. + was_clipped = tf.cast(global_norm >= l2_norm_clip, tf.float32) - 0.5 + + preprocessed_clipped_fraction_record = ( + self._clipped_fraction_query.preprocess_record( + params.clipped_fraction_params, was_clipped)) + + return preprocessed_sum_record, preprocessed_clipped_fraction_record + + def accumulate_preprocessed_record( + self, sample_state, preprocessed_record, weight=1): + """See base class.""" + preprocessed_sum_record, preprocessed_clipped_fraction_record = preprocessed_record + sum_state = self._sum_query.accumulate_preprocessed_record( + sample_state.sum_state, preprocessed_sum_record) + + clipped_fraction_state = self._clipped_fraction_query.accumulate_preprocessed_record( + sample_state.clipped_fraction_state, + preprocessed_clipped_fraction_record) + return self._SampleState(sum_state, clipped_fraction_state) + + def merge_sample_states(self, sample_state_1, sample_state_2): + """See base class.""" + return self._SampleState( + self._sum_query.merge_sample_states( + sample_state_1.sum_state, + sample_state_2.sum_state), + self._clipped_fraction_query.merge_sample_states( + sample_state_1.clipped_fraction_state, + sample_state_2.clipped_fraction_state)) + + def get_noised_result(self, sample_state, global_state): + """See base class.""" + gs = global_state + + noised_vectors, sum_state = self._sum_query.get_noised_result( + sample_state.sum_state, gs.sum_state) + + clipped_fraction_result, new_clipped_fraction_state = ( + self._clipped_fraction_query.get_noised_result( + sample_state.clipped_fraction_state, + gs.clipped_fraction_state)) + + # Unshift clipped percentile by 0.5. (See comment in accumulate_record.) + clipped_quantile = clipped_fraction_result + 0.5 + unclipped_quantile = 1.0 - clipped_quantile + + # Protect against out-of-range estimates. + unclipped_quantile = tf.minimum(1.0, tf.maximum(0.0, unclipped_quantile)) + + # Loss function is convex, with derivative in [-1, 1], and minimized when + # the true quantile matches the target. + loss_grad = unclipped_quantile - self._target_unclipped_quantile + + new_l2_norm_clip = gs.l2_norm_clip - self._learning_rate * loss_grad + new_l2_norm_clip = tf.maximum(0.0, new_l2_norm_clip) + + new_global_state = self._GlobalState( + new_l2_norm_clip, + sum_state, + new_clipped_fraction_state) + + return noised_vectors, new_global_state + + +class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery): + """DPQuery for average queries with adaptive clipping. + + Clipping norm is tuned adaptively to converge to a value such that a specified + quantile of updates are clipped. + + Note that we use "fixed-denominator" estimation: the denominator should be + specified as the expected number of records per sample. Accumulating the + denominator separately would also be possible but would be produce a higher + variance estimator. + """ + + def __init__( + self, + initial_l2_norm_clip, + noise_multiplier, + denominator, + target_unclipped_quantile, + learning_rate, + clipped_count_stddev, + expected_num_records, + ledger=None): + """Initializes the AdaptiveClipAverageQuery. + + Args: + initial_l2_norm_clip: The initial value of clipping norm. + noise_multiplier: The multiplier of the l2_norm_clip to make the stddev of + the noise. + denominator: The normalization constant (applied after noise is added to + the sum). + target_unclipped_quantile: The desired quantile of updates which should be + clipped. + learning_rate: The learning rate for the clipping norm adaptation. A + rate of r means that the clipping norm will change by a maximum of r at + each step. The maximum is attained when |clip - target| is 1.0. + clipped_count_stddev: The stddev of the noise added to the clipped_count. + Since the sensitivity of the clipped count is 0.5, as a rule of thumb it + should be about 0.5 for reasonable privacy. + expected_num_records: The expected number of records, used to estimate the + clipped count quantile. + ledger: The privacy ledger to which queries should be recorded. + """ + numerator_query = QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip, + noise_multiplier, + target_unclipped_quantile, + learning_rate, + clipped_count_stddev, + expected_num_records, + ledger) + super(QuantileAdaptiveClipAverageQuery, self).__init__( + numerator_query=numerator_query, + denominator=denominator) diff --git a/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py b/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py new file mode 100644 index 0000000..f24c9c0 --- /dev/null +++ b/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py @@ -0,0 +1,298 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for QuantileAdaptiveClipSumQuery.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from privacy.analysis import privacy_ledger +from privacy.dp_query import quantile_adaptive_clip_sum_query +from privacy.dp_query import test_utils + +tf.enable_eager_execution() + + +class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase): + + def test_sum_no_clip_no_noise(self): + record1 = tf.constant([2.0, 0.0]) + record2 = tf.constant([-1.0, 1.0]) + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=10.0, + noise_multiplier=0.0, + target_unclipped_quantile=1.0, + learning_rate=0.0, + clipped_count_stddev=0.0, + expected_num_records=2.0) + query_result, _ = test_utils.run_query(query, [record1, record2]) + result = query_result.numpy() + expected = [1.0, 1.0] + self.assertAllClose(result, expected) + + def test_sum_with_clip_no_noise(self): + record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0]. + record2 = tf.constant([4.0, -3.0]) # Not clipped. + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=5.0, + noise_multiplier=0.0, + target_unclipped_quantile=1.0, + learning_rate=0.0, + clipped_count_stddev=0.0, + expected_num_records=2.0) + + query_result, _ = test_utils.run_query(query, [record1, record2]) + result = query_result.numpy() + expected = [1.0, 1.0] + self.assertAllClose(result, expected) + + def test_sum_with_noise(self): + record1, record2 = 2.71828, 3.14159 + stddev = 1.0 + clip = 5.0 + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=clip, + noise_multiplier=stddev / clip, + target_unclipped_quantile=1.0, + learning_rate=0.0, + clipped_count_stddev=0.0, + expected_num_records=2.0) + + noised_sums = [] + for _ in xrange(1000): + query_result, _ = test_utils.run_query(query, [record1, record2]) + noised_sums.append(query_result.numpy()) + + result_stddev = np.std(noised_sums) + self.assertNear(result_stddev, stddev, 0.1) + + def test_average_no_noise(self): + record1 = tf.constant([5.0, 0.0]) # Clipped to [3.0, 0.0]. + record2 = tf.constant([-1.0, 2.0]) # Not clipped. + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipAverageQuery( + initial_l2_norm_clip=3.0, + noise_multiplier=0.0, + denominator=2.0, + target_unclipped_quantile=1.0, + learning_rate=0.0, + clipped_count_stddev=0.0, + expected_num_records=2.0) + query_result, _ = test_utils.run_query(query, [record1, record2]) + result = query_result.numpy() + expected_average = [1.0, 1.0] + self.assertAllClose(result, expected_average) + + def test_average_with_noise(self): + record1, record2 = 2.71828, 3.14159 + sum_stddev = 1.0 + denominator = 2.0 + clip = 3.0 + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipAverageQuery( + initial_l2_norm_clip=clip, + noise_multiplier=sum_stddev / clip, + denominator=denominator, + target_unclipped_quantile=1.0, + learning_rate=0.0, + clipped_count_stddev=0.0, + expected_num_records=2.0) + + noised_averages = [] + for _ in range(1000): + query_result, _ = test_utils.run_query(query, [record1, record2]) + noised_averages.append(query_result.numpy()) + + result_stddev = np.std(noised_averages) + avg_stddev = sum_stddev / denominator + self.assertNear(result_stddev, avg_stddev, 0.1) + + def test_adaptation_target_zero(self): + record1 = tf.constant([8.5]) + record2 = tf.constant([-7.25]) + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=10.0, + noise_multiplier=0.0, + target_unclipped_quantile=0.0, + learning_rate=1.0, + clipped_count_stddev=0.0, + expected_num_records=2.0) + + global_state = query.initial_global_state() + + initial_clip = global_state.l2_norm_clip + self.assertAllClose(initial_clip, 10.0) + + # On the first two iterations, nothing is clipped, so the clip goes down + # by 1.0 (the learning rate). When the clip reaches 8.0, one record is + # clipped, so the clip goes down by only 0.5. After two more iterations, + # both records are clipped, and the clip norm stays there (at 7.0). + + expected_sums = [1.25, 1.25, 0.75, 0.25, 0.0] + expected_clips = [9.0, 8.0, 7.5, 7.0, 7.0] + for expected_sum, expected_clip in zip(expected_sums, expected_clips): + actual_sum, global_state = test_utils.run_query( + query, [record1, record2], global_state) + + actual_clip = global_state.l2_norm_clip + + self.assertAllClose(actual_clip.numpy(), expected_clip) + self.assertAllClose(actual_sum.numpy(), (expected_sum,)) + + def test_adaptation_target_one(self): + record1 = tf.constant([-1.5]) + record2 = tf.constant([2.75]) + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=0.0, + noise_multiplier=0.0, + target_unclipped_quantile=1.0, + learning_rate=1.0, + clipped_count_stddev=0.0, + expected_num_records=2.0) + + global_state = query.initial_global_state() + + initial_clip = global_state.l2_norm_clip + self.assertAllClose(initial_clip, 0.0) + + # On the first two iterations, both are clipped, so the clip goes up + # by 1.0 (the learning rate). When the clip reaches 2.0, only one record is + # clipped, so the clip goes up by only 0.5. After two more iterations, + # both records are clipped, and the clip norm stays there (at 3.0). + + expected_sums = [0.0, 0.0, 0.5, 1.0, 1.25] + expected_clips = [1.0, 2.0, 2.5, 3.0, 3.0] + for expected_sum, expected_clip in zip(expected_sums, expected_clips): + actual_sum, global_state = test_utils.run_query( + query, [record1, record2], global_state) + + actual_clip = global_state.l2_norm_clip + + self.assertAllClose(actual_clip.numpy(), expected_clip) + self.assertAllClose(actual_sum.numpy(), (expected_sum,)) + + def test_adaptation_linspace(self): + # 100 records equally spaced from 0 to 10 in 0.1 increments. + # Test that with a decaying learning rate we converge to the correct + # median with error at most 0.1. + records = [tf.constant(x) for x in np.linspace( + 0.0, 10.0, num=21, dtype=np.float32)] + + learning_rate = tf.Variable(1.0) + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=0.0, + noise_multiplier=0.0, + target_unclipped_quantile=0.5, + learning_rate=learning_rate, + clipped_count_stddev=0.0, + expected_num_records=2.0) + + global_state = query.initial_global_state() + + for t in range(50): + tf.assign(learning_rate, 1.0 / np.sqrt(t+1)) + _, global_state = test_utils.run_query(query, records, global_state) + + actual_clip = global_state.l2_norm_clip + + if t > 40: + self.assertNear(actual_clip, 5.0, 0.25) + + def test_adaptation_all_equal(self): + # 100 equal records. Test that with a decaying learning rate we converge to + # that record and bounce around it. + records = [tf.constant(5.0)] * 20 + + learning_rate = tf.Variable(1.0) + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=0.0, + noise_multiplier=0.0, + target_unclipped_quantile=0.5, + learning_rate=learning_rate, + clipped_count_stddev=0.0, + expected_num_records=2.0) + + global_state = query.initial_global_state() + + for t in range(50): + tf.assign(learning_rate, 1.0 / np.sqrt(t+1)) + _, global_state = test_utils.run_query(query, records, global_state) + + actual_clip = global_state.l2_norm_clip + + if t > 40: + self.assertNear(actual_clip, 5.0, 0.25) + + def test_ledger(self): + record1 = tf.constant([8.5]) + record2 = tf.constant([-7.25]) + + population_size = tf.Variable(0) + selection_probability = tf.Variable(0.0) + ledger = privacy_ledger.PrivacyLedger( + population_size, selection_probability, 50, 50) + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=10.0, + noise_multiplier=1.0, + target_unclipped_quantile=0.0, + learning_rate=1.0, + clipped_count_stddev=0.0, + expected_num_records=2.0, + ledger=ledger) + + query = privacy_ledger.QueryWithLedger(query, ledger) + + # First sample. + tf.assign(population_size, 10) + tf.assign(selection_probability, 0.1) + _, global_state = test_utils.run_query(query, [record1, record2]) + + expected_queries = [[0.5, 0.0], [10.0, 10.0]] + formatted = ledger.get_formatted_ledger_eager() + sample_1 = formatted[0] + self.assertAllClose(sample_1.population_size, 10.0) + self.assertAllClose(sample_1.selection_probability, 0.1) + self.assertAllClose(sample_1.queries, expected_queries) + + # Second sample. + tf.assign(population_size, 20) + tf.assign(selection_probability, 0.2) + test_utils.run_query(query, [record1, record2], global_state) + + formatted = ledger.get_formatted_ledger_eager() + sample_1, sample_2 = formatted + self.assertAllClose(sample_1.population_size, 10.0) + self.assertAllClose(sample_1.selection_probability, 0.1) + self.assertAllClose(sample_1.queries, expected_queries) + + expected_queries_2 = [[0.5, 0.0], [9.0, 9.0]] + self.assertAllClose(sample_2.population_size, 20.0) + self.assertAllClose(sample_2.selection_probability, 0.2) + self.assertAllClose(sample_2.queries, expected_queries_2) + + +if __name__ == '__main__': + tf.test.main()