Add STDDEV to the state of random noise generator, which will be used to enable adaptive clipping norm in tree aggregation queries.

PiperOrigin-RevId: 393851743
This commit is contained in:
Zheng Xu 2021-08-30 14:17:04 -07:00 committed by A. Unique TensorFlower
parent 07c248d868
commit 5edea5863c
2 changed files with 70 additions and 27 deletions

View file

@ -21,8 +21,8 @@ module implements the core logic of tree aggregation in Tensorflow, which serves
as helper functions for `tree_aggregation_query`. This module and helper
functions are publicly accessible.
"""
import abc
import collections
from typing import Any, Callable, Collection, Optional, Tuple, Union
import attr
@ -70,6 +70,9 @@ class GaussianNoiseGenerator(ValueGenerator):
nested structure of `tf.TensorSpec`s.
"""
# pylint: disable=invalid-name
_GlobalState = collections.namedtuple('_GlobalState', ['seeds', 'stddev'])
def __init__(self,
noise_std: float,
specs: Collection[tf.TensorSpec],
@ -83,48 +86,57 @@ class GaussianNoiseGenerator(ValueGenerator):
seed: An optional integer seed. If None, generator is seeded from the
clock.
"""
self.noise_std = noise_std
self.specs = specs
self.seed = seed
self._noise_std = noise_std
self._specs = specs
self._seed = seed
def initialize(self):
"""Makes an initial state for the GaussianNoiseGenerator.
Returns:
An initial state.
A named tuple of (seeds, stddev).
"""
if self.seed is None:
if self._seed is None:
time_now = tf.timestamp()
residual = time_now - tf.math.floor(time_now)
return tf.cast(
tf.stack([
tf.math.floor(tf.timestamp() * 1e6),
tf.math.floor(residual * 1e9)
]),
dtype=tf.int64)
return self._GlobalState(
tf.cast(
tf.stack([
tf.math.floor(tf.timestamp() * 1e6),
tf.math.floor(residual * 1e9)
]),
dtype=tf.int64), tf.constant(self._noise_std, dtype=tf.float32))
else:
return tf.constant(self.seed, dtype=tf.int64, shape=(2,))
return self._GlobalState(
tf.constant(self._seed, dtype=tf.int64, shape=(2,)),
tf.constant(self._noise_std, dtype=tf.float32))
def next(self, state):
"""Gets next value and advances the GaussianNoiseGenerator.
Args:
state: The current state.
state: The current state (seed, noise_std).
Returns:
A pair (sample, new_state) where sample is a new sample and new_state
is the advanced state.
A tuple of (sample, new_state) where sample is a new sample and new_state
is the advanced state (seed+1, noise_std).
"""
flat_structure = tf.nest.flatten(self.specs)
flat_seeds = [state + i for i in range(len(flat_structure))]
nest_seeds = tf.nest.pack_sequence_as(self.specs, flat_seeds)
flat_structure = tf.nest.flatten(self._specs)
flat_seeds = [state.seeds + i for i in range(len(flat_structure))]
nest_seeds = tf.nest.pack_sequence_as(self._specs, flat_seeds)
def _get_noise(spec, seed):
return tf.random.stateless_normal(
shape=spec.shape, seed=seed, stddev=self.noise_std)
shape=spec.shape, seed=seed, stddev=state.stddev)
nest_noise = tf.nest.map_structure(_get_noise, self.specs, nest_seeds)
return nest_noise, flat_seeds[-1] + 1
nest_noise = tf.nest.map_structure(_get_noise, self._specs, nest_seeds)
return nest_noise, self._GlobalState(flat_seeds[-1] + 1, state.stddev)
def make_state(self, seeds: tf.Tensor, stddev: tf.Tensor):
"""Returns a new named tuple of (seeds, stddev)."""
seeds = tf.ensure_shape(seeds, shape=(2,))
return self._GlobalState(
tf.cast(seeds, dtype=tf.int64), tf.cast(stddev, dtype=tf.float32))
class StatelessValueGenerator(ValueGenerator):

View file

@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for `tree_aggregation`."""
import math
import random
from absl.testing import parameterized
import tensorflow as tf
@ -297,7 +298,11 @@ class EfficientTreeAggregatorTest(tf.test.TestCase, parameterized.TestCase):
tf.nest.map_structure(self.assertAllClose, val, expected_result)
class GaussianNoiseGeneratorTest(tf.test.TestCase):
class GaussianNoiseGeneratorTest(tf.test.TestCase, parameterized.TestCase):
def assertStateEqual(self, state1, state2):
for s1, s2 in zip(tf.nest.flatten(state1), tf.nest.flatten(state2)):
self.assertAllEqual(s1, s2)
def test_random_generator_tf(self,
noise_mean=1.0,
@ -330,12 +335,12 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
g2 = tree_aggregation.GaussianNoiseGenerator(
noise_std=noise_std, specs=tf.TensorSpec([]), seed=seed)
gstate2 = g.initialize()
self.assertAllEqual(gstate, gstate2)
self.assertStateEqual(gstate, gstate2)
for _ in range(steps):
value, gstate = g.next(gstate)
value2, gstate2 = g2.next(gstate2)
self.assertAllEqual(value, value2)
self.assertAllEqual(gstate, gstate2)
self.assertStateEqual(gstate, gstate2)
def test_seed_state_nondeterministic(self, steps=32, noise_std=0.1):
g = tree_aggregation.GaussianNoiseGenerator(
@ -344,11 +349,12 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
g2 = tree_aggregation.GaussianNoiseGenerator(
noise_std=noise_std, specs=tf.TensorSpec([]))
gstate2 = g2.initialize()
self.assertNotAllEqual(gstate.seeds, gstate2.seeds)
for _ in range(steps):
value, gstate = g.next(gstate)
value2, gstate2 = g2.next(gstate2)
self.assertNotAllEqual(value, value2)
self.assertNotAllEqual(gstate, gstate2)
self.assertNotAllEqual(gstate.seeds, gstate2.seeds)
def test_seed_state_structure(self, seed=1, steps=32, noise_std=0.1):
specs = [tf.TensorSpec([]), tf.TensorSpec([1]), tf.TensorSpec([2, 2])]
@ -358,11 +364,36 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
g2 = tree_aggregation.GaussianNoiseGenerator(
noise_std=noise_std, specs=specs, seed=seed)
gstate2 = g2.initialize()
self.assertStateEqual(gstate, gstate2)
for _ in range(steps):
value, gstate = g.next(gstate)
value2, gstate2 = g2.next(gstate2)
self.assertAllClose(value, value2)
self.assertAllEqual(gstate, gstate2)
self.assertStateEqual(gstate, gstate2)
@parameterized.named_parameters(
('increase', range(10), 1),
('decrease', range(30, 20, -2), 2),
('flat', [3.0] * 5, 1),
('small', [0.1**x for x in range(4)], 4),
('random', [random.uniform(1, 10) for _ in range(5)], 4),
)
def test_adaptive_stddev(self, stddev_list, reset_frequency):
# The stddev estimation follows a chi distribution. The confidence for
# `sample_num` samples should be high, and we use a relatively large
# tolerance to guard the numerical stability for small stddev values.
sample_num, tolerance = 10000, 0.05
g = tree_aggregation.GaussianNoiseGenerator(
noise_std=1., specs=tf.TensorSpec([sample_num]), seed=2021)
gstate = g.initialize()
for stddev in stddev_list:
gstate = g.make_state(gstate.seeds, tf.constant(stddev, dtype=tf.float32))
for _ in range(reset_frequency):
prev_gstate = gstate
value, gstate = g.next(gstate)
print(tf.math.reduce_std(value), stddev)
self.assertAllClose(tf.math.reduce_std(value), stddev, rtol=tolerance)
self.assertNotAllEqual(gstate.seeds, prev_gstate.seeds)
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):