Remove QuantileAdaptiveClipAverageQuery. Users can simply wrap QuantileAdaptiveClipSumQuery with a NormalizedQuery.

PiperOrigin-RevId: 374770867
This commit is contained in:
Galen Andrew 2021-05-19 18:10:19 -07:00 committed by A. Unique TensorFlower
parent aaf4c252a0
commit 1de7e4dde4
3 changed files with 48 additions and 158 deletions

View file

@ -48,7 +48,6 @@ else:
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import QuantileEstimatorQuery from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import QuantileEstimatorQuery
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import NoPrivacyQuantileEstimatorQuery from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import NoPrivacyQuantileEstimatorQuery
from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery
from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipAverageQuery
# Estimators # Estimators
from tensorflow_privacy.privacy.estimators.dnn import DNNClassifier from tensorflow_privacy.privacy.estimators.dnn import DNNClassifier

View file

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Implements DPQuery interface for adaptive clip queries. """Implements DPQuery interface for adaptive clip queries.
Instead of a fixed clipping norm specified in advance, the clipping norm is Instead of a fixed clipping norm specified in advance, the clipping norm is
@ -31,7 +30,6 @@ import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import normalized_query
from tensorflow_privacy.privacy.dp_query import quantile_estimator_query from tensorflow_privacy.privacy.dp_query import quantile_estimator_query
@ -44,10 +42,8 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
# pylint: disable=invalid-name # pylint: disable=invalid-name
_GlobalState = collections.namedtuple( _GlobalState = collections.namedtuple(
'_GlobalState', [ '_GlobalState',
'noise_multiplier', ['noise_multiplier', 'sum_state', 'quantile_estimator_state'])
'sum_state',
'quantile_estimator_state'])
# pylint: disable=invalid-name # pylint: disable=invalid-name
_SampleState = collections.namedtuple( _SampleState = collections.namedtuple(
@ -57,15 +53,14 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
_SampleParams = collections.namedtuple( _SampleParams = collections.namedtuple(
'_SampleParams', ['sum_params', 'quantile_estimator_params']) '_SampleParams', ['sum_params', 'quantile_estimator_params'])
def __init__( def __init__(self,
self, initial_l2_norm_clip,
initial_l2_norm_clip, noise_multiplier,
noise_multiplier, target_unclipped_quantile,
target_unclipped_quantile, learning_rate,
learning_rate, clipped_count_stddev,
clipped_count_stddev, expected_num_records,
expected_num_records, geometric_update=True):
geometric_update=True):
"""Initializes the QuantileAdaptiveClipSumQuery. """Initializes the QuantileAdaptiveClipSumQuery.
Args: Args:
@ -75,9 +70,9 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
target_unclipped_quantile: The desired quantile of updates which should be target_unclipped_quantile: The desired quantile of updates which should be
unclipped. I.e., a value of 0.8 means a value of l2_norm_clip should be unclipped. I.e., a value of 0.8 means a value of l2_norm_clip should be
found for which approximately 20% of updates are clipped each round. found for which approximately 20% of updates are clipped each round.
learning_rate: The learning rate for the clipping norm adaptation. A learning_rate: The learning rate for the clipping norm adaptation. A rate
rate of r means that the clipping norm will change by a maximum of r at of r means that the clipping norm will change by a maximum of r at each
each step. This maximum is attained when |clip - target| is 1.0. step. This maximum is attained when |clip - target| is 1.0.
clipped_count_stddev: The stddev of the noise added to the clipped_count. clipped_count_stddev: The stddev of the noise added to the clipped_count.
Since the sensitivity of the clipped count is 0.5, as a rule of thumb it Since the sensitivity of the clipped count is 0.5, as a rule of thumb it
should be about 0.5 for reasonable privacy. should be about 0.5 for reasonable privacy.
@ -88,16 +83,11 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
self._noise_multiplier = noise_multiplier self._noise_multiplier = noise_multiplier
self._quantile_estimator_query = quantile_estimator_query.QuantileEstimatorQuery( self._quantile_estimator_query = quantile_estimator_query.QuantileEstimatorQuery(
initial_l2_norm_clip, initial_l2_norm_clip, target_unclipped_quantile, learning_rate,
target_unclipped_quantile, clipped_count_stddev, expected_num_records, geometric_update)
learning_rate,
clipped_count_stddev,
expected_num_records,
geometric_update)
self._sum_query = gaussian_query.GaussianSumQuery( self._sum_query = gaussian_query.GaussianSumQuery(
initial_l2_norm_clip, initial_l2_norm_clip, noise_multiplier * initial_l2_norm_clip)
noise_multiplier * initial_l2_norm_clip)
assert isinstance(self._sum_query, dp_query.SumAggregationDPQuery) assert isinstance(self._sum_query, dp_query.SumAggregationDPQuery)
assert isinstance(self._quantile_estimator_query, assert isinstance(self._quantile_estimator_query,
@ -146,70 +136,13 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
new_l2_norm_clip = tf.maximum(new_l2_norm_clip, 0.0) new_l2_norm_clip = tf.maximum(new_l2_norm_clip, 0.0)
new_sum_stddev = new_l2_norm_clip * global_state.noise_multiplier new_sum_stddev = new_l2_norm_clip * global_state.noise_multiplier
new_sum_query_state = self._sum_query.make_global_state( new_sum_query_state = self._sum_query.make_global_state(
l2_norm_clip=new_l2_norm_clip, l2_norm_clip=new_l2_norm_clip, stddev=new_sum_stddev)
stddev=new_sum_stddev)
new_global_state = self._GlobalState( new_global_state = self._GlobalState(global_state.noise_multiplier,
global_state.noise_multiplier, new_sum_query_state,
new_sum_query_state, new_quantile_estimator_state)
new_quantile_estimator_state)
return noised_vectors, new_global_state return noised_vectors, new_global_state
def derive_metrics(self, global_state): def derive_metrics(self, global_state):
return collections.OrderedDict(clip=global_state.sum_state.l2_norm_clip) return collections.OrderedDict(clip=global_state.sum_state.l2_norm_clip)
class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery):
"""DPQuery for average queries with adaptive clipping.
Clipping norm is tuned adaptively to converge to a value such that a specified
quantile of updates are clipped.
Note that we use "fixed-denominator" estimation: the denominator should be
specified as the expected number of records per sample. Accumulating the
denominator separately would also be possible but would be produce a higher
variance estimator.
"""
def __init__(
self,
initial_l2_norm_clip,
noise_multiplier,
denominator,
target_unclipped_quantile,
learning_rate,
clipped_count_stddev,
expected_num_records,
geometric_update=False):
"""Initializes the AdaptiveClipAverageQuery.
Args:
initial_l2_norm_clip: The initial value of clipping norm.
noise_multiplier: The multiplier of the l2_norm_clip to make the stddev of
the noise.
denominator: The normalization constant (applied after noise is added to
the sum).
target_unclipped_quantile: The desired quantile of updates which should be
clipped.
learning_rate: The learning rate for the clipping norm adaptation. A
rate of r means that the clipping norm will change by a maximum of r at
each step. The maximum is attained when |clip - target| is 1.0.
clipped_count_stddev: The stddev of the noise added to the clipped_count.
Since the sensitivity of the clipped count is 0.5, as a rule of thumb it
should be about 0.5 for reasonable privacy.
expected_num_records: The expected number of records, used to estimate the
clipped count quantile.
geometric_update: If True, use geometric updating of clip.
"""
numerator_query = QuantileAdaptiveClipSumQuery(
initial_l2_norm_clip,
noise_multiplier,
target_unclipped_quantile,
learning_rate,
clipped_count_stddev,
expected_num_records,
geometric_update)
super(QuantileAdaptiveClipAverageQuery, self).__init__(
numerator_query=numerator_query,
denominator=denominator)

View file

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for QuantileAdaptiveClipSumQuery.""" """Tests for QuantileAdaptiveClipSumQuery."""
from __future__ import absolute_import from __future__ import absolute_import
@ -30,8 +29,8 @@ from tensorflow_privacy.privacy.dp_query import test_utils
tf.enable_eager_execution() tf.enable_eager_execution()
class QuantileAdaptiveClipSumQueryTest( class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase,
tf.test.TestCase, parameterized.TestCase): parameterized.TestCase):
def test_sum_no_clip_no_noise(self): def test_sum_no_clip_no_noise(self):
record1 = tf.constant([2.0, 0.0]) record1 = tf.constant([2.0, 0.0])
@ -87,47 +86,6 @@ class QuantileAdaptiveClipSumQueryTest(
result_stddev = np.std(noised_sums) result_stddev = np.std(noised_sums)
self.assertNear(result_stddev, stddev, 0.1) self.assertNear(result_stddev, stddev, 0.1)
def test_average_no_noise(self):
record1 = tf.constant([5.0, 0.0]) # Clipped to [3.0, 0.0].
record2 = tf.constant([-1.0, 2.0]) # Not clipped.
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipAverageQuery(
initial_l2_norm_clip=3.0,
noise_multiplier=0.0,
denominator=2.0,
target_unclipped_quantile=1.0,
learning_rate=0.0,
clipped_count_stddev=0.0,
expected_num_records=2.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
result = query_result.numpy()
expected_average = [1.0, 1.0]
self.assertAllClose(result, expected_average)
def test_average_with_noise(self):
record1, record2 = 2.71828, 3.14159
sum_stddev = 1.0
denominator = 2.0
clip = 3.0
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipAverageQuery(
initial_l2_norm_clip=clip,
noise_multiplier=sum_stddev / clip,
denominator=denominator,
target_unclipped_quantile=1.0,
learning_rate=0.0,
clipped_count_stddev=0.0,
expected_num_records=2.0)
noised_averages = []
for _ in range(1000):
query_result, _ = test_utils.run_query(query, [record1, record2])
noised_averages.append(query_result.numpy())
result_stddev = np.std(noised_averages)
avg_stddev = sum_stddev / denominator
self.assertNear(result_stddev, avg_stddev, 0.1)
def test_adaptation_target_zero(self): def test_adaptation_target_zero(self):
record1 = tf.constant([8.5]) record1 = tf.constant([8.5])
record2 = tf.constant([-7.25]) record2 = tf.constant([-7.25])
@ -154,8 +112,8 @@ class QuantileAdaptiveClipSumQueryTest(
expected_sums = [1.25, 1.25, 0.75, 0.25, 0.0] expected_sums = [1.25, 1.25, 0.75, 0.25, 0.0]
expected_clips = [9.0, 8.0, 7.5, 7.0, 7.0] expected_clips = [9.0, 8.0, 7.5, 7.0, 7.0]
for expected_sum, expected_clip in zip(expected_sums, expected_clips): for expected_sum, expected_clip in zip(expected_sums, expected_clips):
actual_sum, global_state = test_utils.run_query( actual_sum, global_state = test_utils.run_query(query, [record1, record2],
query, [record1, record2], global_state) global_state)
actual_clip = global_state.sum_state.l2_norm_clip actual_clip = global_state.sum_state.l2_norm_clip
@ -170,7 +128,7 @@ class QuantileAdaptiveClipSumQueryTest(
initial_l2_norm_clip=16.0, initial_l2_norm_clip=16.0,
noise_multiplier=0.0, noise_multiplier=0.0,
target_unclipped_quantile=0.0, target_unclipped_quantile=0.0,
learning_rate=np.log(2.0), # Geometric steps in powers of 2. learning_rate=np.log(2.0), # Geometric steps in powers of 2.
clipped_count_stddev=0.0, clipped_count_stddev=0.0,
expected_num_records=2.0, expected_num_records=2.0,
geometric_update=True) geometric_update=True)
@ -185,13 +143,13 @@ class QuantileAdaptiveClipSumQueryTest(
# 4 / sqrt(2.0). Still only one record is clipped, so it reduces to 2.0. # 4 / sqrt(2.0). Still only one record is clipped, so it reduces to 2.0.
# Now both records are clipped, and the clip norm stays there (at 2.0). # Now both records are clipped, and the clip norm stays there (at 2.0).
four_div_root_two = 4 / np.sqrt(2.0) # approx 2.828 four_div_root_two = 4 / np.sqrt(2.0) # approx 2.828
expected_sums = [2.5, 2.5, 1.5, four_div_root_two - 2.5, 0.0] expected_sums = [2.5, 2.5, 1.5, four_div_root_two - 2.5, 0.0]
expected_clips = [8.0, 4.0, four_div_root_two, 2.0, 2.0] expected_clips = [8.0, 4.0, four_div_root_two, 2.0, 2.0]
for expected_sum, expected_clip in zip(expected_sums, expected_clips): for expected_sum, expected_clip in zip(expected_sums, expected_clips):
actual_sum, global_state = test_utils.run_query( actual_sum, global_state = test_utils.run_query(query, [record1, record2],
query, [record1, record2], global_state) global_state)
actual_clip = global_state.sum_state.l2_norm_clip actual_clip = global_state.sum_state.l2_norm_clip
@ -224,8 +182,8 @@ class QuantileAdaptiveClipSumQueryTest(
expected_sums = [0.0, 0.0, 0.5, 1.0, 1.25] expected_sums = [0.0, 0.0, 0.5, 1.0, 1.25]
expected_clips = [1.0, 2.0, 2.5, 3.0, 3.0] expected_clips = [1.0, 2.0, 2.5, 3.0, 3.0]
for expected_sum, expected_clip in zip(expected_sums, expected_clips): for expected_sum, expected_clip in zip(expected_sums, expected_clips):
actual_sum, global_state = test_utils.run_query( actual_sum, global_state = test_utils.run_query(query, [record1, record2],
query, [record1, record2], global_state) global_state)
actual_clip = global_state.sum_state.l2_norm_clip actual_clip = global_state.sum_state.l2_norm_clip
@ -240,7 +198,7 @@ class QuantileAdaptiveClipSumQueryTest(
initial_l2_norm_clip=0.5, initial_l2_norm_clip=0.5,
noise_multiplier=0.0, noise_multiplier=0.0,
target_unclipped_quantile=1.0, target_unclipped_quantile=1.0,
learning_rate=np.log(2.0), # Geometric steps in powers of 2. learning_rate=np.log(2.0), # Geometric steps in powers of 2.
clipped_count_stddev=0.0, clipped_count_stddev=0.0,
expected_num_records=2.0, expected_num_records=2.0,
geometric_update=True) geometric_update=True)
@ -255,30 +213,31 @@ class QuantileAdaptiveClipSumQueryTest(
# multiplied by sqrt(2.0). Still only one is clipped so it increases to 4.0. # multiplied by sqrt(2.0). Still only one is clipped so it increases to 4.0.
# Now both records are clipped, and the clip norm stays there (at 4.0). # Now both records are clipped, and the clip norm stays there (at 4.0).
two_times_root_two = 2 * np.sqrt(2.0) # approx 2.828 two_times_root_two = 2 * np.sqrt(2.0) # approx 2.828
expected_sums = [0.0, 0.0, 0.5, two_times_root_two - 1.5, 1.5] expected_sums = [0.0, 0.0, 0.5, two_times_root_two - 1.5, 1.5]
expected_clips = [1.0, 2.0, two_times_root_two, 4.0, 4.0] expected_clips = [1.0, 2.0, two_times_root_two, 4.0, 4.0]
for expected_sum, expected_clip in zip(expected_sums, expected_clips): for expected_sum, expected_clip in zip(expected_sums, expected_clips):
actual_sum, global_state = test_utils.run_query( actual_sum, global_state = test_utils.run_query(query, [record1, record2],
query, [record1, record2], global_state) global_state)
actual_clip = global_state.sum_state.l2_norm_clip actual_clip = global_state.sum_state.l2_norm_clip
self.assertAllClose(actual_clip.numpy(), expected_clip) self.assertAllClose(actual_clip.numpy(), expected_clip)
self.assertAllClose(actual_sum.numpy(), (expected_sum,)) self.assertAllClose(actual_sum.numpy(), (expected_sum,))
@parameterized.named_parameters( @parameterized.named_parameters(('start_low_arithmetic', True, False),
('start_low_arithmetic', True, False), ('start_low_geometric', True, True),
('start_low_geometric', True, True), ('start_high_arithmetic', False, False),
('start_high_arithmetic', False, False), ('start_high_geometric', False, True))
('start_high_geometric', False, True))
def test_adaptation_linspace(self, start_low, geometric): def test_adaptation_linspace(self, start_low, geometric):
# 100 records equally spaced from 0 to 10 in 0.1 increments. # 100 records equally spaced from 0 to 10 in 0.1 increments.
# Test that we converge to the correct median value and bounce around it. # Test that we converge to the correct median value and bounce around it.
num_records = 21 num_records = 21
records = [tf.constant(x) for x in np.linspace( records = [
0.0, 10.0, num=num_records, dtype=np.float32)] tf.constant(x)
for x in np.linspace(0.0, 10.0, num=num_records, dtype=np.float32)
]
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
initial_l2_norm_clip=(1.0 if start_low else 10.0), initial_l2_norm_clip=(1.0 if start_low else 10.0),
@ -299,11 +258,10 @@ class QuantileAdaptiveClipSumQueryTest(
if t > 40: if t > 40:
self.assertNear(actual_clip, 5.0, 0.25) self.assertNear(actual_clip, 5.0, 0.25)
@parameterized.named_parameters( @parameterized.named_parameters(('start_low_arithmetic', True, False),
('start_low_arithmetic', True, False), ('start_low_geometric', True, True),
('start_low_geometric', True, True), ('start_high_arithmetic', False, False),
('start_high_arithmetic', False, False), ('start_high_geometric', False, True))
('start_high_geometric', False, True))
def test_adaptation_all_equal(self, start_low, geometric): def test_adaptation_all_equal(self, start_low, geometric):
# 20 equal records. Test that we converge to that record and bounce around # 20 equal records. Test that we converge to that record and bounce around
# it. Unlike the linspace test, the quantile-matching objective is very # it. Unlike the linspace test, the quantile-matching objective is very
@ -349,8 +307,8 @@ class QuantileAdaptiveClipSumQueryTest(
expected_num_records=2.0, expected_num_records=2.0,
geometric_update=False) geometric_update=False)
query = privacy_ledger.QueryWithLedger( query = privacy_ledger.QueryWithLedger(query, population_size,
query, population_size, selection_probability) selection_probability)
# First sample. # First sample.
tf.assign(population_size, 10) tf.assign(population_size, 10)