Fix num_records in quantile_adaptive_clip_sum_query_test.
PiperOrigin-RevId: 292995170
This commit is contained in:
parent
9bb3c1e6d8
commit
945075a136
1 changed files with 35 additions and 34 deletions
|
@ -273,10 +273,41 @@ class QuantileAdaptiveClipSumQueryTest(
|
|||
('start_high_geometric', False, True))
|
||||
def test_adaptation_linspace(self, start_low, geometric):
|
||||
# 100 records equally spaced from 0 to 10 in 0.1 increments.
|
||||
# Test that with a decaying learning rate we converge to the correct
|
||||
# median value and bounce around it.
|
||||
# Test that we converge to the correct median value and bounce around it.
|
||||
num_records = 21
|
||||
records = [tf.constant(x) for x in np.linspace(
|
||||
0.0, 10.0, num=21, dtype=np.float32)]
|
||||
0.0, 10.0, num=num_records, dtype=np.float32)]
|
||||
|
||||
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
|
||||
initial_l2_norm_clip=(1.0 if start_low else 10.0),
|
||||
noise_multiplier=0.0,
|
||||
target_unclipped_quantile=0.5,
|
||||
learning_rate=1.0,
|
||||
clipped_count_stddev=0.0,
|
||||
expected_num_records=num_records,
|
||||
geometric_update=geometric)
|
||||
|
||||
global_state = query.initial_global_state()
|
||||
|
||||
for t in range(50):
|
||||
_, global_state = test_utils.run_query(query, records, global_state)
|
||||
|
||||
actual_clip = global_state.sum_state.l2_norm_clip
|
||||
|
||||
if t > 40:
|
||||
self.assertNear(actual_clip, 5.0, 0.25)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('start_low_arithmetic', True, False),
|
||||
('start_low_geometric', True, True),
|
||||
('start_high_arithmetic', False, False),
|
||||
('start_high_geometric', False, True))
|
||||
def test_adaptation_all_equal(self, start_low, geometric):
|
||||
# 20 equal records. Test that we converge to that record and bounce around
|
||||
# it. Unlike the linspace test, the quantile-matching objective is very
|
||||
# sharp at the optimum so a decaying learning rate is necessary.
|
||||
num_records = 20
|
||||
records = [tf.constant(5.0)] * num_records
|
||||
|
||||
learning_rate = tf.Variable(1.0)
|
||||
|
||||
|
@ -286,37 +317,7 @@ class QuantileAdaptiveClipSumQueryTest(
|
|||
target_unclipped_quantile=0.5,
|
||||
learning_rate=learning_rate,
|
||||
clipped_count_stddev=0.0,
|
||||
expected_num_records=2.0,
|
||||
geometric_update=geometric)
|
||||
|
||||
global_state = query.initial_global_state()
|
||||
|
||||
for t in range(50):
|
||||
tf.compat.v1.assign(learning_rate, 1.0 / np.sqrt(t + 1))
|
||||
_, global_state = test_utils.run_query(query, records, global_state)
|
||||
|
||||
actual_clip = global_state.sum_state.l2_norm_clip
|
||||
|
||||
if t > 40:
|
||||
self.assertNear(actual_clip, 5.0, 0.25)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('arithmetic', False),
|
||||
('geometric', True))
|
||||
def test_adaptation_all_equal(self, geometric):
|
||||
# 20 equal records. Test that with a decaying learning rate we converge to
|
||||
# that record and bounce around it.
|
||||
records = [tf.constant(5.0)] * 20
|
||||
|
||||
learning_rate = tf.Variable(1.0)
|
||||
|
||||
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
|
||||
initial_l2_norm_clip=1.0,
|
||||
noise_multiplier=0.0,
|
||||
target_unclipped_quantile=0.5,
|
||||
learning_rate=learning_rate,
|
||||
clipped_count_stddev=0.0,
|
||||
expected_num_records=2.0,
|
||||
expected_num_records=num_records,
|
||||
geometric_update=geometric)
|
||||
|
||||
global_state = query.initial_global_state()
|
||||
|
|
Loading…
Reference in a new issue