Add STDDEV to the state of random noise generator, which will be used to enable adaptive clipping norm in tree aggregation queries.

PiperOrigin-RevId: 393851743
This commit is contained in:
Zheng Xu 2021-08-30 14:17:04 -07:00 committed by A. Unique TensorFlower
parent 07c248d868
commit 5edea5863c
2 changed files with 70 additions and 27 deletions

View file

@ -21,8 +21,8 @@ module implements the core logic of tree aggregation in Tensorflow, which serves
as helper functions for `tree_aggregation_query`. This module and helper as helper functions for `tree_aggregation_query`. This module and helper
functions are publicly accessible. functions are publicly accessible.
""" """
import abc import abc
import collections
from typing import Any, Callable, Collection, Optional, Tuple, Union from typing import Any, Callable, Collection, Optional, Tuple, Union
import attr import attr
@ -70,6 +70,9 @@ class GaussianNoiseGenerator(ValueGenerator):
nested structure of `tf.TensorSpec`s. nested structure of `tf.TensorSpec`s.
""" """
# pylint: disable=invalid-name
_GlobalState = collections.namedtuple('_GlobalState', ['seeds', 'stddev'])
def __init__(self, def __init__(self,
noise_std: float, noise_std: float,
specs: Collection[tf.TensorSpec], specs: Collection[tf.TensorSpec],
@ -83,48 +86,57 @@ class GaussianNoiseGenerator(ValueGenerator):
seed: An optional integer seed. If None, generator is seeded from the seed: An optional integer seed. If None, generator is seeded from the
clock. clock.
""" """
self.noise_std = noise_std self._noise_std = noise_std
self.specs = specs self._specs = specs
self.seed = seed self._seed = seed
def initialize(self): def initialize(self):
"""Makes an initial state for the GaussianNoiseGenerator. """Makes an initial state for the GaussianNoiseGenerator.
Returns: Returns:
An initial state. A named tuple of (seeds, stddev).
""" """
if self.seed is None: if self._seed is None:
time_now = tf.timestamp() time_now = tf.timestamp()
residual = time_now - tf.math.floor(time_now) residual = time_now - tf.math.floor(time_now)
return tf.cast( return self._GlobalState(
tf.stack([ tf.cast(
tf.math.floor(tf.timestamp() * 1e6), tf.stack([
tf.math.floor(residual * 1e9) tf.math.floor(tf.timestamp() * 1e6),
]), tf.math.floor(residual * 1e9)
dtype=tf.int64) ]),
dtype=tf.int64), tf.constant(self._noise_std, dtype=tf.float32))
else: else:
return tf.constant(self.seed, dtype=tf.int64, shape=(2,)) return self._GlobalState(
tf.constant(self._seed, dtype=tf.int64, shape=(2,)),
tf.constant(self._noise_std, dtype=tf.float32))
def next(self, state): def next(self, state):
"""Gets next value and advances the GaussianNoiseGenerator. """Gets next value and advances the GaussianNoiseGenerator.
Args: Args:
state: The current state. state: The current state (seed, noise_std).
Returns: Returns:
A pair (sample, new_state) where sample is a new sample and new_state A tuple of (sample, new_state) where sample is a new sample and new_state
is the advanced state. is the advanced state (seed+1, noise_std).
""" """
flat_structure = tf.nest.flatten(self.specs) flat_structure = tf.nest.flatten(self._specs)
flat_seeds = [state + i for i in range(len(flat_structure))] flat_seeds = [state.seeds + i for i in range(len(flat_structure))]
nest_seeds = tf.nest.pack_sequence_as(self.specs, flat_seeds) nest_seeds = tf.nest.pack_sequence_as(self._specs, flat_seeds)
def _get_noise(spec, seed): def _get_noise(spec, seed):
return tf.random.stateless_normal( return tf.random.stateless_normal(
shape=spec.shape, seed=seed, stddev=self.noise_std) shape=spec.shape, seed=seed, stddev=state.stddev)
nest_noise = tf.nest.map_structure(_get_noise, self.specs, nest_seeds) nest_noise = tf.nest.map_structure(_get_noise, self._specs, nest_seeds)
return nest_noise, flat_seeds[-1] + 1 return nest_noise, self._GlobalState(flat_seeds[-1] + 1, state.stddev)
def make_state(self, seeds: tf.Tensor, stddev: tf.Tensor):
"""Returns a new named tuple of (seeds, stddev)."""
seeds = tf.ensure_shape(seeds, shape=(2,))
return self._GlobalState(
tf.cast(seeds, dtype=tf.int64), tf.cast(stddev, dtype=tf.float32))
class StatelessValueGenerator(ValueGenerator): class StatelessValueGenerator(ValueGenerator):

View file

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Tests for `tree_aggregation`.""" """Tests for `tree_aggregation`."""
import math import math
import random
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
@ -297,7 +298,11 @@ class EfficientTreeAggregatorTest(tf.test.TestCase, parameterized.TestCase):
tf.nest.map_structure(self.assertAllClose, val, expected_result) tf.nest.map_structure(self.assertAllClose, val, expected_result)
class GaussianNoiseGeneratorTest(tf.test.TestCase): class GaussianNoiseGeneratorTest(tf.test.TestCase, parameterized.TestCase):
def assertStateEqual(self, state1, state2):
for s1, s2 in zip(tf.nest.flatten(state1), tf.nest.flatten(state2)):
self.assertAllEqual(s1, s2)
def test_random_generator_tf(self, def test_random_generator_tf(self,
noise_mean=1.0, noise_mean=1.0,
@ -330,12 +335,12 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
g2 = tree_aggregation.GaussianNoiseGenerator( g2 = tree_aggregation.GaussianNoiseGenerator(
noise_std=noise_std, specs=tf.TensorSpec([]), seed=seed) noise_std=noise_std, specs=tf.TensorSpec([]), seed=seed)
gstate2 = g.initialize() gstate2 = g.initialize()
self.assertAllEqual(gstate, gstate2) self.assertStateEqual(gstate, gstate2)
for _ in range(steps): for _ in range(steps):
value, gstate = g.next(gstate) value, gstate = g.next(gstate)
value2, gstate2 = g2.next(gstate2) value2, gstate2 = g2.next(gstate2)
self.assertAllEqual(value, value2) self.assertAllEqual(value, value2)
self.assertAllEqual(gstate, gstate2) self.assertStateEqual(gstate, gstate2)
def test_seed_state_nondeterministic(self, steps=32, noise_std=0.1): def test_seed_state_nondeterministic(self, steps=32, noise_std=0.1):
g = tree_aggregation.GaussianNoiseGenerator( g = tree_aggregation.GaussianNoiseGenerator(
@ -344,11 +349,12 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
g2 = tree_aggregation.GaussianNoiseGenerator( g2 = tree_aggregation.GaussianNoiseGenerator(
noise_std=noise_std, specs=tf.TensorSpec([])) noise_std=noise_std, specs=tf.TensorSpec([]))
gstate2 = g2.initialize() gstate2 = g2.initialize()
self.assertNotAllEqual(gstate.seeds, gstate2.seeds)
for _ in range(steps): for _ in range(steps):
value, gstate = g.next(gstate) value, gstate = g.next(gstate)
value2, gstate2 = g2.next(gstate2) value2, gstate2 = g2.next(gstate2)
self.assertNotAllEqual(value, value2) self.assertNotAllEqual(value, value2)
self.assertNotAllEqual(gstate, gstate2) self.assertNotAllEqual(gstate.seeds, gstate2.seeds)
def test_seed_state_structure(self, seed=1, steps=32, noise_std=0.1): def test_seed_state_structure(self, seed=1, steps=32, noise_std=0.1):
specs = [tf.TensorSpec([]), tf.TensorSpec([1]), tf.TensorSpec([2, 2])] specs = [tf.TensorSpec([]), tf.TensorSpec([1]), tf.TensorSpec([2, 2])]
@ -358,11 +364,36 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
g2 = tree_aggregation.GaussianNoiseGenerator( g2 = tree_aggregation.GaussianNoiseGenerator(
noise_std=noise_std, specs=specs, seed=seed) noise_std=noise_std, specs=specs, seed=seed)
gstate2 = g2.initialize() gstate2 = g2.initialize()
self.assertStateEqual(gstate, gstate2)
for _ in range(steps): for _ in range(steps):
value, gstate = g.next(gstate) value, gstate = g.next(gstate)
value2, gstate2 = g2.next(gstate2) value2, gstate2 = g2.next(gstate2)
self.assertAllClose(value, value2) self.assertAllClose(value, value2)
self.assertAllEqual(gstate, gstate2) self.assertStateEqual(gstate, gstate2)
@parameterized.named_parameters(
('increase', range(10), 1),
('decrease', range(30, 20, -2), 2),
('flat', [3.0] * 5, 1),
('small', [0.1**x for x in range(4)], 4),
('random', [random.uniform(1, 10) for _ in range(5)], 4),
)
def test_adaptive_stddev(self, stddev_list, reset_frequency):
# The stddev estimation follows a chi distribution. The confidence for
# `sample_num` samples should be high, and we use a relatively large
# tolerance to guard the numerical stability for small stddev values.
sample_num, tolerance = 10000, 0.05
g = tree_aggregation.GaussianNoiseGenerator(
noise_std=1., specs=tf.TensorSpec([sample_num]), seed=2021)
gstate = g.initialize()
for stddev in stddev_list:
gstate = g.make_state(gstate.seeds, tf.constant(stddev, dtype=tf.float32))
for _ in range(reset_frequency):
prev_gstate = gstate
value, gstate = g.next(gstate)
print(tf.math.reduce_std(value), stddev)
self.assertAllClose(tf.math.reduce_std(value), stddev, rtol=tolerance)
self.assertNotAllEqual(gstate.seeds, prev_gstate.seeds)
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase): class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):