forked from 626_privacy/tensorflow_privacy
Add STDDEV to the state of random noise generator, which will be used to enable adaptive clipping norm in tree aggregation queries.
PiperOrigin-RevId: 393851743
This commit is contained in:
parent
07c248d868
commit
5edea5863c
2 changed files with 70 additions and 27 deletions
|
@ -21,8 +21,8 @@ module implements the core logic of tree aggregation in Tensorflow, which serves
|
||||||
as helper functions for `tree_aggregation_query`. This module and helper
|
as helper functions for `tree_aggregation_query`. This module and helper
|
||||||
functions are publicly accessible.
|
functions are publicly accessible.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import collections
|
||||||
from typing import Any, Callable, Collection, Optional, Tuple, Union
|
from typing import Any, Callable, Collection, Optional, Tuple, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -70,6 +70,9 @@ class GaussianNoiseGenerator(ValueGenerator):
|
||||||
nested structure of `tf.TensorSpec`s.
|
nested structure of `tf.TensorSpec`s.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
_GlobalState = collections.namedtuple('_GlobalState', ['seeds', 'stddev'])
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
noise_std: float,
|
noise_std: float,
|
||||||
specs: Collection[tf.TensorSpec],
|
specs: Collection[tf.TensorSpec],
|
||||||
|
@ -83,48 +86,57 @@ class GaussianNoiseGenerator(ValueGenerator):
|
||||||
seed: An optional integer seed. If None, generator is seeded from the
|
seed: An optional integer seed. If None, generator is seeded from the
|
||||||
clock.
|
clock.
|
||||||
"""
|
"""
|
||||||
self.noise_std = noise_std
|
self._noise_std = noise_std
|
||||||
self.specs = specs
|
self._specs = specs
|
||||||
self.seed = seed
|
self._seed = seed
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
"""Makes an initial state for the GaussianNoiseGenerator.
|
"""Makes an initial state for the GaussianNoiseGenerator.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An initial state.
|
A named tuple of (seeds, stddev).
|
||||||
"""
|
"""
|
||||||
if self.seed is None:
|
if self._seed is None:
|
||||||
time_now = tf.timestamp()
|
time_now = tf.timestamp()
|
||||||
residual = time_now - tf.math.floor(time_now)
|
residual = time_now - tf.math.floor(time_now)
|
||||||
return tf.cast(
|
return self._GlobalState(
|
||||||
tf.stack([
|
tf.cast(
|
||||||
tf.math.floor(tf.timestamp() * 1e6),
|
tf.stack([
|
||||||
tf.math.floor(residual * 1e9)
|
tf.math.floor(tf.timestamp() * 1e6),
|
||||||
]),
|
tf.math.floor(residual * 1e9)
|
||||||
dtype=tf.int64)
|
]),
|
||||||
|
dtype=tf.int64), tf.constant(self._noise_std, dtype=tf.float32))
|
||||||
else:
|
else:
|
||||||
return tf.constant(self.seed, dtype=tf.int64, shape=(2,))
|
return self._GlobalState(
|
||||||
|
tf.constant(self._seed, dtype=tf.int64, shape=(2,)),
|
||||||
|
tf.constant(self._noise_std, dtype=tf.float32))
|
||||||
|
|
||||||
def next(self, state):
|
def next(self, state):
|
||||||
"""Gets next value and advances the GaussianNoiseGenerator.
|
"""Gets next value and advances the GaussianNoiseGenerator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: The current state.
|
state: The current state (seed, noise_std).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A pair (sample, new_state) where sample is a new sample and new_state
|
A tuple of (sample, new_state) where sample is a new sample and new_state
|
||||||
is the advanced state.
|
is the advanced state (seed+1, noise_std).
|
||||||
"""
|
"""
|
||||||
flat_structure = tf.nest.flatten(self.specs)
|
flat_structure = tf.nest.flatten(self._specs)
|
||||||
flat_seeds = [state + i for i in range(len(flat_structure))]
|
flat_seeds = [state.seeds + i for i in range(len(flat_structure))]
|
||||||
nest_seeds = tf.nest.pack_sequence_as(self.specs, flat_seeds)
|
nest_seeds = tf.nest.pack_sequence_as(self._specs, flat_seeds)
|
||||||
|
|
||||||
def _get_noise(spec, seed):
|
def _get_noise(spec, seed):
|
||||||
return tf.random.stateless_normal(
|
return tf.random.stateless_normal(
|
||||||
shape=spec.shape, seed=seed, stddev=self.noise_std)
|
shape=spec.shape, seed=seed, stddev=state.stddev)
|
||||||
|
|
||||||
nest_noise = tf.nest.map_structure(_get_noise, self.specs, nest_seeds)
|
nest_noise = tf.nest.map_structure(_get_noise, self._specs, nest_seeds)
|
||||||
return nest_noise, flat_seeds[-1] + 1
|
return nest_noise, self._GlobalState(flat_seeds[-1] + 1, state.stddev)
|
||||||
|
|
||||||
|
def make_state(self, seeds: tf.Tensor, stddev: tf.Tensor):
|
||||||
|
"""Returns a new named tuple of (seeds, stddev)."""
|
||||||
|
seeds = tf.ensure_shape(seeds, shape=(2,))
|
||||||
|
return self._GlobalState(
|
||||||
|
tf.cast(seeds, dtype=tf.int64), tf.cast(stddev, dtype=tf.float32))
|
||||||
|
|
||||||
|
|
||||||
class StatelessValueGenerator(ValueGenerator):
|
class StatelessValueGenerator(ValueGenerator):
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for `tree_aggregation`."""
|
"""Tests for `tree_aggregation`."""
|
||||||
import math
|
import math
|
||||||
|
import random
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -297,7 +298,11 @@ class EfficientTreeAggregatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
tf.nest.map_structure(self.assertAllClose, val, expected_result)
|
tf.nest.map_structure(self.assertAllClose, val, expected_result)
|
||||||
|
|
||||||
|
|
||||||
class GaussianNoiseGeneratorTest(tf.test.TestCase):
|
class GaussianNoiseGeneratorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def assertStateEqual(self, state1, state2):
|
||||||
|
for s1, s2 in zip(tf.nest.flatten(state1), tf.nest.flatten(state2)):
|
||||||
|
self.assertAllEqual(s1, s2)
|
||||||
|
|
||||||
def test_random_generator_tf(self,
|
def test_random_generator_tf(self,
|
||||||
noise_mean=1.0,
|
noise_mean=1.0,
|
||||||
|
@ -330,12 +335,12 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
|
||||||
g2 = tree_aggregation.GaussianNoiseGenerator(
|
g2 = tree_aggregation.GaussianNoiseGenerator(
|
||||||
noise_std=noise_std, specs=tf.TensorSpec([]), seed=seed)
|
noise_std=noise_std, specs=tf.TensorSpec([]), seed=seed)
|
||||||
gstate2 = g.initialize()
|
gstate2 = g.initialize()
|
||||||
self.assertAllEqual(gstate, gstate2)
|
self.assertStateEqual(gstate, gstate2)
|
||||||
for _ in range(steps):
|
for _ in range(steps):
|
||||||
value, gstate = g.next(gstate)
|
value, gstate = g.next(gstate)
|
||||||
value2, gstate2 = g2.next(gstate2)
|
value2, gstate2 = g2.next(gstate2)
|
||||||
self.assertAllEqual(value, value2)
|
self.assertAllEqual(value, value2)
|
||||||
self.assertAllEqual(gstate, gstate2)
|
self.assertStateEqual(gstate, gstate2)
|
||||||
|
|
||||||
def test_seed_state_nondeterministic(self, steps=32, noise_std=0.1):
|
def test_seed_state_nondeterministic(self, steps=32, noise_std=0.1):
|
||||||
g = tree_aggregation.GaussianNoiseGenerator(
|
g = tree_aggregation.GaussianNoiseGenerator(
|
||||||
|
@ -344,11 +349,12 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
|
||||||
g2 = tree_aggregation.GaussianNoiseGenerator(
|
g2 = tree_aggregation.GaussianNoiseGenerator(
|
||||||
noise_std=noise_std, specs=tf.TensorSpec([]))
|
noise_std=noise_std, specs=tf.TensorSpec([]))
|
||||||
gstate2 = g2.initialize()
|
gstate2 = g2.initialize()
|
||||||
|
self.assertNotAllEqual(gstate.seeds, gstate2.seeds)
|
||||||
for _ in range(steps):
|
for _ in range(steps):
|
||||||
value, gstate = g.next(gstate)
|
value, gstate = g.next(gstate)
|
||||||
value2, gstate2 = g2.next(gstate2)
|
value2, gstate2 = g2.next(gstate2)
|
||||||
self.assertNotAllEqual(value, value2)
|
self.assertNotAllEqual(value, value2)
|
||||||
self.assertNotAllEqual(gstate, gstate2)
|
self.assertNotAllEqual(gstate.seeds, gstate2.seeds)
|
||||||
|
|
||||||
def test_seed_state_structure(self, seed=1, steps=32, noise_std=0.1):
|
def test_seed_state_structure(self, seed=1, steps=32, noise_std=0.1):
|
||||||
specs = [tf.TensorSpec([]), tf.TensorSpec([1]), tf.TensorSpec([2, 2])]
|
specs = [tf.TensorSpec([]), tf.TensorSpec([1]), tf.TensorSpec([2, 2])]
|
||||||
|
@ -358,11 +364,36 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
|
||||||
g2 = tree_aggregation.GaussianNoiseGenerator(
|
g2 = tree_aggregation.GaussianNoiseGenerator(
|
||||||
noise_std=noise_std, specs=specs, seed=seed)
|
noise_std=noise_std, specs=specs, seed=seed)
|
||||||
gstate2 = g2.initialize()
|
gstate2 = g2.initialize()
|
||||||
|
self.assertStateEqual(gstate, gstate2)
|
||||||
for _ in range(steps):
|
for _ in range(steps):
|
||||||
value, gstate = g.next(gstate)
|
value, gstate = g.next(gstate)
|
||||||
value2, gstate2 = g2.next(gstate2)
|
value2, gstate2 = g2.next(gstate2)
|
||||||
self.assertAllClose(value, value2)
|
self.assertAllClose(value, value2)
|
||||||
self.assertAllEqual(gstate, gstate2)
|
self.assertStateEqual(gstate, gstate2)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('increase', range(10), 1),
|
||||||
|
('decrease', range(30, 20, -2), 2),
|
||||||
|
('flat', [3.0] * 5, 1),
|
||||||
|
('small', [0.1**x for x in range(4)], 4),
|
||||||
|
('random', [random.uniform(1, 10) for _ in range(5)], 4),
|
||||||
|
)
|
||||||
|
def test_adaptive_stddev(self, stddev_list, reset_frequency):
|
||||||
|
# The stddev estimation follows a chi distribution. The confidence for
|
||||||
|
# `sample_num` samples should be high, and we use a relatively large
|
||||||
|
# tolerance to guard the numerical stability for small stddev values.
|
||||||
|
sample_num, tolerance = 10000, 0.05
|
||||||
|
g = tree_aggregation.GaussianNoiseGenerator(
|
||||||
|
noise_std=1., specs=tf.TensorSpec([sample_num]), seed=2021)
|
||||||
|
gstate = g.initialize()
|
||||||
|
for stddev in stddev_list:
|
||||||
|
gstate = g.make_state(gstate.seeds, tf.constant(stddev, dtype=tf.float32))
|
||||||
|
for _ in range(reset_frequency):
|
||||||
|
prev_gstate = gstate
|
||||||
|
value, gstate = g.next(gstate)
|
||||||
|
print(tf.math.reduce_std(value), stddev)
|
||||||
|
self.assertAllClose(tf.math.reduce_std(value), stddev, rtol=tolerance)
|
||||||
|
self.assertNotAllEqual(gstate.seeds, prev_gstate.seeds)
|
||||||
|
|
||||||
|
|
||||||
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
Loading…
Reference in a new issue