Adds DP-FTRL via tree aggregation optimizer DPFTRLTreeAggregationOptimizer
.
Includes renaming of `frequency` parameter in restart_query.py to `period` to more more accurately reflect its purpose. PiperOrigin-RevId: 480736961
This commit is contained in:
parent
5e37c1bc70
commit
71837fbeec
6 changed files with 539 additions and 139 deletions
|
@ -65,6 +65,7 @@ else:
|
|||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPAdamOptimizer
|
||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPSGDOptimizer
|
||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdagradOptimizer
|
||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPFTRLTreeAggregationOptimizer
|
||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdamOptimizer
|
||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
|
||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_gaussian_query_optimizer_class
|
||||
|
|
|
@ -61,25 +61,24 @@ class PeriodicRoundRestartIndicator(RestartIndicator):
|
|||
The indicator will maintain an internal counter as state.
|
||||
"""
|
||||
|
||||
def __init__(self, frequency: int, warmup: Optional[int] = None):
|
||||
def __init__(self, period: int, warmup: Optional[int] = None):
|
||||
"""Construct the `PeriodicRoundRestartIndicator`.
|
||||
|
||||
Args:
|
||||
frequency: The `next` function will return `True` every `frequency` number
|
||||
of `next` calls.
|
||||
period: The `next` function will return `True` every `period` number of
|
||||
`next` calls.
|
||||
warmup: The first `True` will be returned at the `warmup` times call of
|
||||
`next`.
|
||||
"""
|
||||
if frequency < 1:
|
||||
raise ValueError('Restart frequency should be equal or larger than 1, '
|
||||
f'got {frequency}')
|
||||
if period < 1:
|
||||
raise ValueError('Restart period should be equal or larger than 1, '
|
||||
f'got {period}')
|
||||
if warmup is None:
|
||||
warmup = 0
|
||||
elif warmup <= 0 or warmup >= frequency:
|
||||
raise ValueError(
|
||||
f'Warmup should be between 1 and `frequency-1={frequency-1}`, '
|
||||
f'got {warmup}')
|
||||
self.frequency = frequency
|
||||
elif warmup <= 0 or warmup >= period:
|
||||
raise ValueError(f'Warmup must be between 1 and `period`-1={period-1}, '
|
||||
f'got {warmup}')
|
||||
self.period = period
|
||||
self.warmup = warmup
|
||||
|
||||
def initialize(self):
|
||||
|
@ -96,10 +95,10 @@ class PeriodicRoundRestartIndicator(RestartIndicator):
|
|||
A pair (value, new_state) where value is the bool indicator and new_state
|
||||
of `state+1`.
|
||||
"""
|
||||
frequency = tf.constant(self.frequency, tf.int32)
|
||||
period = tf.constant(self.period, tf.int32)
|
||||
warmup = tf.constant(self.warmup, tf.int32)
|
||||
state = state + tf.constant(1, tf.int32)
|
||||
flag = tf.math.equal(tf.math.floormod(state, frequency), warmup)
|
||||
flag = tf.math.equal(tf.math.floormod(state, period), warmup)
|
||||
return flag, state
|
||||
|
||||
|
||||
|
|
|
@ -23,44 +23,48 @@ from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
|
|||
|
||||
class RoundRestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def assertRestartsOnPeriod(self, indicator: restart_query.RestartIndicator,
|
||||
state: tf.Tensor, total_steps: int, period: int,
|
||||
offset: int):
|
||||
"""Asserts a restart occurs only every `period` steps."""
|
||||
for step in range(total_steps):
|
||||
flag, state = indicator.next(state)
|
||||
if step % period == offset - 1:
|
||||
self.assertTrue(flag)
|
||||
else:
|
||||
self.assertFalse(flag)
|
||||
|
||||
@parameterized.named_parameters(('zero', 0), ('negative', -1))
|
||||
def test_round_raise(self, frequency):
|
||||
def test_round_raise(self, period):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Restart frequency should be equal or larger than 1'):
|
||||
restart_query.PeriodicRoundRestartIndicator(frequency)
|
||||
ValueError, 'Restart period should be equal or larger than 1'):
|
||||
restart_query.PeriodicRoundRestartIndicator(period)
|
||||
|
||||
@parameterized.named_parameters(('zero', 0), ('negative', -1), ('equal', 2),
|
||||
('large', 3))
|
||||
def test_round_raise_warmup(self, warmup):
|
||||
frequency = 2
|
||||
period = 2
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
f'Warmup should be between 1 and `frequency-1={frequency-1}`'):
|
||||
restart_query.PeriodicRoundRestartIndicator(frequency, warmup)
|
||||
ValueError, f'Warmup must be between 1 and `period`-1={period-1}'):
|
||||
restart_query.PeriodicRoundRestartIndicator(period, warmup)
|
||||
|
||||
@parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5))
|
||||
def test_round_indicator(self, frequency):
|
||||
@parameterized.named_parameters(('period_1', 1), ('period_2', 2),
|
||||
('period_4', 4), ('period_5', 5))
|
||||
def test_round_indicator(self, period):
|
||||
total_steps = 20
|
||||
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
|
||||
indicator = restart_query.PeriodicRoundRestartIndicator(period)
|
||||
state = indicator.initialize()
|
||||
for i in range(total_steps):
|
||||
flag, state = indicator.next(state)
|
||||
if i % frequency == frequency - 1:
|
||||
self.assertTrue(flag)
|
||||
else:
|
||||
self.assertFalse(flag)
|
||||
|
||||
@parameterized.named_parameters(('f2', 2, 1), ('f4', 4, 3), ('f5', 5, 2))
|
||||
def test_round_indicator_warmup(self, frequency, warmup):
|
||||
self.assertRestartsOnPeriod(indicator, state, total_steps, period, period)
|
||||
|
||||
@parameterized.named_parameters(('period_2', 2, 1), ('period_4', 4, 3),
|
||||
('period_5', 5, 2))
|
||||
def test_round_indicator_warmup(self, period, warmup):
|
||||
total_steps = 20
|
||||
indicator = restart_query.PeriodicRoundRestartIndicator(frequency, warmup)
|
||||
indicator = restart_query.PeriodicRoundRestartIndicator(period, warmup)
|
||||
state = indicator.initialize()
|
||||
for i in range(total_steps):
|
||||
flag, state = indicator.next(state)
|
||||
if i % frequency == warmup - 1:
|
||||
self.assertTrue(flag)
|
||||
else:
|
||||
self.assertFalse(flag)
|
||||
|
||||
self.assertRestartsOnPeriod(indicator, state, total_steps, period, warmup)
|
||||
|
||||
|
||||
class TimeRestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
@ -116,9 +120,9 @@ class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
('s1t5f6', 1., 5., 6),
|
||||
)
|
||||
def test_sum_scalar_tree_aggregation_reset(self, scalar_value,
|
||||
tree_node_value, frequency):
|
||||
tree_node_value, period):
|
||||
total_steps = 20
|
||||
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
|
||||
indicator = restart_query.PeriodicRoundRestartIndicator(period)
|
||||
query = tree_aggregation_query.TreeCumulativeSumQuery(
|
||||
clip_fn=_get_l2_clip_fn(),
|
||||
clip_value=scalar_value + 1., # no clip
|
||||
|
@ -138,8 +142,8 @@ class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
# be inferred from the binary representation of the current step.
|
||||
expected = (
|
||||
scalar_value * (i + 1) +
|
||||
i // frequency * tree_node_value * bin(frequency)[2:].count('1') +
|
||||
tree_node_value * bin(i % frequency + 1)[2:].count('1'))
|
||||
i // period * tree_node_value * bin(period)[2:].count('1') +
|
||||
tree_node_value * bin(i % period + 1)[2:].count('1'))
|
||||
self.assertEqual(query_result, expected)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
|
@ -151,9 +155,9 @@ class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
('s1t5f6', 1., 5., 6),
|
||||
)
|
||||
def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value,
|
||||
frequency):
|
||||
period):
|
||||
total_steps = 20
|
||||
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
|
||||
indicator = restart_query.PeriodicRoundRestartIndicator(period)
|
||||
query = tree_aggregation_query.TreeResidualSumQuery(
|
||||
clip_fn=_get_l2_clip_fn(),
|
||||
clip_value=scalar_value + 1., # no clip
|
||||
|
@ -172,8 +176,7 @@ class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
# two continous tree aggregation values. The tree aggregation value can
|
||||
# be inferred from the binary representation of the current step.
|
||||
expected = scalar_value + tree_node_value * (
|
||||
bin(i % frequency + 1)[2:].count('1') -
|
||||
bin(i % frequency)[2:].count('1'))
|
||||
bin(i % period + 1)[2:].count('1') - bin(i % period)[2:].count('1'))
|
||||
self.assertEqual(query_result, expected)
|
||||
|
||||
|
||||
|
|
|
@ -27,6 +27,8 @@ py_library(
|
|||
deps = [
|
||||
"//tensorflow_privacy/privacy/dp_query",
|
||||
"//tensorflow_privacy/privacy/dp_query:gaussian_query",
|
||||
"//tensorflow_privacy/privacy/dp_query:restart_query",
|
||||
"//tensorflow_privacy/privacy/dp_query:tree_aggregation_query",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -47,6 +49,8 @@ py_library(
|
|||
deps = [
|
||||
"//tensorflow_privacy/privacy/dp_query",
|
||||
"//tensorflow_privacy/privacy/dp_query:gaussian_query",
|
||||
"//tensorflow_privacy/privacy/dp_query:restart_query",
|
||||
"//tensorflow_privacy/privacy/dp_query:tree_aggregation_query",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -13,12 +13,16 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Differentially private version of Keras optimizer v2."""
|
||||
from typing import Optional, Type
|
||||
from typing import List, Optional, Type, Union
|
||||
import warnings
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||
from tensorflow_privacy.privacy.dp_query import restart_query
|
||||
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
|
||||
|
||||
_VarListType = List[Union[tf.Tensor, tf.Variable]]
|
||||
|
||||
|
||||
def _normalize(microbatch_gradient: tf.Tensor,
|
||||
|
@ -462,6 +466,102 @@ def make_gaussian_query_optimizer_class(cls):
|
|||
return return_gaussian_query_optimizer
|
||||
|
||||
|
||||
def make_dpftrl_tree_aggregation_optimizer_class(cls):
|
||||
"""Returns a differentially private follow-the-regularized-leader optimizer.
|
||||
|
||||
Args:
|
||||
cls: `DPOptimizerClass`, the output of `make_keras_optimizer_class`.
|
||||
"""
|
||||
|
||||
def return_dpftrl_tree_aggregation_optimizer(
|
||||
l2_norm_clip: float,
|
||||
noise_multiplier: float,
|
||||
var_list_or_model: Union[_VarListType, tf.keras.Model],
|
||||
num_microbatches: Optional[int] = None,
|
||||
gradient_accumulation_steps: int = 1,
|
||||
restart_period: Optional[int] = None,
|
||||
restart_warmup: Optional[int] = None,
|
||||
noise_seed: Optional[int] = None,
|
||||
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
||||
**kwargs):
|
||||
"""Returns a `DPOptimizerClass` `cls` using the `TreeAggregationQuery`.
|
||||
|
||||
Combining this query with a SGD optimizer can be used to implement the
|
||||
DP-FTRL algorithm in
|
||||
"Practical and Private (Deep) Learning without Sampling or Shuffling".
|
||||
|
||||
This function is a thin wrapper around
|
||||
`make_keras_optimizer_class.<locals>.DPOptimizerClass` which can be used to
|
||||
apply a `TreeAggregationQuery` to any `DPOptimizerClass`.
|
||||
|
||||
Args:
|
||||
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
|
||||
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
||||
var_list_or_model: Either a tf.keras.Model or a list of tf.variables from
|
||||
which `tf.TensorSpec`s can be defined. These specify the structure and
|
||||
shapes of records (gradients).
|
||||
num_microbatches: Number of microbatches into which each minibatch is
|
||||
split. Default is `None` which means that number of microbatches is
|
||||
equal to batch size (i.e. each microbatch contains exactly one example).
|
||||
If `gradient_accumulation_steps` is greater than 1 and
|
||||
`num_microbatches` is not `None` then the effective number of
|
||||
microbatches is equal to `num_microbatches *
|
||||
gradient_accumulation_steps`.
|
||||
gradient_accumulation_steps: If greater than 1 then optimizer will be
|
||||
accumulating gradients for this number of optimizer steps before
|
||||
applying them to update model weights. If this argument is set to 1 then
|
||||
updates will be applied on each optimizer step.
|
||||
restart_period: (Optional) Restart wil occur after `restart_period` steps.
|
||||
The default (None) means there will be no periodic restarts. Must be a
|
||||
positive integer. If `restart_warmup` is passed, this only applies to
|
||||
the second restart and onwards and must be not None.
|
||||
restart_warmup: (Optional) The first restart will occur after
|
||||
`restart_warmup` steps. The default (None) means no warmup. Must be an
|
||||
integer in the range [1, `restart_period` - 1].
|
||||
noise_seed: (Optional) Integer seed for the Gaussian noise generator. If
|
||||
`None`, a nondeterministic seed based on system time will be generated.
|
||||
*args: These will be passed on to the base class `__init__` method.
|
||||
**kwargs: These will be passed on to the base class `__init__` method.
|
||||
Raise:
|
||||
ValueError: If restart_warmup is not None and restart_period is None.
|
||||
"""
|
||||
if restart_warmup is not None and restart_period is None:
|
||||
raise ValueError(
|
||||
'`restart_period` was None when `restart_warmup` was not None.')
|
||||
|
||||
if isinstance(var_list_or_model, tf.keras.layers.Layer):
|
||||
model_trainable_specs = tf.nest.map_structure(
|
||||
lambda t: tf.TensorSpec(t.shape),
|
||||
var_list_or_model.trainable_variables)
|
||||
else:
|
||||
model_trainable_specs = tf.nest.map_structure(
|
||||
lambda t: tf.TensorSpec(tf.shape(t)), var_list_or_model)
|
||||
|
||||
if restart_period is not None:
|
||||
sum_query = (
|
||||
tree_aggregation_query.TreeResidualSumQuery.build_l2_gaussian_query(
|
||||
l2_norm_clip, noise_multiplier, model_trainable_specs,
|
||||
noise_seed))
|
||||
restart_indicator = restart_query.PeriodicRoundRestartIndicator(
|
||||
period=restart_period, warmup=restart_warmup)
|
||||
tree_aggregation_sum_query = restart_query.RestartQuery(
|
||||
sum_query, restart_indicator)
|
||||
else:
|
||||
tree_aggregation_sum_query = (
|
||||
tree_aggregation_query.TreeResidualSumQuery.build_l2_gaussian_query(
|
||||
l2_norm_clip, noise_multiplier, model_trainable_specs,
|
||||
noise_seed))
|
||||
|
||||
return cls(
|
||||
dp_sum_query=tree_aggregation_sum_query,
|
||||
num_microbatches=num_microbatches,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
return return_dpftrl_tree_aggregation_optimizer
|
||||
|
||||
|
||||
def make_keras_optimizer_class(cls: Type[tf.keras.optimizers.Optimizer]):
|
||||
"""Returns a differentially private optimizer using the `GaussianSumQuery`.
|
||||
|
||||
|
@ -487,6 +587,8 @@ GenericDPAdamOptimizer = make_keras_generic_optimizer_class(
|
|||
GenericDPSGDOptimizer = make_keras_generic_optimizer_class(
|
||||
tf.keras.optimizers.legacy.SGD)
|
||||
|
||||
DPFTRLTreeAggregationOptimizer = (
|
||||
make_dpftrl_tree_aggregation_optimizer_class(GenericDPSGDOptimizer))
|
||||
# We keep the same names for backwards compatibility.
|
||||
DPKerasAdagradOptimizer = make_gaussian_query_optimizer_class(
|
||||
GenericDPAdagradOptimizer)
|
||||
|
|
|
@ -135,18 +135,24 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertAllCloseAccordingToType(expected_grad1, grads_and_vars[1][0])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent', dp_optimizer_keras.DPKerasSGDOptimizer),
|
||||
('DPGradientDescentVectorized',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),
|
||||
('DPGradientDescent_False', dp_optimizer_keras.DPKerasSGDOptimizer,
|
||||
False),
|
||||
('DPGradientDescentVectorized_False',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, False),
|
||||
('DPFTRLTreeAggregation_True',
|
||||
dp_optimizer_keras.DPFTRLTreeAggregationOptimizer, True),
|
||||
)
|
||||
def testClippingNorm(self, optimizer_class):
|
||||
def testClippingNorm(self, optimizer_class, requires_varlist):
|
||||
var0 = tf.Variable([0.0, 0.0])
|
||||
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
|
||||
|
||||
varlist_kwarg = {'var_list_or_model': [var0]} if requires_varlist else {}
|
||||
|
||||
optimizer = optimizer_class(
|
||||
l2_norm_clip=1.0,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=1,
|
||||
**varlist_kwarg,
|
||||
learning_rate=2.0)
|
||||
|
||||
loss = lambda: self._loss(data0, var0)
|
||||
|
@ -155,24 +161,31 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.5, 1),
|
||||
('DPGradientDescent 2', dp_optimizer_keras.DPKerasSGDOptimizer, 2.5, 2),
|
||||
('DPGradientDescent 4', dp_optimizer_keras.DPKerasSGDOptimizer, 2.5, 4),
|
||||
('DPGradientDescent_1', dp_optimizer_keras.DPKerasSGDOptimizer, 1, False),
|
||||
('DPGradientDescent_2', dp_optimizer_keras.DPKerasSGDOptimizer, 2, False),
|
||||
('DPGradientDescent_4', dp_optimizer_keras.DPKerasSGDOptimizer, 4, False),
|
||||
('DPGradientDescentVectorized',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.5, 1),
|
||||
)
|
||||
def testClippingNormMultipleVariables(self, cls, l2_clip_norm,
|
||||
num_microbatches):
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1, False),
|
||||
('DPFTRLTreeAggregation_4',
|
||||
dp_optimizer_keras.DPFTRLTreeAggregationOptimizer, 4, True))
|
||||
def testClippingNormMultipleVariables(self, cls, num_microbatches,
|
||||
requires_varlist):
|
||||
var0 = tf.Variable([1.0, 2.0])
|
||||
var1 = tf.Variable([3.0])
|
||||
data0 = tf.Variable([[3.0, 6.0], [5.0, 6.0], [4.0, 8.0], [-1.0, 0.0]])
|
||||
data1 = tf.Variable([[8.0], [2.0], [3.0], [1.0]])
|
||||
l2_clip_norm = 2.5
|
||||
|
||||
varlist_kwarg = {
|
||||
'var_list_or_model': [var0, var1]
|
||||
} if requires_varlist else {}
|
||||
|
||||
opt = cls(
|
||||
l2_norm_clip=l2_clip_norm,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=num_microbatches,
|
||||
learning_rate=2.0)
|
||||
learning_rate=2.0,
|
||||
**varlist_kwarg)
|
||||
|
||||
loss = lambda: self._loss(data0, var0) + self._loss(data1, var1)
|
||||
|
||||
|
@ -202,26 +215,28 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent_2_4_1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0,
|
||||
4.0, 1),
|
||||
4.0, 1, False),
|
||||
('DPGradientDescent_4_1_4', dp_optimizer_keras.DPKerasSGDOptimizer, 4.0,
|
||||
1.0, 4),
|
||||
1.0, 4, False),
|
||||
('DPGradientDescentVectorized_2_4_1',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.0, 4.0,
|
||||
1),
|
||||
('DPGradientDescentVectorized_4_1_4',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4.0, 1.0,
|
||||
4),
|
||||
)
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.0, 4.0, 1,
|
||||
False), ('DPGradientDescentVectorized_4_1_4',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer,
|
||||
4.0, 1.0, 4, False),
|
||||
('DPFTRLTreeAggregation_2_4_1',
|
||||
dp_optimizer_keras.DPFTRLTreeAggregationOptimizer, 2.0, 4.0, 1, True))
|
||||
def testNoiseMultiplier(self, optimizer_class, l2_norm_clip, noise_multiplier,
|
||||
num_microbatches):
|
||||
tf.random.set_seed(2)
|
||||
num_microbatches, requires_varlist):
|
||||
var0 = tf.Variable(tf.zeros([1000], dtype=tf.float32))
|
||||
data0 = tf.Variable(tf.zeros([16, 1000], dtype=tf.float32))
|
||||
|
||||
varlist_kwarg = {'var_list_or_model': [var0]} if requires_varlist else {}
|
||||
|
||||
optimizer = optimizer_class(
|
||||
l2_norm_clip=l2_norm_clip,
|
||||
noise_multiplier=noise_multiplier,
|
||||
num_microbatches=num_microbatches,
|
||||
**varlist_kwarg,
|
||||
learning_rate=2.0)
|
||||
|
||||
loss = lambda: self._loss(data0, var0)
|
||||
|
@ -233,36 +248,6 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertNear(
|
||||
np.std(grads), l2_norm_clip * noise_multiplier / num_microbatches, 0.5)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent', dp_optimizer_keras.DPKerasSGDOptimizer),
|
||||
('DPAdagrad', dp_optimizer_keras.DPKerasAdagradOptimizer),
|
||||
('DPAdam', dp_optimizer_keras.DPKerasAdamOptimizer),
|
||||
('DPGradientDescentVectorized',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),
|
||||
('DPAdagradVectorized',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer),
|
||||
('DPAdamVectorized',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer),
|
||||
)
|
||||
def testRaisesOnNoCallOfComputeGradients(self, optimizer_class):
|
||||
"""Tests that assertion fails when DP gradients are not computed."""
|
||||
optimizer = optimizer_class(
|
||||
l2_norm_clip=100.0,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=1,
|
||||
learning_rate=2.0)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
grads_and_vars = tf.Variable([0.0])
|
||||
optimizer.apply_gradients(grads_and_vars)
|
||||
|
||||
# Expect no exception if _compute_gradients is called.
|
||||
var0 = tf.Variable([0.0])
|
||||
data0 = tf.Variable([[0.0]])
|
||||
loss = lambda: self._loss(data0, var0)
|
||||
grads_and_vars = optimizer._compute_gradients(loss, [var0])
|
||||
optimizer.apply_gradients(grads_and_vars)
|
||||
|
||||
|
||||
class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||
"""Tests for get_gradient method.
|
||||
|
@ -271,8 +256,13 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
the Estimator framework.
|
||||
"""
|
||||
|
||||
def _make_linear_model_fn(self, optimizer_class, l2_norm_clip,
|
||||
noise_multiplier, num_microbatches, learning_rate):
|
||||
def _make_linear_model_fn(self,
|
||||
optimizer_class,
|
||||
l2_norm_clip,
|
||||
noise_multiplier,
|
||||
num_microbatches,
|
||||
learning_rate,
|
||||
requires_varlist=False):
|
||||
"""Returns a model function for a linear regressor."""
|
||||
|
||||
def linear_model_fn(features, labels, mode):
|
||||
|
@ -287,10 +277,16 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
vector_loss = 0.5 * tf.math.squared_difference(labels, preds)
|
||||
scalar_loss = tf.reduce_mean(input_tensor=vector_loss)
|
||||
|
||||
# We also set the noise seed (since this is accepted by the constructor).
|
||||
if requires_varlist:
|
||||
varlist_kwarg = {'var_list_or_model': layer, 'noise_seed': 2}
|
||||
else:
|
||||
varlist_kwarg = {}
|
||||
optimizer = optimizer_class(
|
||||
l2_norm_clip=l2_norm_clip,
|
||||
noise_multiplier=noise_multiplier,
|
||||
num_microbatches=num_microbatches,
|
||||
**varlist_kwarg,
|
||||
learning_rate=learning_rate)
|
||||
|
||||
params = layer.trainable_weights
|
||||
|
@ -304,25 +300,36 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
return linear_model_fn
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent_1', dp_optimizer_keras.DPKerasSGDOptimizer, 1),
|
||||
('DPGradientDescent_2', dp_optimizer_keras.DPKerasSGDOptimizer, 2),
|
||||
('DPGradientDescent_4', dp_optimizer_keras.DPKerasSGDOptimizer, 4),
|
||||
('DPGradientDescent_None', dp_optimizer_keras.DPKerasSGDOptimizer, None),
|
||||
('DPGradientDescentVectorized_1',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1),
|
||||
('DPGradientDescentVectorized_2',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2),
|
||||
('DPGradientDescentVectorized_4',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4),
|
||||
('DPGradientDescentVectorized_None',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, None),
|
||||
('DPGradientDescent_1_False', dp_optimizer_keras.DPKerasSGDOptimizer, 1,
|
||||
False),
|
||||
('DPGradientDescent_2_False', dp_optimizer_keras.DPKerasSGDOptimizer, 2,
|
||||
False),
|
||||
('DPGradientDescent_4_False', dp_optimizer_keras.DPKerasSGDOptimizer, 4,
|
||||
False),
|
||||
('DPGradientDescentVectorized_1_False',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1, False),
|
||||
('DPGradientDescentVectorized_2_False',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2, False),
|
||||
('DPGradientDescentVectorized_4_False',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4, False),
|
||||
('DPGradientDescentVectorized_None_False',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, None,
|
||||
False),
|
||||
('DPFTRLTreeAggregation_1_True',
|
||||
dp_optimizer_keras.DPFTRLTreeAggregationOptimizer, 1, True),
|
||||
)
|
||||
def testBaselineNoNoise(self, optimizer_class, num_microbatches):
|
||||
def testBaselineNoNoise(self, optimizer_class, num_microbatches,
|
||||
requires_varlist):
|
||||
"""Tests that DP optimizers work with tf.estimator."""
|
||||
|
||||
linear_regressor = tf_estimator.Estimator(
|
||||
model_fn=self._make_linear_model_fn(optimizer_class, 100.0, 0.0,
|
||||
num_microbatches, 0.05))
|
||||
model_fn=self._make_linear_model_fn(
|
||||
optimizer_class=optimizer_class,
|
||||
l2_norm_clip=100.0,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=num_microbatches,
|
||||
requires_varlist=requires_varlist,
|
||||
learning_rate=0.05))
|
||||
|
||||
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
|
||||
true_bias = np.array([6.0]).astype(np.float32)
|
||||
|
@ -346,11 +353,14 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
linear_regressor.get_variable_value('dense/bias'), true_bias, atol=0.05)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent_1', dp_optimizer_keras.DPKerasSGDOptimizer),
|
||||
('DPGradientDescentVectorized_1',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),
|
||||
('DPGradientDescent_False', dp_optimizer_keras.DPKerasSGDOptimizer,
|
||||
False),
|
||||
('DPGradientDescentVectorized_False',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, False),
|
||||
('DPFTRLTreeAggregation_True',
|
||||
dp_optimizer_keras.DPFTRLTreeAggregationOptimizer, True),
|
||||
)
|
||||
def testClippingNorm(self, optimizer_class):
|
||||
def testClippingNorm(self, optimizer_class, requires_varlist):
|
||||
"""Tests that DP optimizers work with tf.estimator."""
|
||||
|
||||
true_weights = np.array([[6.0], [0.0], [0], [0]]).astype(np.float32)
|
||||
|
@ -369,6 +379,7 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
l2_norm_clip=1.0e9,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=1,
|
||||
requires_varlist=requires_varlist,
|
||||
learning_rate=1.0))
|
||||
unclipped_linear_regressor.train(input_fn=train_input_fn, steps=1)
|
||||
|
||||
|
@ -382,6 +393,7 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
l2_norm_clip=1.0,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=1,
|
||||
requires_varlist=requires_varlist,
|
||||
learning_rate=1.0))
|
||||
clipped_linear_regressor.train(input_fn=train_input_fn, steps=1)
|
||||
|
||||
|
@ -394,35 +406,36 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
bias_value / global_norm,
|
||||
atol=0.001)
|
||||
|
||||
# Parameters for testing: optimizer, l2_norm_clip, noise_multiplier,
|
||||
# num_microbatches.
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent_2_4_1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0,
|
||||
4.0, 1),
|
||||
('DPGradientDescent_3_2_4', dp_optimizer_keras.DPKerasSGDOptimizer, 3.0,
|
||||
2.0, 4),
|
||||
('DPGradientDescent_8_6_8', dp_optimizer_keras.DPKerasSGDOptimizer, 8.0,
|
||||
6.0, 8),
|
||||
('DPGradientDescentVectorized_2_4_1',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.0, 4.0,
|
||||
1),
|
||||
('DPGradientDescentVectorized_3_2_4',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 3.0, 2.0,
|
||||
4),
|
||||
('DPGradientDescentVectorized_8_6_8',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 8.0, 6.0,
|
||||
8),
|
||||
('DPGradientDescent_2_4_1_False', dp_optimizer_keras.DPKerasSGDOptimizer,
|
||||
2.0, 4.0, 1, False),
|
||||
('DPGradientDescent_3_2_4_False', dp_optimizer_keras.DPKerasSGDOptimizer,
|
||||
3.0, 2.0, 4, False),
|
||||
('DPGradientDescent_8_6_8_False', dp_optimizer_keras.DPKerasSGDOptimizer,
|
||||
8.0, 6.0, 8, False),
|
||||
('DPGradientDescentVectorized_2_4_1_False',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.0, 4.0, 1,
|
||||
False),
|
||||
('DPGradientDescentVectorized_3_2_4_False',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 3.0, 2.0, 4,
|
||||
False),
|
||||
('DPGradientDescentVectorized_8_6_8_False',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 8.0, 6.0, 8,
|
||||
False),
|
||||
('DPFTRLTreeAggregation_8_4_2_True',
|
||||
dp_optimizer_keras.DPFTRLTreeAggregationOptimizer, 8.0, 4.0, 1, True),
|
||||
)
|
||||
def testNoiseMultiplier(self, optimizer_class, l2_norm_clip, noise_multiplier,
|
||||
num_microbatches):
|
||||
num_microbatches, requires_varlist):
|
||||
"""Tests that DP optimizers work with tf.estimator."""
|
||||
|
||||
tf.random.set_seed(2)
|
||||
linear_regressor = tf_estimator.Estimator(
|
||||
model_fn=self._make_linear_model_fn(
|
||||
optimizer_class,
|
||||
l2_norm_clip,
|
||||
noise_multiplier,
|
||||
num_microbatches,
|
||||
requires_varlist=requires_varlist,
|
||||
learning_rate=1.0))
|
||||
|
||||
true_weights = np.zeros((1000, 1), dtype=np.float32)
|
||||
|
@ -745,5 +758,283 @@ class DPVectorizedOptimizerUnconnectedNodesTest(tf.test.TestCase,
|
|||
self.fail('ValueError raised by model.fit().')
|
||||
|
||||
|
||||
class DPTreeAggregationOptimizerComputeGradientsTest(tf.test.TestCase,
|
||||
parameterized.TestCase):
|
||||
"""Tests for _compute_gradients method."""
|
||||
|
||||
def _loss(self, val0, val1):
|
||||
"""Loss function whose derivative w.r.t val1 is val1 - val0."""
|
||||
return 0.5 * tf.reduce_sum(
|
||||
input_tensor=tf.math.squared_difference(val0, val1), axis=1)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('1_None_None', 1, None, None),
|
||||
('2_2_1', 2, 2, 1),
|
||||
('4_1_None', 4, 1, None),
|
||||
('4_4_2', 4, 4, 2),
|
||||
)
|
||||
def testBaselineWithCallableLossNoNoise(self, num_microbatches,
|
||||
restart_period, restart_warmup):
|
||||
var0 = tf.Variable([1.0, 2.0])
|
||||
var1 = tf.Variable([3.0])
|
||||
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
||||
data1 = tf.Variable([[8.0], [2.0], [3.0], [1.0]])
|
||||
expected_grad0 = [-2.5, -2.5]
|
||||
expected_grad1 = [-0.5]
|
||||
|
||||
optimizer = dp_optimizer_keras.DPFTRLTreeAggregationOptimizer(
|
||||
l2_norm_clip=100.0,
|
||||
noise_multiplier=0.0,
|
||||
var_list_or_model=[var0, var1],
|
||||
num_microbatches=num_microbatches,
|
||||
restart_period=restart_period,
|
||||
restart_warmup=restart_warmup,
|
||||
learning_rate=2.0)
|
||||
|
||||
loss = lambda: self._loss(data0, var0) + self._loss(data1, var1)
|
||||
|
||||
grads_and_vars = optimizer._compute_gradients(loss, [var0, var1])
|
||||
|
||||
self.assertAllCloseAccordingToType(expected_grad0, grads_and_vars[0][0])
|
||||
self.assertAllCloseAccordingToType(expected_grad1, grads_and_vars[1][0])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('1_None_None', 1, None, None),
|
||||
('2_2_1', 2, 2, 1),
|
||||
('4_1_None', 4, 1, None),
|
||||
('4_4_2', 4, 4, 2),
|
||||
)
|
||||
def testBaselineWithTensorLossNoNoise(self, num_microbatches, restart_period,
|
||||
restart_warmup):
|
||||
var0 = tf.Variable([1.0, 2.0])
|
||||
var1 = tf.Variable([3.0])
|
||||
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
||||
data1 = tf.Variable([[8.0], [2.0], [3.0], [1.0]])
|
||||
expected_grad0 = [-2.5, -2.5]
|
||||
expected_grad1 = [-0.5]
|
||||
|
||||
optimizer = dp_optimizer_keras.DPFTRLTreeAggregationOptimizer(
|
||||
l2_norm_clip=100.0,
|
||||
noise_multiplier=0.0,
|
||||
var_list_or_model=[var0, var1],
|
||||
num_microbatches=num_microbatches,
|
||||
restart_period=restart_period,
|
||||
restart_warmup=restart_warmup,
|
||||
learning_rate=2.0)
|
||||
|
||||
tape = tf.GradientTape()
|
||||
with tape:
|
||||
loss = self._loss(data0, var0) + self._loss(data1, var1)
|
||||
|
||||
grads_and_vars = optimizer._compute_gradients(loss, [var0, var1], tape=tape)
|
||||
self.assertAllCloseAccordingToType(expected_grad0, grads_and_vars[0][0])
|
||||
self.assertAllCloseAccordingToType(expected_grad1, grads_and_vars[1][0])
|
||||
|
||||
def testRaisesOnNoCallOfComputeGradients(self):
|
||||
"""Tests that assertion fails when DP gradients are not computed."""
|
||||
variables = [tf.Variable([0.0])]
|
||||
optimizer = dp_optimizer_keras.DPFTRLTreeAggregationOptimizer(
|
||||
l2_norm_clip=100.0,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=1,
|
||||
learning_rate=2.0,
|
||||
restart_period=None,
|
||||
restart_warmup=None,
|
||||
var_list_or_model=variables)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
optimizer.apply_gradients(variables)
|
||||
|
||||
# Expect no exception if _compute_gradients is called.
|
||||
data0 = tf.Variable([[0.0]])
|
||||
loss = lambda: self._loss(data0, variables[0])
|
||||
grads_and_vars = optimizer._compute_gradients(loss, variables[0])
|
||||
optimizer.apply_gradients(grads_and_vars)
|
||||
|
||||
|
||||
class DPTreeAggregationGetGradientsTest(tf.test.TestCase,
|
||||
parameterized.TestCase):
|
||||
"""Tests for get_gradient method.
|
||||
|
||||
Since get_gradients must run in graph mode, the method is tested within
|
||||
the Estimator framework.
|
||||
"""
|
||||
|
||||
def _make_linear_model_fn(self, l2_norm_clip, noise_multiplier,
|
||||
num_microbatches, restart_period, restart_warmup,
|
||||
learning_rate):
|
||||
"""Returns a model function for a linear regressor."""
|
||||
|
||||
def linear_model_fn(features, labels, mode):
|
||||
layer = tf.keras.layers.Dense(
|
||||
1,
|
||||
activation='linear',
|
||||
name='dense',
|
||||
kernel_initializer='zeros',
|
||||
bias_initializer='zeros')
|
||||
preds = layer(features)
|
||||
|
||||
vector_loss = 0.5 * tf.math.squared_difference(labels, preds)
|
||||
scalar_loss = tf.reduce_mean(input_tensor=vector_loss)
|
||||
|
||||
optimizer = dp_optimizer_keras.DPFTRLTreeAggregationOptimizer(
|
||||
l2_norm_clip=l2_norm_clip,
|
||||
noise_multiplier=noise_multiplier,
|
||||
num_microbatches=num_microbatches,
|
||||
var_list_or_model=layer,
|
||||
restart_period=restart_period,
|
||||
restart_warmup=restart_warmup,
|
||||
learning_rate=learning_rate)
|
||||
|
||||
params = layer.trainable_weights
|
||||
global_step = tf.compat.v1.train.get_global_step()
|
||||
train_op = tf.group(
|
||||
optimizer.get_updates(loss=vector_loss, params=params),
|
||||
[tf.compat.v1.assign_add(global_step, 1)])
|
||||
return tf_estimator.EstimatorSpec(
|
||||
mode=mode, loss=scalar_loss, train_op=train_op)
|
||||
|
||||
return linear_model_fn
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('1_None_None', 1, None, None),
|
||||
('2_1_1', 2, 2, 1),
|
||||
('4_1_None', 4, 1, None),
|
||||
('4_4_2', 4, 4, 2),
|
||||
)
|
||||
def testBaselineNoNoise(self, num_microbatches, restart_period,
|
||||
restart_warmup):
|
||||
"""Tests that DP optimizers work with tf.estimator."""
|
||||
|
||||
linear_regressor = tf_estimator.Estimator(
|
||||
model_fn=self._make_linear_model_fn(
|
||||
l2_norm_clip=100.0,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=num_microbatches,
|
||||
restart_period=restart_period,
|
||||
restart_warmup=restart_warmup,
|
||||
learning_rate=0.05))
|
||||
|
||||
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
|
||||
true_bias = np.array([6.0]).astype(np.float32)
|
||||
train_data = np.random.normal(scale=3.0, size=(1000, 4)).astype(np.float32)
|
||||
|
||||
train_labels = np.matmul(train_data,
|
||||
true_weights) + true_bias + np.random.normal(
|
||||
scale=0.0, size=(1000, 1)).astype(np.float32)
|
||||
|
||||
def train_input_fn():
|
||||
return tf.data.Dataset.from_tensor_slices(
|
||||
(train_data, train_labels)).batch(8)
|
||||
|
||||
linear_regressor.train(input_fn=train_input_fn, steps=125)
|
||||
|
||||
self.assertAllClose(
|
||||
linear_regressor.get_variable_value('dense/kernel'),
|
||||
true_weights,
|
||||
atol=0.05)
|
||||
self.assertAllClose(
|
||||
linear_regressor.get_variable_value('dense/bias'), true_bias, atol=0.05)
|
||||
|
||||
def testRaisesOnNoCallOfGetGradients(self):
|
||||
"""Tests that assertion fails when DP gradients are not computed."""
|
||||
grads_and_vars = tf.Variable([0.0])
|
||||
optimizer = dp_optimizer_keras.DPFTRLTreeAggregationOptimizer(
|
||||
l2_norm_clip=100.0,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=1,
|
||||
var_list_or_model=[grads_and_vars],
|
||||
restart_period=None,
|
||||
restart_warmup=None,
|
||||
learning_rate=2.0)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
optimizer.apply_gradients(grads_and_vars)
|
||||
|
||||
def testLargeBatchEmulationNoNoise(self):
|
||||
# Test for emulation of large batch training.
|
||||
# It tests that updates are only done every gradient_accumulation_steps
|
||||
# steps.
|
||||
# In this test we set noise multiplier to zero and clipping norm to high
|
||||
# value, such that optimizer essentially behave as non-DP optimizer.
|
||||
# This makes easier to check how values of variables are changing.
|
||||
#
|
||||
# This test optimizes loss var0*x + var1
|
||||
# Gradients of this loss are computed as:
|
||||
# d(loss)/d(var0) = x
|
||||
# d(loss)/d(var1) = 1
|
||||
var0 = tf.Variable([[1.0, 2.0]], dtype=tf.float32)
|
||||
var1 = tf.Variable([3.0], dtype=tf.float32)
|
||||
x1 = tf.constant([[2.0, 0.0], [0.0, 1.0]], dtype=tf.float32)
|
||||
loss1 = lambda: tf.matmul(var0, x1, transpose_b=True) + var1
|
||||
x2 = tf.constant([[4.0, 2.0], [2.0, 1.0]], dtype=tf.float32)
|
||||
loss2 = lambda: tf.matmul(var0, x2, transpose_b=True) + var1
|
||||
variables = [var0, var1]
|
||||
|
||||
optimizer = dp_optimizer_keras.DPFTRLTreeAggregationOptimizer(
|
||||
l2_norm_clip=100.0,
|
||||
noise_multiplier=0.0,
|
||||
gradient_accumulation_steps=2,
|
||||
var_list_or_model=variables,
|
||||
restart_period=None,
|
||||
restart_warmup=None,
|
||||
learning_rate=1.0)
|
||||
|
||||
# before any call to optimizer
|
||||
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
|
||||
self.assertAllCloseAccordingToType([3.0], var1)
|
||||
|
||||
optimizer.minimize(loss1, variables)
|
||||
# After first call to optimizer values didn't change
|
||||
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
|
||||
self.assertAllCloseAccordingToType([3.0], var1)
|
||||
|
||||
optimizer.minimize(loss2, variables)
|
||||
# After second call to optimizer updates were applied
|
||||
self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0)
|
||||
self.assertAllCloseAccordingToType([2.0], var1)
|
||||
|
||||
optimizer.minimize(loss2, variables)
|
||||
# After third call to optimizer values didn't change
|
||||
self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0)
|
||||
self.assertAllCloseAccordingToType([2.0], var1)
|
||||
|
||||
optimizer.minimize(loss2, variables)
|
||||
# After fourth call to optimizer updates were applied again
|
||||
self.assertAllCloseAccordingToType([[-4.0, -0.5]], var0)
|
||||
self.assertAllCloseAccordingToType([1.0], var1)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('1', 1),
|
||||
('2', 2),
|
||||
('4', 4),
|
||||
)
|
||||
def testLargeBatchEmulation(self, gradient_accumulation_steps):
|
||||
# Uses clipping and noise, thus does not test specific values
|
||||
# of the variables and only tests how often variables are updated.
|
||||
var0 = tf.Variable([[1.0, 2.0]], dtype=tf.float32)
|
||||
var1 = tf.Variable([3.0], dtype=tf.float32)
|
||||
x = tf.constant([[2.0, 0.0], [0.0, 1.0]], dtype=tf.float32)
|
||||
loss = lambda: tf.matmul(var0, x, transpose_b=True) + var1
|
||||
variables = [var0, var1]
|
||||
|
||||
optimizer = dp_optimizer_keras.DPFTRLTreeAggregationOptimizer(
|
||||
l2_norm_clip=100.0,
|
||||
noise_multiplier=0.0,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
var_list_or_model=variables,
|
||||
restart_period=None,
|
||||
restart_warmup=None,
|
||||
learning_rate=1.0)
|
||||
|
||||
for _ in range(gradient_accumulation_steps):
|
||||
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
|
||||
self.assertAllCloseAccordingToType([3.0], var1)
|
||||
optimizer.minimize(loss, variables)
|
||||
|
||||
self.assertNotAllClose([[1.0, 2.0]], var0)
|
||||
self.assertNotAllClose([3.0], var1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in a new issue