forked from 626_privacy/tensorflow_privacy
Add STDDEV to the state of random noise generator, which will be used to enable adaptive clipping norm in tree aggregation queries.
PiperOrigin-RevId: 393851743
This commit is contained in:
parent
07c248d868
commit
5edea5863c
2 changed files with 70 additions and 27 deletions
|
@ -21,8 +21,8 @@ module implements the core logic of tree aggregation in Tensorflow, which serves
|
|||
as helper functions for `tree_aggregation_query`. This module and helper
|
||||
functions are publicly accessible.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import collections
|
||||
from typing import Any, Callable, Collection, Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
|
@ -70,6 +70,9 @@ class GaussianNoiseGenerator(ValueGenerator):
|
|||
nested structure of `tf.TensorSpec`s.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_GlobalState = collections.namedtuple('_GlobalState', ['seeds', 'stddev'])
|
||||
|
||||
def __init__(self,
|
||||
noise_std: float,
|
||||
specs: Collection[tf.TensorSpec],
|
||||
|
@ -83,48 +86,57 @@ class GaussianNoiseGenerator(ValueGenerator):
|
|||
seed: An optional integer seed. If None, generator is seeded from the
|
||||
clock.
|
||||
"""
|
||||
self.noise_std = noise_std
|
||||
self.specs = specs
|
||||
self.seed = seed
|
||||
self._noise_std = noise_std
|
||||
self._specs = specs
|
||||
self._seed = seed
|
||||
|
||||
def initialize(self):
|
||||
"""Makes an initial state for the GaussianNoiseGenerator.
|
||||
|
||||
Returns:
|
||||
An initial state.
|
||||
A named tuple of (seeds, stddev).
|
||||
"""
|
||||
if self.seed is None:
|
||||
if self._seed is None:
|
||||
time_now = tf.timestamp()
|
||||
residual = time_now - tf.math.floor(time_now)
|
||||
return tf.cast(
|
||||
return self._GlobalState(
|
||||
tf.cast(
|
||||
tf.stack([
|
||||
tf.math.floor(tf.timestamp() * 1e6),
|
||||
tf.math.floor(residual * 1e9)
|
||||
]),
|
||||
dtype=tf.int64)
|
||||
dtype=tf.int64), tf.constant(self._noise_std, dtype=tf.float32))
|
||||
else:
|
||||
return tf.constant(self.seed, dtype=tf.int64, shape=(2,))
|
||||
return self._GlobalState(
|
||||
tf.constant(self._seed, dtype=tf.int64, shape=(2,)),
|
||||
tf.constant(self._noise_std, dtype=tf.float32))
|
||||
|
||||
def next(self, state):
|
||||
"""Gets next value and advances the GaussianNoiseGenerator.
|
||||
|
||||
Args:
|
||||
state: The current state.
|
||||
state: The current state (seed, noise_std).
|
||||
|
||||
Returns:
|
||||
A pair (sample, new_state) where sample is a new sample and new_state
|
||||
is the advanced state.
|
||||
A tuple of (sample, new_state) where sample is a new sample and new_state
|
||||
is the advanced state (seed+1, noise_std).
|
||||
"""
|
||||
flat_structure = tf.nest.flatten(self.specs)
|
||||
flat_seeds = [state + i for i in range(len(flat_structure))]
|
||||
nest_seeds = tf.nest.pack_sequence_as(self.specs, flat_seeds)
|
||||
flat_structure = tf.nest.flatten(self._specs)
|
||||
flat_seeds = [state.seeds + i for i in range(len(flat_structure))]
|
||||
nest_seeds = tf.nest.pack_sequence_as(self._specs, flat_seeds)
|
||||
|
||||
def _get_noise(spec, seed):
|
||||
return tf.random.stateless_normal(
|
||||
shape=spec.shape, seed=seed, stddev=self.noise_std)
|
||||
shape=spec.shape, seed=seed, stddev=state.stddev)
|
||||
|
||||
nest_noise = tf.nest.map_structure(_get_noise, self.specs, nest_seeds)
|
||||
return nest_noise, flat_seeds[-1] + 1
|
||||
nest_noise = tf.nest.map_structure(_get_noise, self._specs, nest_seeds)
|
||||
return nest_noise, self._GlobalState(flat_seeds[-1] + 1, state.stddev)
|
||||
|
||||
def make_state(self, seeds: tf.Tensor, stddev: tf.Tensor):
|
||||
"""Returns a new named tuple of (seeds, stddev)."""
|
||||
seeds = tf.ensure_shape(seeds, shape=(2,))
|
||||
return self._GlobalState(
|
||||
tf.cast(seeds, dtype=tf.int64), tf.cast(stddev, dtype=tf.float32))
|
||||
|
||||
|
||||
class StatelessValueGenerator(ValueGenerator):
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
"""Tests for `tree_aggregation`."""
|
||||
import math
|
||||
import random
|
||||
from absl.testing import parameterized
|
||||
|
||||
import tensorflow as tf
|
||||
|
@ -297,7 +298,11 @@ class EfficientTreeAggregatorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
tf.nest.map_structure(self.assertAllClose, val, expected_result)
|
||||
|
||||
|
||||
class GaussianNoiseGeneratorTest(tf.test.TestCase):
|
||||
class GaussianNoiseGeneratorTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def assertStateEqual(self, state1, state2):
|
||||
for s1, s2 in zip(tf.nest.flatten(state1), tf.nest.flatten(state2)):
|
||||
self.assertAllEqual(s1, s2)
|
||||
|
||||
def test_random_generator_tf(self,
|
||||
noise_mean=1.0,
|
||||
|
@ -330,12 +335,12 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
|
|||
g2 = tree_aggregation.GaussianNoiseGenerator(
|
||||
noise_std=noise_std, specs=tf.TensorSpec([]), seed=seed)
|
||||
gstate2 = g.initialize()
|
||||
self.assertAllEqual(gstate, gstate2)
|
||||
self.assertStateEqual(gstate, gstate2)
|
||||
for _ in range(steps):
|
||||
value, gstate = g.next(gstate)
|
||||
value2, gstate2 = g2.next(gstate2)
|
||||
self.assertAllEqual(value, value2)
|
||||
self.assertAllEqual(gstate, gstate2)
|
||||
self.assertStateEqual(gstate, gstate2)
|
||||
|
||||
def test_seed_state_nondeterministic(self, steps=32, noise_std=0.1):
|
||||
g = tree_aggregation.GaussianNoiseGenerator(
|
||||
|
@ -344,11 +349,12 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
|
|||
g2 = tree_aggregation.GaussianNoiseGenerator(
|
||||
noise_std=noise_std, specs=tf.TensorSpec([]))
|
||||
gstate2 = g2.initialize()
|
||||
self.assertNotAllEqual(gstate.seeds, gstate2.seeds)
|
||||
for _ in range(steps):
|
||||
value, gstate = g.next(gstate)
|
||||
value2, gstate2 = g2.next(gstate2)
|
||||
self.assertNotAllEqual(value, value2)
|
||||
self.assertNotAllEqual(gstate, gstate2)
|
||||
self.assertNotAllEqual(gstate.seeds, gstate2.seeds)
|
||||
|
||||
def test_seed_state_structure(self, seed=1, steps=32, noise_std=0.1):
|
||||
specs = [tf.TensorSpec([]), tf.TensorSpec([1]), tf.TensorSpec([2, 2])]
|
||||
|
@ -358,11 +364,36 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
|
|||
g2 = tree_aggregation.GaussianNoiseGenerator(
|
||||
noise_std=noise_std, specs=specs, seed=seed)
|
||||
gstate2 = g2.initialize()
|
||||
self.assertStateEqual(gstate, gstate2)
|
||||
for _ in range(steps):
|
||||
value, gstate = g.next(gstate)
|
||||
value2, gstate2 = g2.next(gstate2)
|
||||
self.assertAllClose(value, value2)
|
||||
self.assertAllEqual(gstate, gstate2)
|
||||
self.assertStateEqual(gstate, gstate2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('increase', range(10), 1),
|
||||
('decrease', range(30, 20, -2), 2),
|
||||
('flat', [3.0] * 5, 1),
|
||||
('small', [0.1**x for x in range(4)], 4),
|
||||
('random', [random.uniform(1, 10) for _ in range(5)], 4),
|
||||
)
|
||||
def test_adaptive_stddev(self, stddev_list, reset_frequency):
|
||||
# The stddev estimation follows a chi distribution. The confidence for
|
||||
# `sample_num` samples should be high, and we use a relatively large
|
||||
# tolerance to guard the numerical stability for small stddev values.
|
||||
sample_num, tolerance = 10000, 0.05
|
||||
g = tree_aggregation.GaussianNoiseGenerator(
|
||||
noise_std=1., specs=tf.TensorSpec([sample_num]), seed=2021)
|
||||
gstate = g.initialize()
|
||||
for stddev in stddev_list:
|
||||
gstate = g.make_state(gstate.seeds, tf.constant(stddev, dtype=tf.float32))
|
||||
for _ in range(reset_frequency):
|
||||
prev_gstate = gstate
|
||||
value, gstate = g.next(gstate)
|
||||
print(tf.math.reduce_std(value), stddev)
|
||||
self.assertAllClose(tf.math.reduce_std(value), stddev, rtol=tolerance)
|
||||
self.assertNotAllEqual(gstate.seeds, prev_gstate.seeds)
|
||||
|
||||
|
||||
class RestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
|
Loading…
Reference in a new issue