Adaptive clipping in DP-FTRL with restart.
PiperOrigin-RevId: 513934548
This commit is contained in:
parent
8bfafdd74d
commit
0a0f377f3f
5 changed files with 694 additions and 0 deletions
|
@ -26,6 +26,7 @@ py_library(
|
||||||
"//tensorflow_privacy/privacy/dp_query:no_privacy_query",
|
"//tensorflow_privacy/privacy/dp_query:no_privacy_query",
|
||||||
"//tensorflow_privacy/privacy/dp_query:normalized_query",
|
"//tensorflow_privacy/privacy/dp_query:normalized_query",
|
||||||
"//tensorflow_privacy/privacy/dp_query:quantile_adaptive_clip_sum_query",
|
"//tensorflow_privacy/privacy/dp_query:quantile_adaptive_clip_sum_query",
|
||||||
|
"//tensorflow_privacy/privacy/dp_query:quantile_adaptive_clip_tree_query",
|
||||||
"//tensorflow_privacy/privacy/dp_query:quantile_estimator_query",
|
"//tensorflow_privacy/privacy/dp_query:quantile_estimator_query",
|
||||||
"//tensorflow_privacy/privacy/dp_query:restart_query",
|
"//tensorflow_privacy/privacy/dp_query:restart_query",
|
||||||
"//tensorflow_privacy/privacy/dp_query:tree_aggregation",
|
"//tensorflow_privacy/privacy/dp_query:tree_aggregation",
|
||||||
|
|
|
@ -45,6 +45,7 @@ else:
|
||||||
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import QuantileEstimatorQuery
|
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import QuantileEstimatorQuery
|
||||||
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import NoPrivacyQuantileEstimatorQuery
|
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import NoPrivacyQuantileEstimatorQuery
|
||||||
from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery
|
from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery
|
||||||
|
from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_tree_query import QAdaClipTreeResSumQuery
|
||||||
from tensorflow_privacy.privacy.dp_query import restart_query
|
from tensorflow_privacy.privacy.dp_query import restart_query
|
||||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
from tensorflow_privacy.privacy.dp_query import tree_aggregation
|
||||||
from tensorflow_privacy.privacy.dp_query.restart_query import RestartQuery
|
from tensorflow_privacy.privacy.dp_query.restart_query import RestartQuery
|
||||||
|
|
|
@ -323,3 +323,30 @@ py_library(
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [":dp_query"],
|
deps = [":dp_query"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "quantile_adaptive_clip_tree_query",
|
||||||
|
srcs = ["quantile_adaptive_clip_tree_query.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":dp_query",
|
||||||
|
":quantile_estimator_query",
|
||||||
|
":tree_aggregation_query",
|
||||||
|
"//third_party/py/tensorflow:tensorflow_no_contrib",
|
||||||
|
"@com_google_differential_py//dp_accounting:accounting",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "quantile_adaptive_clip_tree_query_test",
|
||||||
|
timeout = "long",
|
||||||
|
srcs = ["quantile_adaptive_clip_tree_query_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
shard_count = 5,
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":quantile_adaptive_clip_tree_query",
|
||||||
|
":test_utils",
|
||||||
|
"//third_party/py/tensorflow:tensorflow_no_contrib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,196 @@
|
||||||
|
# Copyright 2021, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""`DPQuery` for tree aggregation queries with adaptive clipping."""
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import dp_accounting
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
|
from tensorflow_privacy.privacy.dp_query import quantile_estimator_query
|
||||||
|
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
|
||||||
|
|
||||||
|
|
||||||
|
class QAdaClipTreeResSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
"""`DPQuery` for tree aggregation queries with adaptive clipping.
|
||||||
|
|
||||||
|
The implementation is based on tree aggregation noise for cumulative sum in
|
||||||
|
"Practical and Private (Deep) Learning without Sampling or Shuffling"
|
||||||
|
(https://arxiv.org/abs/2103.00039) and quantile-based adaptive clipping in
|
||||||
|
"Differentially Private Learning with Adaptive Clipping"
|
||||||
|
(https://arxiv.org/abs/1905.03871).
|
||||||
|
|
||||||
|
The quantile value will be continuously estimated, but the clip norm is only
|
||||||
|
updated when `reset_state` is called, and the tree state will be reset. This
|
||||||
|
will force the clip norm (and corresponding stddev) in a tree unchanged.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
_GlobalState = collections.namedtuple(
|
||||||
|
'_GlobalState',
|
||||||
|
['noise_multiplier', 'sum_state', 'quantile_estimator_state'])
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
_SampleState = collections.namedtuple(
|
||||||
|
'_SampleState', ['sum_state', 'quantile_estimator_state'])
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
_SampleParams = collections.namedtuple(
|
||||||
|
'_SampleParams', ['sum_params', 'quantile_estimator_params'])
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
initial_l2_norm_clip,
|
||||||
|
noise_multiplier,
|
||||||
|
record_specs,
|
||||||
|
target_unclipped_quantile,
|
||||||
|
learning_rate,
|
||||||
|
clipped_count_stddev,
|
||||||
|
expected_num_records,
|
||||||
|
geometric_update=True,
|
||||||
|
noise_seed=None):
|
||||||
|
"""Initializes the `QAdaClipTreeResSumQuery`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_l2_norm_clip: The initial value of clipping norm.
|
||||||
|
noise_multiplier: The stddev of the noise added to the output will be this
|
||||||
|
times the current value of the clipping norm.
|
||||||
|
record_specs: A nested structure of `tf.TensorSpec`s specifying structure
|
||||||
|
and shapes of records.
|
||||||
|
target_unclipped_quantile: The desired quantile of updates which should be
|
||||||
|
unclipped. I.e., a value of 0.8 means a value of l2_norm_clip should be
|
||||||
|
found for which approximately 20% of updates are clipped each round.
|
||||||
|
Andrew et al. recommends that this be set to 0.5 to clip to the median.
|
||||||
|
learning_rate: The learning rate for the clipping norm adaptation. With
|
||||||
|
geometric updating, a rate of r means that the clipping norm will change
|
||||||
|
by a maximum factor of exp(r) at each round. This maximum is attained
|
||||||
|
when |actual_unclipped_fraction - target_unclipped_quantile| is 1.0.
|
||||||
|
Andrew et al. recommends that this be set to 0.2 for geometric updating.
|
||||||
|
clipped_count_stddev: The stddev of the noise added to the clipped_count.
|
||||||
|
Andrew et al. recommends that this be set to `expected_num_records / 20`
|
||||||
|
for reasonably fast adaptation and high privacy.
|
||||||
|
expected_num_records: The expected number of records per round, used to
|
||||||
|
estimate the clipped count quantile.
|
||||||
|
geometric_update: If `True`, use geometric updating of clip (recommended).
|
||||||
|
noise_seed: Integer seed for the Gaussian noise generator of
|
||||||
|
`TreeResidualSumQuery`. If `None`, a nondeterministic seed based on
|
||||||
|
system time will be generated.
|
||||||
|
"""
|
||||||
|
self._noise_multiplier = noise_multiplier
|
||||||
|
|
||||||
|
self._quantile_estimator_query = (
|
||||||
|
quantile_estimator_query.TreeQuantileEstimatorQuery(
|
||||||
|
initial_l2_norm_clip,
|
||||||
|
target_unclipped_quantile,
|
||||||
|
learning_rate,
|
||||||
|
clipped_count_stddev,
|
||||||
|
expected_num_records,
|
||||||
|
geometric_update,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._sum_query = (
|
||||||
|
tree_aggregation_query.TreeResidualSumQuery.build_l2_gaussian_query(
|
||||||
|
initial_l2_norm_clip,
|
||||||
|
noise_multiplier,
|
||||||
|
record_specs,
|
||||||
|
noise_seed=noise_seed,
|
||||||
|
use_efficient=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(self._sum_query, dp_query.SumAggregationDPQuery)
|
||||||
|
assert isinstance(self._quantile_estimator_query,
|
||||||
|
dp_query.SumAggregationDPQuery)
|
||||||
|
|
||||||
|
def initial_global_state(self):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
|
||||||
|
return self._GlobalState(
|
||||||
|
tf.cast(self._noise_multiplier, tf.float32),
|
||||||
|
self._sum_query.initial_global_state(),
|
||||||
|
self._quantile_estimator_query.initial_global_state())
|
||||||
|
|
||||||
|
def derive_sample_params(self, global_state):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
|
||||||
|
return self._SampleParams(
|
||||||
|
self._sum_query.derive_sample_params(global_state.sum_state),
|
||||||
|
self._quantile_estimator_query.derive_sample_params(
|
||||||
|
global_state.quantile_estimator_state))
|
||||||
|
|
||||||
|
def initial_sample_state(self, template):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
|
||||||
|
return self._SampleState(
|
||||||
|
self._sum_query.initial_sample_state(template),
|
||||||
|
self._quantile_estimator_query.initial_sample_state())
|
||||||
|
|
||||||
|
def preprocess_record(self, params, record):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
|
||||||
|
clipped_record, global_norm = (
|
||||||
|
self._sum_query.preprocess_record_l2_impl(params.sum_params, record))
|
||||||
|
|
||||||
|
below_estimate = self._quantile_estimator_query.preprocess_record(
|
||||||
|
params.quantile_estimator_params, global_norm)
|
||||||
|
|
||||||
|
return self._SampleState(clipped_record, below_estimate)
|
||||||
|
|
||||||
|
def get_noised_result(self, sample_state, global_state):
|
||||||
|
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
|
||||||
|
noised_vectors, sum_state, sum_event = self._sum_query.get_noised_result(
|
||||||
|
sample_state.sum_state, global_state.sum_state)
|
||||||
|
|
||||||
|
_, quantile_estimator_state, quantile_event = (
|
||||||
|
self._quantile_estimator_query.get_noised_result(
|
||||||
|
sample_state.quantile_estimator_state,
|
||||||
|
global_state.quantile_estimator_state))
|
||||||
|
|
||||||
|
new_global_state = self._GlobalState(global_state.noise_multiplier,
|
||||||
|
sum_state, quantile_estimator_state)
|
||||||
|
event = dp_accounting.ComposedDpEvent(events=[sum_event, quantile_event])
|
||||||
|
return noised_vectors, new_global_state, event
|
||||||
|
|
||||||
|
def reset_state(self, noised_results, global_state):
|
||||||
|
"""Returns state after resetting the tree and updating the clip norm.
|
||||||
|
|
||||||
|
This function will be used in `restart_query.RestartQuery` after calling
|
||||||
|
`get_noised_result` when the restarting condition is met. The clip norm (
|
||||||
|
and corresponding noise stddev) for the tree aggregated sum query is only
|
||||||
|
updated from the quantile-based estimation when `reset_state` is called.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
noised_results: Noised cumulative sum returned by `get_noised_result`.
|
||||||
|
global_state: Updated global state returned by `get_noised_result`, which
|
||||||
|
records noise for the conceptual cumulative sum of the current leaf
|
||||||
|
node, and tree state for the next conceptual cumulative sum.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New global state with restarted tree state, and new clip norm.
|
||||||
|
"""
|
||||||
|
new_l2_norm_clip = tf.math.maximum(
|
||||||
|
global_state.quantile_estimator_state.current_estimate, 0.0)
|
||||||
|
new_sum_stddev = new_l2_norm_clip * global_state.noise_multiplier
|
||||||
|
sum_state = self._sum_query.reset_l2_clip_gaussian_noise(
|
||||||
|
global_state.sum_state,
|
||||||
|
clip_norm=new_l2_norm_clip,
|
||||||
|
stddev=new_sum_stddev)
|
||||||
|
sum_state = self._sum_query.reset_state(noised_results, sum_state)
|
||||||
|
quantile_estimator_state = self._quantile_estimator_query.reset_state(
|
||||||
|
noised_results, global_state.quantile_estimator_state)
|
||||||
|
|
||||||
|
return global_state._replace(
|
||||||
|
sum_state=sum_state, quantile_estimator_state=quantile_estimator_state)
|
||||||
|
|
||||||
|
def derive_metrics(self, global_state):
|
||||||
|
"""Returns the clipping norm and estimated quantile value as a metric."""
|
||||||
|
return collections.OrderedDict(
|
||||||
|
current_clip=global_state.sum_state.clip_value,
|
||||||
|
estimate_clip=global_state.quantile_estimator_state.current_estimate)
|
|
@ -0,0 +1,469 @@
|
||||||
|
# Copyright 2021, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.dp_query import quantile_adaptive_clip_tree_query
|
||||||
|
from tensorflow_privacy.privacy.dp_query import test_utils
|
||||||
|
|
||||||
|
|
||||||
|
class QAdaClipTreeResSumQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def test_sum_no_clip_no_noise(self):
|
||||||
|
record1 = tf.constant([2.0, 0.0])
|
||||||
|
record2 = tf.constant([-1.0, 1.0])
|
||||||
|
|
||||||
|
query = quantile_adaptive_clip_tree_query.QAdaClipTreeResSumQuery(
|
||||||
|
initial_l2_norm_clip=10.0,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
record_specs=tf.TensorSpec([2]),
|
||||||
|
target_unclipped_quantile=1.0,
|
||||||
|
learning_rate=0.0,
|
||||||
|
clipped_count_stddev=0.0,
|
||||||
|
expected_num_records=2.0)
|
||||||
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
|
result = query_result.numpy()
|
||||||
|
expected = [1.0, 1.0]
|
||||||
|
self.assertAllClose(result, expected)
|
||||||
|
|
||||||
|
def test_sum_with_clip_no_noise(self):
|
||||||
|
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
|
||||||
|
record2 = tf.constant([4.0, -3.0]) # Not clipped.
|
||||||
|
|
||||||
|
query = quantile_adaptive_clip_tree_query.QAdaClipTreeResSumQuery(
|
||||||
|
initial_l2_norm_clip=5.0,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
record_specs=tf.TensorSpec([2]),
|
||||||
|
target_unclipped_quantile=1.0,
|
||||||
|
learning_rate=0.0,
|
||||||
|
clipped_count_stddev=0.0,
|
||||||
|
expected_num_records=2.0)
|
||||||
|
|
||||||
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
|
result = query_result.numpy()
|
||||||
|
expected = [1.0, 1.0]
|
||||||
|
self.assertAllClose(result, expected)
|
||||||
|
|
||||||
|
def test_sum_with_noise(self):
|
||||||
|
vector_size = 1000
|
||||||
|
record1 = tf.constant(2.71828, shape=[vector_size])
|
||||||
|
record2 = tf.constant(3.14159, shape=[vector_size])
|
||||||
|
stddev = 1.0
|
||||||
|
clip = 5.0
|
||||||
|
|
||||||
|
query = quantile_adaptive_clip_tree_query.QAdaClipTreeResSumQuery(
|
||||||
|
initial_l2_norm_clip=clip,
|
||||||
|
noise_multiplier=stddev / clip,
|
||||||
|
record_specs=tf.TensorSpec([vector_size]),
|
||||||
|
target_unclipped_quantile=1.0,
|
||||||
|
learning_rate=0.0,
|
||||||
|
clipped_count_stddev=0.0,
|
||||||
|
expected_num_records=2.0,
|
||||||
|
noise_seed=1)
|
||||||
|
|
||||||
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
|
|
||||||
|
result_stddev = np.std(query_result.numpy())
|
||||||
|
self.assertNear(result_stddev, stddev, 0.1)
|
||||||
|
|
||||||
|
def _test_estimate_clip_expected_sum(self,
|
||||||
|
query,
|
||||||
|
global_state,
|
||||||
|
records,
|
||||||
|
expected_sums,
|
||||||
|
expected_clips,
|
||||||
|
reset=True):
|
||||||
|
for expected_sum, expected_clip in zip(expected_sums, expected_clips):
|
||||||
|
initial_clip = global_state.sum_state.clip_value
|
||||||
|
actual_sum, global_state = test_utils.run_query(query, records,
|
||||||
|
global_state)
|
||||||
|
if reset:
|
||||||
|
global_state = query.reset_state(actual_sum, global_state)
|
||||||
|
actual_clip = global_state.sum_state.clip_value
|
||||||
|
self.assertAllClose(actual_clip.numpy(), expected_clip)
|
||||||
|
self.assertAllClose(actual_sum.numpy(), (expected_sum,))
|
||||||
|
else:
|
||||||
|
actual_clip = global_state.sum_state.clip_value
|
||||||
|
estimate_clip = global_state.quantile_estimator_state.current_estimate
|
||||||
|
self.assertAllClose(actual_clip.numpy(), initial_clip)
|
||||||
|
self.assertAllClose(estimate_clip.numpy(), expected_clip)
|
||||||
|
self.assertAllClose(actual_sum.numpy(), (expected_sums[0],))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('adaptive', True), ('constant', False))
|
||||||
|
def test_adaptation_target_zero(self, reset):
|
||||||
|
record1 = tf.constant([8.5])
|
||||||
|
record2 = tf.constant([-7.25])
|
||||||
|
|
||||||
|
query = quantile_adaptive_clip_tree_query.QAdaClipTreeResSumQuery(
|
||||||
|
initial_l2_norm_clip=10.0,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
record_specs=tf.TensorSpec([]),
|
||||||
|
target_unclipped_quantile=0.0,
|
||||||
|
learning_rate=1.0,
|
||||||
|
clipped_count_stddev=0.0,
|
||||||
|
expected_num_records=2.0,
|
||||||
|
geometric_update=False)
|
||||||
|
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
initial_clip = global_state.sum_state.clip_value
|
||||||
|
self.assertAllClose(initial_clip, 10.0)
|
||||||
|
|
||||||
|
# On the first two iterations, nothing is clipped, so the clip goes down
|
||||||
|
# by 1.0 (the learning rate). When the clip reaches 8.0, one record is
|
||||||
|
# clipped, so the clip goes down by only 0.5. After two more iterations,
|
||||||
|
# both records are clipped, and the clip norm stays there (at 7.0).
|
||||||
|
|
||||||
|
expected_sums = [1.25, 1.25, 0.75, 0.25, 0.0]
|
||||||
|
expected_clips = [9.0, 8.0, 7.5, 7.0, 7.0]
|
||||||
|
self._test_estimate_clip_expected_sum(
|
||||||
|
query,
|
||||||
|
global_state, [record1, record2],
|
||||||
|
expected_sums,
|
||||||
|
expected_clips,
|
||||||
|
reset=reset)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('adaptive', True), ('constant', False))
|
||||||
|
def test_adaptation_target_zero_geometric(self, reset):
|
||||||
|
record1 = tf.constant([5.0])
|
||||||
|
record2 = tf.constant([-2.5])
|
||||||
|
|
||||||
|
query = quantile_adaptive_clip_tree_query.QAdaClipTreeResSumQuery(
|
||||||
|
initial_l2_norm_clip=16.0,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
record_specs=tf.TensorSpec([]),
|
||||||
|
target_unclipped_quantile=0.0,
|
||||||
|
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
|
||||||
|
clipped_count_stddev=0.0,
|
||||||
|
expected_num_records=2.0,
|
||||||
|
geometric_update=True)
|
||||||
|
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
initial_clip = global_state.sum_state.clip_value
|
||||||
|
self.assertAllClose(initial_clip, 16.0)
|
||||||
|
|
||||||
|
# For two iterations, nothing is clipped, so the clip is cut in half.
|
||||||
|
# Then one record is clipped, so the clip goes down by only sqrt(2.0) to
|
||||||
|
# 4 / sqrt(2.0). Still only one record is clipped, so it reduces to 2.0.
|
||||||
|
# Now both records are clipped, and the clip norm stays there (at 2.0).
|
||||||
|
|
||||||
|
four_div_root_two = 4 / np.sqrt(2.0) # approx 2.828
|
||||||
|
|
||||||
|
expected_sums = [2.5, 2.5, 1.5, four_div_root_two - 2.5, 0.0]
|
||||||
|
expected_clips = [8.0, 4.0, four_div_root_two, 2.0, 2.0]
|
||||||
|
self._test_estimate_clip_expected_sum(
|
||||||
|
query,
|
||||||
|
global_state, [record1, record2],
|
||||||
|
expected_sums,
|
||||||
|
expected_clips,
|
||||||
|
reset=reset)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('adaptive', True), ('constant', False))
|
||||||
|
def test_adaptation_target_one(self, reset):
|
||||||
|
record1 = tf.constant([-1.5])
|
||||||
|
record2 = tf.constant([2.75])
|
||||||
|
|
||||||
|
query = quantile_adaptive_clip_tree_query.QAdaClipTreeResSumQuery(
|
||||||
|
initial_l2_norm_clip=0.0,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
record_specs=tf.TensorSpec([]),
|
||||||
|
target_unclipped_quantile=1.0,
|
||||||
|
learning_rate=1.0,
|
||||||
|
clipped_count_stddev=0.0,
|
||||||
|
expected_num_records=2.0,
|
||||||
|
geometric_update=False)
|
||||||
|
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
initial_clip = global_state.sum_state.clip_value
|
||||||
|
self.assertAllClose(initial_clip, 0.0)
|
||||||
|
|
||||||
|
# On the first two iterations, both are clipped, so the clip goes up
|
||||||
|
# by 1.0 (the learning rate). When the clip reaches 2.0, only one record is
|
||||||
|
# clipped, so the clip goes up by only 0.5. After two more iterations,
|
||||||
|
# both records are clipped, and the clip norm stays there (at 3.0).
|
||||||
|
|
||||||
|
expected_sums = [0.0, 0.0, 0.5, 1.0, 1.25]
|
||||||
|
expected_clips = [1.0, 2.0, 2.5, 3.0, 3.0]
|
||||||
|
self._test_estimate_clip_expected_sum(
|
||||||
|
query,
|
||||||
|
global_state, [record1, record2],
|
||||||
|
expected_sums,
|
||||||
|
expected_clips,
|
||||||
|
reset=reset)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('adaptive', True), ('constant', False))
|
||||||
|
def test_adaptation_target_one_geometric(self, reset):
|
||||||
|
record1 = tf.constant([-1.5])
|
||||||
|
record2 = tf.constant([3.0])
|
||||||
|
|
||||||
|
query = quantile_adaptive_clip_tree_query.QAdaClipTreeResSumQuery(
|
||||||
|
initial_l2_norm_clip=0.5,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
record_specs=tf.TensorSpec([]),
|
||||||
|
target_unclipped_quantile=1.0,
|
||||||
|
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
|
||||||
|
clipped_count_stddev=0.0,
|
||||||
|
expected_num_records=2.0,
|
||||||
|
geometric_update=True)
|
||||||
|
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
|
||||||
|
initial_clip = global_state.sum_state.clip_value
|
||||||
|
self.assertAllClose(initial_clip, 0.5)
|
||||||
|
|
||||||
|
# On the first two iterations, both are clipped, so the clip is doubled.
|
||||||
|
# When the clip reaches 2.0, only one record is clipped, so the clip is
|
||||||
|
# multiplied by sqrt(2.0). Still only one is clipped so it increases to 4.0.
|
||||||
|
# Now both records are clipped, and the clip norm stays there (at 4.0).
|
||||||
|
|
||||||
|
two_times_root_two = 2 * np.sqrt(2.0) # approx 2.828
|
||||||
|
|
||||||
|
expected_sums = [0.0, 0.0, 0.5, two_times_root_two - 1.5, 1.5]
|
||||||
|
expected_clips = [1.0, 2.0, two_times_root_two, 4.0, 4.0]
|
||||||
|
self._test_estimate_clip_expected_sum(
|
||||||
|
query,
|
||||||
|
global_state, [record1, record2],
|
||||||
|
expected_sums,
|
||||||
|
expected_clips,
|
||||||
|
reset=reset)
|
||||||
|
|
||||||
|
def _test_estimate_clip_converge(self,
|
||||||
|
query,
|
||||||
|
records,
|
||||||
|
expected_clip,
|
||||||
|
tolerance,
|
||||||
|
learning_rate=None,
|
||||||
|
total_steps=50,
|
||||||
|
converge_steps=40):
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
for t in range(total_steps):
|
||||||
|
if learning_rate is not None:
|
||||||
|
learning_rate.assign(1.0 / np.sqrt(t + 1))
|
||||||
|
actual_sum, global_state = test_utils.run_query(query, records,
|
||||||
|
global_state)
|
||||||
|
if t > converge_steps:
|
||||||
|
global_state = query.reset_state(actual_sum, global_state)
|
||||||
|
estimate_clip = global_state.sum_state.clip_value
|
||||||
|
self.assertNear(estimate_clip, expected_clip, tolerance)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('start_low_arithmetic', True, False),
|
||||||
|
('start_low_geometric', True, True),
|
||||||
|
('start_high_arithmetic', False, False),
|
||||||
|
('start_high_geometric', False, True))
|
||||||
|
def test_adaptation_linspace(self, start_low, geometric):
|
||||||
|
# `num_records` records equally spaced from 0 to 10 in 0.1 increments.
|
||||||
|
# Test that we converge to the correct median value and bounce around it.
|
||||||
|
num_records = 21
|
||||||
|
records = [
|
||||||
|
tf.constant(x)
|
||||||
|
for x in np.linspace(0.0, 10.0, num=num_records, dtype=np.float32)
|
||||||
|
]
|
||||||
|
|
||||||
|
query = quantile_adaptive_clip_tree_query.QAdaClipTreeResSumQuery(
|
||||||
|
initial_l2_norm_clip=(1.0 if start_low else 10.0),
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
record_specs=tf.TensorSpec([]),
|
||||||
|
target_unclipped_quantile=0.5,
|
||||||
|
learning_rate=1.0,
|
||||||
|
clipped_count_stddev=0.0,
|
||||||
|
expected_num_records=num_records,
|
||||||
|
geometric_update=geometric)
|
||||||
|
|
||||||
|
self._test_estimate_clip_converge(
|
||||||
|
query, records, expected_clip=5., tolerance=0.25)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('start_low_arithmetic', True, False),
|
||||||
|
('start_low_geometric', True, True),
|
||||||
|
('start_high_arithmetic', False, False),
|
||||||
|
('start_high_geometric', False, True))
|
||||||
|
def test_adaptation_all_equal(self, start_low, geometric):
|
||||||
|
# `num_records` equal records. Test that we converge to that record and
|
||||||
|
# bounce around it. Unlike the linspace test, the quantile-matching
|
||||||
|
# objective is very sharp at the optimum so a decaying learning rate is
|
||||||
|
# necessary.
|
||||||
|
num_records = 20
|
||||||
|
records = [tf.constant(5.0)] * num_records
|
||||||
|
|
||||||
|
learning_rate = tf.Variable(1.0)
|
||||||
|
|
||||||
|
query = quantile_adaptive_clip_tree_query.QAdaClipTreeResSumQuery(
|
||||||
|
initial_l2_norm_clip=(1.0 if start_low else 10.0),
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
record_specs=tf.TensorSpec([]),
|
||||||
|
target_unclipped_quantile=0.5,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
clipped_count_stddev=0.0,
|
||||||
|
expected_num_records=num_records,
|
||||||
|
geometric_update=geometric)
|
||||||
|
|
||||||
|
self._test_estimate_clip_converge(
|
||||||
|
query,
|
||||||
|
records,
|
||||||
|
expected_clip=5.,
|
||||||
|
tolerance=0.5,
|
||||||
|
learning_rate=learning_rate)
|
||||||
|
|
||||||
|
def _test_noise_multiplier(self,
|
||||||
|
query,
|
||||||
|
records,
|
||||||
|
noise_multiplier,
|
||||||
|
learning_rate=None,
|
||||||
|
tolerance=0.15,
|
||||||
|
total_steps=10):
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
for t in range(total_steps):
|
||||||
|
if learning_rate is not None:
|
||||||
|
learning_rate.assign((t + 1.)**(-.5))
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
sample_state = query.initial_sample_state(records[0])
|
||||||
|
for record in records:
|
||||||
|
sample_state = query.accumulate_record(params, sample_state, record)
|
||||||
|
actual_sum, global_state, _ = query.get_noised_result(
|
||||||
|
sample_state, global_state)
|
||||||
|
expected_std = global_state.sum_state.clip_value * noise_multiplier
|
||||||
|
self.assertAllClose(
|
||||||
|
expected_std,
|
||||||
|
global_state.sum_state.tree_state.value_generator_state.stddev)
|
||||||
|
global_state = query.reset_state(actual_sum, global_state)
|
||||||
|
self.assertAllClose(
|
||||||
|
expected_std, tf.math.reduce_std(actual_sum), rtol=tolerance)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('start_low_arithmetic', True, False),
|
||||||
|
('start_low_geometric', True, True),
|
||||||
|
('start_high_arithmetic', False, False),
|
||||||
|
('start_high_geometric', False, True))
|
||||||
|
def test_adaptation_linspace_noise(self, start_low, geometric):
|
||||||
|
# `num_records` records equally spaced from 0 to 10 in 0.1 increments.
|
||||||
|
# Test that we converge to the correct median value and bounce around it.
|
||||||
|
num_records, vector_size, noise_multiplier = 11, 1000, 0.1
|
||||||
|
records = [
|
||||||
|
tf.constant(
|
||||||
|
vector_size**(-.5) * x, shape=[vector_size], dtype=tf.float32)
|
||||||
|
for x in np.linspace(0.0, 10.0, num=num_records)
|
||||||
|
]
|
||||||
|
|
||||||
|
query = quantile_adaptive_clip_tree_query.QAdaClipTreeResSumQuery(
|
||||||
|
initial_l2_norm_clip=(1.0 if start_low else 10.0),
|
||||||
|
noise_multiplier=noise_multiplier,
|
||||||
|
record_specs=tf.TensorSpec([vector_size]),
|
||||||
|
target_unclipped_quantile=0.5,
|
||||||
|
learning_rate=1.0,
|
||||||
|
clipped_count_stddev=0.0,
|
||||||
|
expected_num_records=num_records,
|
||||||
|
geometric_update=geometric,
|
||||||
|
noise_seed=1)
|
||||||
|
|
||||||
|
self._test_noise_multiplier(query, records, noise_multiplier)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('start_low_arithmetic', True, False),
|
||||||
|
('start_low_geometric', True, True),
|
||||||
|
('start_high_arithmetic', False, False),
|
||||||
|
('start_high_geometric', False, True))
|
||||||
|
def test_adaptation_all_equal_noise(self, start_low, geometric):
|
||||||
|
# `num_records` equal records. Test that we converge to that record and
|
||||||
|
# bounce around it. Unlike the linspace test, the quantile-matching
|
||||||
|
# objective is very sharp at the optimum so a decaying learning rate is
|
||||||
|
# necessary.
|
||||||
|
num_records, vector_size, noise_multiplier = 10, 1000, 0.5
|
||||||
|
records = [
|
||||||
|
tf.constant(
|
||||||
|
vector_size**(-.5) * 5., shape=[vector_size], dtype=tf.float32)
|
||||||
|
] * num_records
|
||||||
|
|
||||||
|
learning_rate = tf.Variable(1.)
|
||||||
|
|
||||||
|
query = quantile_adaptive_clip_tree_query.QAdaClipTreeResSumQuery(
|
||||||
|
initial_l2_norm_clip=(1.0 if start_low else 10.0),
|
||||||
|
noise_multiplier=noise_multiplier,
|
||||||
|
record_specs=tf.TensorSpec([vector_size]),
|
||||||
|
target_unclipped_quantile=0.5,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
clipped_count_stddev=0.0,
|
||||||
|
expected_num_records=num_records,
|
||||||
|
geometric_update=geometric,
|
||||||
|
noise_seed=1)
|
||||||
|
|
||||||
|
self._test_noise_multiplier(
|
||||||
|
query, records, noise_multiplier, learning_rate=learning_rate)
|
||||||
|
|
||||||
|
def test_adaptation_clip_noise(self):
|
||||||
|
sample_num, tolerance, stddev = 1000, 0.3, 0.1
|
||||||
|
initial_clip, expected_num_records = 5., 2.
|
||||||
|
record1 = tf.constant(1.)
|
||||||
|
record2 = tf.constant(10.)
|
||||||
|
|
||||||
|
query = quantile_adaptive_clip_tree_query.QAdaClipTreeResSumQuery(
|
||||||
|
initial_l2_norm_clip=initial_clip,
|
||||||
|
noise_multiplier=0.,
|
||||||
|
record_specs=tf.TensorSpec([]),
|
||||||
|
target_unclipped_quantile=.5,
|
||||||
|
learning_rate=1.,
|
||||||
|
clipped_count_stddev=stddev,
|
||||||
|
expected_num_records=expected_num_records,
|
||||||
|
geometric_update=False,
|
||||||
|
noise_seed=1)
|
||||||
|
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
samples = []
|
||||||
|
for _ in range(sample_num):
|
||||||
|
noised_results, global_state = test_utils.run_query(
|
||||||
|
query, [record1, record2], global_state)
|
||||||
|
samples.append(noised_results.numpy())
|
||||||
|
global_state = query.reset_state(noised_results, global_state)
|
||||||
|
self.assertNotEqual(
|
||||||
|
global_state.quantile_estimator_state.current_estimate, initial_clip)
|
||||||
|
# Force to use the same clip norm for noise estimation
|
||||||
|
quantile_estimator_state = global_state.quantile_estimator_state._replace(
|
||||||
|
current_estimate=initial_clip)
|
||||||
|
global_state = global_state._replace(
|
||||||
|
quantile_estimator_state=quantile_estimator_state)
|
||||||
|
|
||||||
|
# The sum result is 1. (unclipped) + 5. (clipped) = 6.
|
||||||
|
self.assertAllClose(np.mean(samples), 6., atol=4 * stddev)
|
||||||
|
self.assertAllClose(
|
||||||
|
np.std(samples), stddev / expected_num_records, rtol=tolerance)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('start_low_arithmetic', True, False),
|
||||||
|
('start_low_geometric', True, True),
|
||||||
|
('start_high_arithmetic', False, False),
|
||||||
|
('start_high_geometric', False, True))
|
||||||
|
def test_adaptation_linspace_noise_converge(self, start_low, geometric):
|
||||||
|
# `num_records` records equally spaced from 0 to 10 in 0.1 increments.
|
||||||
|
# Test that we converge to the correct median value and bounce around it.
|
||||||
|
num_records = 21
|
||||||
|
records = [
|
||||||
|
tf.constant(x)
|
||||||
|
for x in np.linspace(0.0, 10.0, num=num_records, dtype=np.float32)
|
||||||
|
]
|
||||||
|
|
||||||
|
query = quantile_adaptive_clip_tree_query.QAdaClipTreeResSumQuery(
|
||||||
|
initial_l2_norm_clip=(1.0 if start_low else 10.0),
|
||||||
|
noise_multiplier=0.01,
|
||||||
|
record_specs=tf.TensorSpec([]),
|
||||||
|
target_unclipped_quantile=0.5,
|
||||||
|
learning_rate=1.0,
|
||||||
|
clipped_count_stddev=0.01,
|
||||||
|
expected_num_records=num_records,
|
||||||
|
geometric_update=geometric,
|
||||||
|
noise_seed=1)
|
||||||
|
|
||||||
|
self._test_estimate_clip_converge(
|
||||||
|
query, records, expected_clip=5., tolerance=0.25)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
Loading…
Reference in a new issue