forked from 626_privacy/tensorflow_privacy
Changes DPOptimizerClass
to generically accept and use any dp_sum_query
.
This enables creation of generic DPOptimizers by user's passing queries. The most common Gaussian query is automatically performed for convenience and backwards compatibility. Byproducts of this update: -ensures consistent implementations between the internal (and legacy) `get_gradients` and newer `_compute_gradients` for all queries. -refactors for python readability. -includes new tests ensuring that `_num_microbatches=None` is tested. -changes the `_global_state` to to be initialized in the init function for `_compute_gradients`. PiperOrigin-RevId: 480668376
This commit is contained in:
parent
f8ed0fcd9c
commit
79fe32a60b
4 changed files with 388 additions and 215 deletions
|
@ -61,9 +61,14 @@ else:
|
||||||
from tensorflow_privacy.privacy.keras_models.dp_keras_model import make_dp_model_class
|
from tensorflow_privacy.privacy.keras_models.dp_keras_model import make_dp_model_class
|
||||||
|
|
||||||
# Optimizers
|
# Optimizers
|
||||||
|
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPAdagradOptimizer
|
||||||
|
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPAdamOptimizer
|
||||||
|
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPSGDOptimizer
|
||||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdagradOptimizer
|
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdagradOptimizer
|
||||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdamOptimizer
|
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdamOptimizer
|
||||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
|
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
|
||||||
|
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_gaussian_query_optimizer_class
|
||||||
|
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_keras_generic_optimizer_class
|
||||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_keras_optimizer_class
|
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_keras_optimizer_class
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras_vectorized import VectorizedDPKerasAdagradOptimizer
|
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras_vectorized import VectorizedDPKerasAdagradOptimizer
|
||||||
|
|
|
@ -18,6 +18,18 @@ py_library(
|
||||||
deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"],
|
deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "dp_optimizer_factory",
|
||||||
|
srcs = [
|
||||||
|
"dp_optimizer_keras.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow_privacy/privacy/dp_query",
|
||||||
|
"//tensorflow_privacy/privacy/dp_query:gaussian_query",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "dp_optimizer_vectorized",
|
name = "dp_optimizer_vectorized",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -32,7 +44,10 @@ py_library(
|
||||||
"dp_optimizer_keras.py",
|
"dp_optimizer_keras.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"],
|
deps = [
|
||||||
|
"//tensorflow_privacy/privacy/dp_query",
|
||||||
|
"//tensorflow_privacy/privacy/dp_query:gaussian_query",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
@ -84,7 +99,7 @@ py_test(
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
|
":dp_optimizer_keras",
|
||||||
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras_vectorized",
|
":dp_optimizer_keras_vectorized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,29 +13,40 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Differentially private version of Keras optimizer v2."""
|
"""Differentially private version of Keras optimizer v2."""
|
||||||
|
from typing import Optional, Type
|
||||||
|
import warnings
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||||
|
|
||||||
|
|
||||||
def make_keras_optimizer_class(cls):
|
def _normalize(microbatch_gradient: tf.Tensor,
|
||||||
"""Given a subclass of `tf.keras.optimizers.legacy.Optimizer`, returns a DP-SGD subclass of it.
|
num_microbatches: float) -> tf.Tensor:
|
||||||
|
"""Normalizes `microbatch_gradient` by `num_microbatches`."""
|
||||||
|
return tf.truediv(microbatch_gradient,
|
||||||
|
tf.cast(num_microbatches, microbatch_gradient.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def make_keras_generic_optimizer_class(
|
||||||
|
cls: Type[tf.keras.optimizers.Optimizer]):
|
||||||
|
"""Returns a differentially private (DP) subclass of `cls`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cls: Class from which to derive a DP subclass. Should be a subclass of
|
cls: Class from which to derive a DP subclass. Should be a subclass of
|
||||||
`tf.keras.optimizers.legacy.Optimizer`.
|
`tf.keras.optimizers.legacy.Optimizer`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A DP-SGD subclass of `cls`.
|
A generic DP-SGD subclass of `cls`, compatible with many DP queries.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class DPOptimizerClass(cls): # pylint: disable=empty-docstring
|
class DPOptimizerClass(cls): # pylint: disable=empty-docstring,missing-class-docstring
|
||||||
__doc__ = """Differentially private subclass of class `{base_class}`.
|
__doc__ = """Differentially private subclass of class `{base_class}`.
|
||||||
|
|
||||||
You can use this as a differentially private replacement for
|
You can use this as a differentially private replacement for
|
||||||
`{base_class}`. This optimizer implements DP-SGD using
|
`{base_class}`. This optimizer implements a differentiallyy private version
|
||||||
the standard Gaussian mechanism.
|
of the stochastic gradient descent optimizer `cls` using the chosen
|
||||||
|
`dp_query.DPQuery` instance.
|
||||||
|
|
||||||
When instantiating this optimizer, you need to supply several
|
When instantiating this optimizer, you need to supply several
|
||||||
DP-related arguments followed by the standard arguments for
|
DP-related arguments followed by the standard arguments for
|
||||||
|
@ -45,8 +56,10 @@ def make_keras_optimizer_class(cls):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Create optimizer.
|
# Create optimizer.
|
||||||
opt = {dp_keras_class}(l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1,
|
gaussian_query = gaussian_query.GaussianSumQuery(
|
||||||
<standard arguments>)
|
l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1
|
||||||
|
)
|
||||||
|
opt = {dp_keras_class}(dp_sum_query=gaussian_query, <standard arguments>)
|
||||||
```
|
```
|
||||||
|
|
||||||
When using the optimizer, be sure to pass in the loss as a
|
When using the optimizer, be sure to pass in the loss as a
|
||||||
|
@ -92,8 +105,10 @@ def make_keras_optimizer_class(cls):
|
||||||
```python
|
```python
|
||||||
# Create optimizer which will be accumulating gradients for 4 steps.
|
# Create optimizer which will be accumulating gradients for 4 steps.
|
||||||
# and then performing an update of model weights.
|
# and then performing an update of model weights.
|
||||||
opt = {dp_keras_class}(l2_norm_clip=1.0,
|
gaussian_query = gaussian_query.GaussianSumQuery(
|
||||||
noise_multiplier=0.5,
|
l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1
|
||||||
|
)
|
||||||
|
opt = {dp_keras_class}(dp_sum_query=gaussian_query,
|
||||||
num_microbatches=1,
|
num_microbatches=1,
|
||||||
gradient_accumulation_steps=4,
|
gradient_accumulation_steps=4,
|
||||||
<standard arguments>)
|
<standard arguments>)
|
||||||
|
@ -138,24 +153,23 @@ def make_keras_optimizer_class(cls):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
l2_norm_clip,
|
dp_sum_query: dp_query.DPQuery,
|
||||||
noise_multiplier,
|
num_microbatches: Optional[int] = None,
|
||||||
num_microbatches=None,
|
gradient_accumulation_steps: int = 1,
|
||||||
gradient_accumulation_steps=1,
|
|
||||||
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Initialize the DPOptimizerClass.
|
"""Initializes the DPOptimizerClass.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
|
dp_sum_query: `DPQuery` object, specifying differential privacy
|
||||||
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
mechanism to use.
|
||||||
num_microbatches: Number of microbatches into which each minibatch is
|
num_microbatches: Number of microbatches into which each minibatch is
|
||||||
split. Default is `None` which means that number of microbatches
|
split. Default is `None` which means that number of microbatches is
|
||||||
is equal to batch size (i.e. each microbatch contains exactly one
|
equal to batch size (i.e. each microbatch contains exactly one
|
||||||
example). If `gradient_accumulation_steps` is greater than 1 and
|
example). If `gradient_accumulation_steps` is greater than 1 and
|
||||||
`num_microbatches` is not `None` then the effective number of
|
`num_microbatches` is not `None` then the effective number of
|
||||||
microbatches is equal to
|
microbatches is equal to `num_microbatches *
|
||||||
`num_microbatches * gradient_accumulation_steps`.
|
gradient_accumulation_steps`.
|
||||||
gradient_accumulation_steps: If greater than 1 then optimizer will be
|
gradient_accumulation_steps: If greater than 1 then optimizer will be
|
||||||
accumulating gradients for this number of optimizer steps before
|
accumulating gradients for this number of optimizer steps before
|
||||||
applying them to update model weights. If this argument is set to 1
|
applying them to update model weights. If this argument is set to 1
|
||||||
|
@ -165,13 +179,16 @@ def make_keras_optimizer_class(cls):
|
||||||
"""
|
"""
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||||
self._l2_norm_clip = l2_norm_clip
|
|
||||||
self._noise_multiplier = noise_multiplier
|
|
||||||
self._num_microbatches = num_microbatches
|
self._num_microbatches = num_microbatches
|
||||||
self._dp_sum_query = gaussian_query.GaussianSumQuery(
|
self._dp_sum_query = dp_sum_query
|
||||||
l2_norm_clip, l2_norm_clip * noise_multiplier)
|
|
||||||
self._global_state = None
|
|
||||||
self._was_dp_gradients_called = False
|
self._was_dp_gradients_called = False
|
||||||
|
# We initialize here for `_compute_gradients` because of requirements from
|
||||||
|
# the tf.keras.Model API. Specifically, keras models use the
|
||||||
|
# `_compute_gradients` method for both eager and graph mode. So,
|
||||||
|
# instantiating the state here is necessary to avoid graph compilation
|
||||||
|
# issues.
|
||||||
|
|
||||||
|
self._global_state = self._dp_sum_query.initial_global_state()
|
||||||
|
|
||||||
def _create_slots(self, var_list):
|
def _create_slots(self, var_list):
|
||||||
super()._create_slots(var_list) # pytype: disable=attribute-error
|
super()._create_slots(var_list) # pytype: disable=attribute-error
|
||||||
|
@ -233,82 +250,88 @@ def make_keras_optimizer_class(cls):
|
||||||
|
|
||||||
def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
|
def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
|
||||||
"""DP-SGD version of base class method."""
|
"""DP-SGD version of base class method."""
|
||||||
|
|
||||||
self._was_dp_gradients_called = True
|
self._was_dp_gradients_called = True
|
||||||
|
|
||||||
# Compute loss.
|
# Compute loss.
|
||||||
if not callable(loss) and tape is None:
|
if not callable(loss) and tape is None:
|
||||||
raise ValueError('`tape` is required when a `Tensor` loss is passed.')
|
raise ValueError('`tape` is required when a `Tensor` loss is passed.')
|
||||||
|
|
||||||
tape = tape if tape is not None else tf.GradientTape()
|
tape = tape if tape is not None else tf.GradientTape()
|
||||||
|
|
||||||
if callable(loss):
|
with tape:
|
||||||
with tape:
|
if callable(loss):
|
||||||
if not callable(var_list):
|
if not callable(var_list):
|
||||||
tape.watch(var_list)
|
tape.watch(var_list)
|
||||||
|
|
||||||
loss = loss()
|
loss = loss()
|
||||||
if self._num_microbatches is None:
|
|
||||||
num_microbatches = tf.shape(input=loss)[0]
|
|
||||||
else:
|
|
||||||
num_microbatches = self._num_microbatches
|
|
||||||
microbatch_losses = tf.reduce_mean(
|
|
||||||
tf.reshape(loss, [num_microbatches, -1]), axis=1)
|
|
||||||
|
|
||||||
if callable(var_list):
|
if self._num_microbatches is None:
|
||||||
var_list = var_list()
|
num_microbatches = tf.shape(input=loss)[0]
|
||||||
else:
|
else:
|
||||||
with tape:
|
num_microbatches = self._num_microbatches
|
||||||
if self._num_microbatches is None:
|
microbatch_losses = tf.reduce_mean(
|
||||||
num_microbatches = tf.shape(input=loss)[0]
|
tf.reshape(loss, [num_microbatches, -1]), axis=1)
|
||||||
else:
|
|
||||||
num_microbatches = self._num_microbatches
|
if callable(var_list):
|
||||||
microbatch_losses = tf.reduce_mean(
|
var_list = var_list()
|
||||||
tf.reshape(loss, [num_microbatches, -1]), axis=1)
|
|
||||||
|
|
||||||
var_list = tf.nest.flatten(var_list)
|
var_list = tf.nest.flatten(var_list)
|
||||||
|
|
||||||
|
sample_params = (
|
||||||
|
self._dp_sum_query.derive_sample_params(self._global_state))
|
||||||
|
|
||||||
# Compute the per-microbatch losses using helpful jacobian method.
|
# Compute the per-microbatch losses using helpful jacobian method.
|
||||||
with tf.keras.backend.name_scope(self._name + '/gradients'):
|
with tf.keras.backend.name_scope(self._name + '/gradients'):
|
||||||
jacobian = tape.jacobian(
|
jacobian_per_var = tape.jacobian(
|
||||||
microbatch_losses, var_list, unconnected_gradients='zero')
|
microbatch_losses, var_list, unconnected_gradients='zero')
|
||||||
|
|
||||||
def clip_gradients(g):
|
def process_microbatch(sample_state, microbatch_jacobians):
|
||||||
"""Clips gradients to given l2_norm_clip."""
|
"""Process one microbatch (record) with privacy helper."""
|
||||||
return tf.clip_by_global_norm(g, self._l2_norm_clip)[0]
|
sample_state = self._dp_sum_query.accumulate_record(
|
||||||
|
sample_params, sample_state, microbatch_jacobians)
|
||||||
|
return sample_state
|
||||||
|
|
||||||
# Clip all gradients. Note that `tf.map_fn` applies the given function
|
sample_state = self._dp_sum_query.initial_sample_state(var_list)
|
||||||
# to its arguments unstacked along axis 0.
|
|
||||||
clipped_gradients = tf.map_fn(clip_gradients, jacobian)
|
|
||||||
|
|
||||||
def reduce_noise_normalize_batch(g):
|
def body_fn(idx, sample_state):
|
||||||
# Sum gradients over all microbatches.
|
microbatch_jacobians_per_var = [
|
||||||
summed_gradient = tf.reduce_sum(g, axis=0)
|
jacobian[idx] for jacobian in jacobian_per_var
|
||||||
|
]
|
||||||
|
sample_state = process_microbatch(sample_state,
|
||||||
|
microbatch_jacobians_per_var)
|
||||||
|
return tf.add(idx, 1), sample_state
|
||||||
|
|
||||||
# Add noise to summed gradients.
|
cond_fn = lambda idx, _: tf.less(idx, num_microbatches)
|
||||||
noise_stddev = self._l2_norm_clip * self._noise_multiplier
|
idx = tf.constant(0)
|
||||||
noise = tf.random.normal(
|
_, sample_state = tf.while_loop(cond_fn, body_fn, [idx, sample_state])
|
||||||
tf.shape(input=summed_gradient), stddev=noise_stddev)
|
|
||||||
noised_gradient = tf.add(summed_gradient, noise)
|
|
||||||
|
|
||||||
# Normalize by number of microbatches and return.
|
grad_sums, self._global_state, _ = (
|
||||||
return tf.truediv(noised_gradient,
|
self._dp_sum_query.get_noised_result(sample_state,
|
||||||
tf.cast(num_microbatches, tf.float32))
|
self._global_state))
|
||||||
|
final_grads = tf.nest.map_structure(_normalize, grad_sums,
|
||||||
|
[num_microbatches] * len(grad_sums))
|
||||||
|
|
||||||
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
|
return list(zip(final_grads, var_list))
|
||||||
clipped_gradients)
|
|
||||||
|
|
||||||
return list(zip(final_gradients, var_list))
|
|
||||||
|
|
||||||
def get_gradients(self, loss, params):
|
def get_gradients(self, loss, params):
|
||||||
"""DP-SGD version of base class method."""
|
"""DP-SGD version of base class method."""
|
||||||
|
if not self._was_dp_gradients_called:
|
||||||
self._was_dp_gradients_called = True
|
# We create the global state here due to tf.Estimator API requirements,
|
||||||
if self._global_state is None:
|
# specifically, that instantiating the global state outside this
|
||||||
|
# function leads to graph compilation errors of attempting to capture an
|
||||||
|
# EagerTensor.
|
||||||
self._global_state = self._dp_sum_query.initial_global_state()
|
self._global_state = self._dp_sum_query.initial_global_state()
|
||||||
|
self._was_dp_gradients_called = True
|
||||||
|
|
||||||
# This code mostly follows the logic in the original DPOptimizerClass
|
# This code mostly follows the logic in the original DPOptimizerClass
|
||||||
# in dp_optimizer.py, except that this returns only the gradients,
|
# in dp_optimizer.py, except that this returns only the gradients,
|
||||||
# not the gradients and variables.
|
# not the gradients and variables.
|
||||||
microbatch_losses = tf.reshape(loss, [self._num_microbatches, -1])
|
if self._num_microbatches is None:
|
||||||
|
num_microbatches = tf.shape(input=loss)[0]
|
||||||
|
else:
|
||||||
|
num_microbatches = self._num_microbatches
|
||||||
|
|
||||||
|
microbatch_losses = tf.reshape(loss, [num_microbatches, -1])
|
||||||
sample_params = (
|
sample_params = (
|
||||||
self._dp_sum_query.derive_sample_params(self._global_state))
|
self._dp_sum_query.derive_sample_params(self._global_state))
|
||||||
|
|
||||||
|
@ -322,19 +345,20 @@ def make_keras_optimizer_class(cls):
|
||||||
return sample_state
|
return sample_state
|
||||||
|
|
||||||
sample_state = self._dp_sum_query.initial_sample_state(params)
|
sample_state = self._dp_sum_query.initial_sample_state(params)
|
||||||
for idx in range(self._num_microbatches):
|
|
||||||
|
def body_fn(idx, sample_state):
|
||||||
sample_state = process_microbatch(idx, sample_state)
|
sample_state = process_microbatch(idx, sample_state)
|
||||||
|
return tf.add(idx, 1), sample_state
|
||||||
|
|
||||||
|
cond_fn = lambda idx, _: tf.less(idx, num_microbatches)
|
||||||
|
idx = tf.constant(0)
|
||||||
|
_, sample_state = tf.while_loop(cond_fn, body_fn, [idx, sample_state])
|
||||||
grad_sums, self._global_state, _ = (
|
grad_sums, self._global_state, _ = (
|
||||||
self._dp_sum_query.get_noised_result(sample_state,
|
self._dp_sum_query.get_noised_result(sample_state,
|
||||||
self._global_state))
|
self._global_state))
|
||||||
|
|
||||||
def normalize(v):
|
final_grads = tf.nest.map_structure(_normalize, grad_sums,
|
||||||
try:
|
[num_microbatches] * len(grad_sums))
|
||||||
return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32))
|
|
||||||
except TypeError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
final_grads = tf.nest.map_structure(normalize, grad_sums)
|
|
||||||
|
|
||||||
return final_grads
|
return final_grads
|
||||||
|
|
||||||
|
@ -351,8 +375,7 @@ def make_keras_optimizer_class(cls):
|
||||||
"""
|
"""
|
||||||
config = super().get_config()
|
config = super().get_config()
|
||||||
config.update({
|
config.update({
|
||||||
'l2_norm_clip': self._l2_norm_clip,
|
'global_state': self._global_state._asdict(),
|
||||||
'noise_multiplier': self._noise_multiplier,
|
|
||||||
'num_microbatches': self._num_microbatches,
|
'num_microbatches': self._num_microbatches,
|
||||||
})
|
})
|
||||||
return config
|
return config
|
||||||
|
@ -370,8 +393,103 @@ def make_keras_optimizer_class(cls):
|
||||||
return DPOptimizerClass
|
return DPOptimizerClass
|
||||||
|
|
||||||
|
|
||||||
DPKerasAdagradOptimizer = make_keras_optimizer_class(
|
def make_gaussian_query_optimizer_class(cls):
|
||||||
|
"""Returns a differentially private optimizer using the `GaussianSumQuery`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls: `DPOptimizerClass`, the output of `make_keras_optimizer_class`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A DP-SGD subclass of `cls` using the `GaussianQuery`, the canonical DP-SGD
|
||||||
|
implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def return_gaussian_query_optimizer(
|
||||||
|
l2_norm_clip: float,
|
||||||
|
noise_multiplier: float,
|
||||||
|
num_microbatches: Optional[int] = None,
|
||||||
|
gradient_accumulation_steps: int = 1,
|
||||||
|
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
||||||
|
**kwargs):
|
||||||
|
"""Returns a `DPOptimizerClass` `cls` using the `GaussianSumQuery`.
|
||||||
|
|
||||||
|
This function is a thin wrapper around
|
||||||
|
`make_keras_optimizer_class.<locals>.DPOptimizerClass` which can be used to
|
||||||
|
apply a `GaussianSumQuery` to any `DPOptimizerClass`.
|
||||||
|
|
||||||
|
When combined with stochastic gradient descent, this creates the canonical
|
||||||
|
DP-SGD algorithm of "Deep Learning with Differential Privacy"
|
||||||
|
(see https://arxiv.org/abs/1607.00133).
|
||||||
|
|
||||||
|
When instantiating this optimizer, you need to supply several
|
||||||
|
DP-related arguments followed by the standard arguments for
|
||||||
|
`{short_base_class}`.
|
||||||
|
|
||||||
|
As an example, see the below or the documentation of the DPOptimizerClass.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Create optimizer.
|
||||||
|
opt = {dp_keras_class}(l2_norm_clip=1.0, noise_multiplier=0.5,
|
||||||
|
num_microbatches=1, <standard arguments>)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
|
||||||
|
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
||||||
|
num_microbatches: Number of microbatches into which each minibatch is
|
||||||
|
split. Default is `None` which means that number of microbatches is
|
||||||
|
equal to batch size (i.e. each microbatch contains exactly one example).
|
||||||
|
If `gradient_accumulation_steps` is greater than 1 and
|
||||||
|
`num_microbatches` is not `None` then the effective number of
|
||||||
|
microbatches is equal to `num_microbatches *
|
||||||
|
gradient_accumulation_steps`.
|
||||||
|
gradient_accumulation_steps: If greater than 1 then optimizer will be
|
||||||
|
accumulating gradients for this number of optimizer steps before
|
||||||
|
applying them to update model weights. If this argument is set to 1 then
|
||||||
|
updates will be applied on each optimizer step.
|
||||||
|
*args: These will be passed on to the base class `__init__` method.
|
||||||
|
**kwargs: These will be passed on to the base class `__init__` method.
|
||||||
|
"""
|
||||||
|
dp_sum_query = gaussian_query.GaussianSumQuery(
|
||||||
|
l2_norm_clip, l2_norm_clip * noise_multiplier)
|
||||||
|
return cls(
|
||||||
|
dp_sum_query=dp_sum_query,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
|
*args,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
return return_gaussian_query_optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def make_keras_optimizer_class(cls: Type[tf.keras.optimizers.Optimizer]):
|
||||||
|
"""Returns a differentially private optimizer using the `GaussianSumQuery`.
|
||||||
|
|
||||||
|
For backwards compatibility, we create this symbol to match the previous
|
||||||
|
output of `make_keras_optimizer_class` but using the new logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls: Class from which to derive a DP subclass. Should be a subclass of
|
||||||
|
`tf.keras.optimizers.Optimizer`.
|
||||||
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
'`make_keras_optimizer_class` will be depracated on 2023-02-23. '
|
||||||
|
'Please switch to `make_gaussian_query_optimizer_class` and the '
|
||||||
|
'generic optimizers (`make_keras_generic_optimizer_class`).')
|
||||||
|
return make_gaussian_query_optimizer_class(
|
||||||
|
make_keras_generic_optimizer_class(cls))
|
||||||
|
|
||||||
|
|
||||||
|
GenericDPAdagradOptimizer = make_keras_generic_optimizer_class(
|
||||||
tf.keras.optimizers.legacy.Adagrad)
|
tf.keras.optimizers.legacy.Adagrad)
|
||||||
DPKerasAdamOptimizer = make_keras_optimizer_class(
|
GenericDPAdamOptimizer = make_keras_generic_optimizer_class(
|
||||||
tf.keras.optimizers.legacy.Adam)
|
tf.keras.optimizers.legacy.Adam)
|
||||||
DPKerasSGDOptimizer = make_keras_optimizer_class(tf.keras.optimizers.legacy.SGD)
|
GenericDPSGDOptimizer = make_keras_generic_optimizer_class(
|
||||||
|
tf.keras.optimizers.legacy.SGD)
|
||||||
|
|
||||||
|
# We keep the same names for backwards compatibility.
|
||||||
|
DPKerasAdagradOptimizer = make_gaussian_query_optimizer_class(
|
||||||
|
GenericDPAdagradOptimizer)
|
||||||
|
DPKerasAdamOptimizer = make_gaussian_query_optimizer_class(
|
||||||
|
GenericDPAdamOptimizer)
|
||||||
|
DPKerasSGDOptimizer = make_gaussian_query_optimizer_class(GenericDPSGDOptimizer)
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -29,36 +28,30 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
return 0.5 * tf.reduce_sum(
|
return 0.5 * tf.reduce_sum(
|
||||||
input_tensor=tf.math.squared_difference(val0, val1), axis=1)
|
input_tensor=tf.math.squared_difference(val0, val1), axis=1)
|
||||||
|
|
||||||
# Parameters for testing: optimizer, num_microbatches, expected gradient for
|
|
||||||
# var0, expected gradient for var1.
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1,
|
('DPGradientDescent_1', dp_optimizer_keras.DPKerasSGDOptimizer, 1),
|
||||||
[-2.5, -2.5], [-0.5]),
|
('DPGradientDescent_None', dp_optimizer_keras.DPKerasSGDOptimizer, None),
|
||||||
('DPAdam 2', dp_optimizer_keras.DPKerasAdamOptimizer, 2, [-2.5, -2.5
|
('DPAdam_2', dp_optimizer_keras.DPKerasAdamOptimizer, 2),
|
||||||
], [-0.5]),
|
('DPAdagrad _4', dp_optimizer_keras.DPKerasAdagradOptimizer, 4),
|
||||||
('DPAdagrad 4', dp_optimizer_keras.DPKerasAdagradOptimizer, 4,
|
('DPGradientDescentVectorized_1',
|
||||||
[-2.5, -2.5], [-0.5]),
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1),
|
||||||
('DPGradientDescentVectorized 1',
|
('DPAdamVectorized_2',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1,
|
dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer, 2),
|
||||||
[-2.5, -2.5], [-0.5]),
|
('DPAdagradVectorized_4',
|
||||||
('DPAdamVectorized 2',
|
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4),
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer, 2,
|
('DPAdagradVectorized_None',
|
||||||
[-2.5, -2.5], [-0.5]),
|
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None),
|
||||||
('DPAdagradVectorized 4',
|
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4,
|
|
||||||
[-2.5, -2.5], [-0.5]),
|
|
||||||
('DPAdagradVectorized None',
|
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None,
|
|
||||||
[-2.5, -2.5], [-0.5]),
|
|
||||||
)
|
)
|
||||||
def testBaselineWithCallableLoss(self, cls, num_microbatches, expected_grad0,
|
def testBaselineWithCallableLossNoNoise(self, optimizer_class,
|
||||||
expected_grad1):
|
num_microbatches):
|
||||||
var0 = tf.Variable([1.0, 2.0])
|
var0 = tf.Variable([1.0, 2.0])
|
||||||
var1 = tf.Variable([3.0])
|
var1 = tf.Variable([3.0])
|
||||||
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
||||||
data1 = tf.Variable([[8.0], [2.0], [3.0], [1.0]])
|
data1 = tf.Variable([[8.0], [2.0], [3.0], [1.0]])
|
||||||
|
expected_grad0 = [-2.5, -2.5]
|
||||||
|
expected_grad1 = [-0.5]
|
||||||
|
|
||||||
opt = cls(
|
optimizer = optimizer_class(
|
||||||
l2_norm_clip=100.0,
|
l2_norm_clip=100.0,
|
||||||
noise_multiplier=0.0,
|
noise_multiplier=0.0,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
|
@ -66,40 +59,68 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
loss = lambda: self._loss(data0, var0) + self._loss(data1, var1)
|
loss = lambda: self._loss(data0, var0) + self._loss(data1, var1)
|
||||||
|
|
||||||
grads_and_vars = opt._compute_gradients(loss, [var0, var1])
|
grads_and_vars = optimizer._compute_gradients(loss, [var0, var1])
|
||||||
|
|
||||||
self.assertAllCloseAccordingToType(expected_grad0, grads_and_vars[0][0])
|
self.assertAllCloseAccordingToType(expected_grad0, grads_and_vars[0][0])
|
||||||
self.assertAllCloseAccordingToType(expected_grad1, grads_and_vars[1][0])
|
self.assertAllCloseAccordingToType(expected_grad1, grads_and_vars[1][0])
|
||||||
|
|
||||||
# Parameters for testing: optimizer, num_microbatches, expected gradient for
|
def testKerasModelBaselineNoNoiseNoneMicrobatches(self):
|
||||||
# var0, expected gradient for var1.
|
"""Tests that DP optimizers work with tf.keras.Model."""
|
||||||
|
|
||||||
|
model = tf.keras.models.Sequential(layers=[
|
||||||
|
tf.keras.layers.Dense(
|
||||||
|
1,
|
||||||
|
activation='linear',
|
||||||
|
name='dense',
|
||||||
|
kernel_initializer='zeros',
|
||||||
|
bias_initializer='zeros')
|
||||||
|
])
|
||||||
|
|
||||||
|
optimizer = dp_optimizer_keras.DPKerasSGDOptimizer(
|
||||||
|
l2_norm_clip=100.0,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
num_microbatches=None,
|
||||||
|
learning_rate=0.05)
|
||||||
|
loss = tf.keras.losses.MeanSquaredError(reduction='none')
|
||||||
|
model.compile(optimizer, loss)
|
||||||
|
|
||||||
|
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
|
||||||
|
true_bias = np.array([6.0]).astype(np.float32)
|
||||||
|
train_data = np.random.normal(scale=3.0, size=(1000, 4)).astype(np.float32)
|
||||||
|
|
||||||
|
train_labels = np.matmul(train_data,
|
||||||
|
true_weights) + true_bias + np.random.normal(
|
||||||
|
scale=0.0, size=(1000, 1)).astype(np.float32)
|
||||||
|
|
||||||
|
model.fit(train_data, train_labels, batch_size=8, epochs=1, shuffle=False)
|
||||||
|
|
||||||
|
self.assertAllClose(model.get_weights()[0], true_weights, atol=0.05)
|
||||||
|
self.assertAllClose(model.get_weights()[1], true_bias, atol=0.05)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1,
|
('DPGradientDescent_1', dp_optimizer_keras.DPKerasSGDOptimizer, 1),
|
||||||
[-2.5, -2.5], [-0.5]),
|
('DPGradientDescent_None', dp_optimizer_keras.DPKerasSGDOptimizer, None),
|
||||||
('DPAdam 2', dp_optimizer_keras.DPKerasAdamOptimizer, 2, [-2.5, -2.5
|
('DPAdam_2', dp_optimizer_keras.DPKerasAdamOptimizer, 2),
|
||||||
], [-0.5]),
|
('DPAdagrad_4', dp_optimizer_keras.DPKerasAdagradOptimizer, 4),
|
||||||
('DPAdagrad 4', dp_optimizer_keras.DPKerasAdagradOptimizer, 4,
|
('DPGradientDescentVectorized_1',
|
||||||
[-2.5, -2.5], [-0.5]),
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1),
|
||||||
('DPGradientDescentVectorized 1',
|
('DPAdamVectorized_2',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1,
|
dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer, 2),
|
||||||
[-2.5, -2.5], [-0.5]),
|
('DPAdagradVectorized_4',
|
||||||
('DPAdamVectorized 2',
|
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4),
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer, 2,
|
('DPAdagradVectorized_None',
|
||||||
[-2.5, -2.5], [-0.5]),
|
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None),
|
||||||
('DPAdagradVectorized 4',
|
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4,
|
|
||||||
[-2.5, -2.5], [-0.5]),
|
|
||||||
('DPAdagradVectorized None',
|
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None,
|
|
||||||
[-2.5, -2.5], [-0.5]),
|
|
||||||
)
|
)
|
||||||
def testBaselineWithTensorLoss(self, cls, num_microbatches, expected_grad0,
|
def testBaselineWithTensorLossNoNoise(self, optimizer_class,
|
||||||
expected_grad1):
|
num_microbatches):
|
||||||
var0 = tf.Variable([1.0, 2.0])
|
var0 = tf.Variable([1.0, 2.0])
|
||||||
var1 = tf.Variable([3.0])
|
var1 = tf.Variable([3.0])
|
||||||
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
||||||
data1 = tf.Variable([[8.0], [2.0], [3.0], [1.0]])
|
data1 = tf.Variable([[8.0], [2.0], [3.0], [1.0]])
|
||||||
|
expected_grad0 = [-2.5, -2.5]
|
||||||
|
expected_grad1 = [-0.5]
|
||||||
|
|
||||||
opt = cls(
|
optimizer = optimizer_class(
|
||||||
l2_norm_clip=100.0,
|
l2_norm_clip=100.0,
|
||||||
noise_multiplier=0.0,
|
noise_multiplier=0.0,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
|
@ -109,7 +130,7 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
with tape:
|
with tape:
|
||||||
loss = self._loss(data0, var0) + self._loss(data1, var1)
|
loss = self._loss(data0, var0) + self._loss(data1, var1)
|
||||||
|
|
||||||
grads_and_vars = opt._compute_gradients(loss, [var0, var1], tape=tape)
|
grads_and_vars = optimizer._compute_gradients(loss, [var0, var1], tape=tape)
|
||||||
self.assertAllCloseAccordingToType(expected_grad0, grads_and_vars[0][0])
|
self.assertAllCloseAccordingToType(expected_grad0, grads_and_vars[0][0])
|
||||||
self.assertAllCloseAccordingToType(expected_grad1, grads_and_vars[1][0])
|
self.assertAllCloseAccordingToType(expected_grad1, grads_and_vars[1][0])
|
||||||
|
|
||||||
|
@ -118,11 +139,11 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
('DPGradientDescentVectorized',
|
('DPGradientDescentVectorized',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),
|
||||||
)
|
)
|
||||||
def testClippingNorm(self, cls):
|
def testClippingNorm(self, optimizer_class):
|
||||||
var0 = tf.Variable([0.0, 0.0])
|
var0 = tf.Variable([0.0, 0.0])
|
||||||
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
|
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
|
||||||
|
|
||||||
opt = cls(
|
optimizer = optimizer_class(
|
||||||
l2_norm_clip=1.0,
|
l2_norm_clip=1.0,
|
||||||
noise_multiplier=0.0,
|
noise_multiplier=0.0,
|
||||||
num_microbatches=1,
|
num_microbatches=1,
|
||||||
|
@ -130,7 +151,7 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
loss = lambda: self._loss(data0, var0)
|
loss = lambda: self._loss(data0, var0)
|
||||||
# Expected gradient is sum of differences.
|
# Expected gradient is sum of differences.
|
||||||
grads_and_vars = opt._compute_gradients(loss, [var0])
|
grads_and_vars = optimizer._compute_gradients(loss, [var0])
|
||||||
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
|
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
|
@ -180,33 +201,35 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertAllCloseAccordingToType(expected1, grads_and_vars[1][0])
|
self.assertAllCloseAccordingToType(expected1, grads_and_vars[1][0])
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('DPGradientDescent 2 4 1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0,
|
('DPGradientDescent_2_4_1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0,
|
||||||
4.0, 1),
|
4.0, 1),
|
||||||
('DPGradientDescent 4 1 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4.0,
|
('DPGradientDescent_4_1_4', dp_optimizer_keras.DPKerasSGDOptimizer, 4.0,
|
||||||
1.0, 4),
|
1.0, 4),
|
||||||
('DPGradientDescentVectorized 2 4 1',
|
('DPGradientDescentVectorized_2_4_1',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.0, 4.0,
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.0, 4.0,
|
||||||
1),
|
1),
|
||||||
('DPGradientDescentVectorized 4 1 4',
|
('DPGradientDescentVectorized_4_1_4',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4.0, 1.0,
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4.0, 1.0,
|
||||||
4),
|
4),
|
||||||
)
|
)
|
||||||
def testNoiseMultiplier(self, cls, l2_norm_clip, noise_multiplier,
|
def testNoiseMultiplier(self, optimizer_class, l2_norm_clip, noise_multiplier,
|
||||||
num_microbatches):
|
num_microbatches):
|
||||||
|
tf.random.set_seed(2)
|
||||||
var0 = tf.Variable(tf.zeros([1000], dtype=tf.float32))
|
var0 = tf.Variable(tf.zeros([1000], dtype=tf.float32))
|
||||||
data0 = tf.Variable(tf.zeros([16, 1000], dtype=tf.float32))
|
data0 = tf.Variable(tf.zeros([16, 1000], dtype=tf.float32))
|
||||||
|
|
||||||
opt = cls(
|
optimizer = optimizer_class(
|
||||||
l2_norm_clip=l2_norm_clip,
|
l2_norm_clip=l2_norm_clip,
|
||||||
noise_multiplier=noise_multiplier,
|
noise_multiplier=noise_multiplier,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
learning_rate=2.0)
|
learning_rate=2.0)
|
||||||
|
|
||||||
loss = lambda: self._loss(data0, var0)
|
loss = lambda: self._loss(data0, var0)
|
||||||
grads_and_vars = opt._compute_gradients(loss, [var0])
|
grads_and_vars = optimizer._compute_gradients(loss, [var0])
|
||||||
grads = grads_and_vars[0][0].numpy()
|
grads = grads_and_vars[0][0].numpy()
|
||||||
|
|
||||||
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
||||||
|
|
||||||
self.assertNear(
|
self.assertNear(
|
||||||
np.std(grads), l2_norm_clip * noise_multiplier / num_microbatches, 0.5)
|
np.std(grads), l2_norm_clip * noise_multiplier / num_microbatches, 0.5)
|
||||||
|
|
||||||
|
@ -221,9 +244,9 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
('DPAdamVectorized',
|
('DPAdamVectorized',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer),
|
dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer),
|
||||||
)
|
)
|
||||||
def testAssertOnNoCallOfComputeGradients(self, cls):
|
def testRaisesOnNoCallOfComputeGradients(self, optimizer_class):
|
||||||
"""Tests that assertion fails when DP gradients are not computed."""
|
"""Tests that assertion fails when DP gradients are not computed."""
|
||||||
opt = cls(
|
optimizer = optimizer_class(
|
||||||
l2_norm_clip=100.0,
|
l2_norm_clip=100.0,
|
||||||
noise_multiplier=0.0,
|
noise_multiplier=0.0,
|
||||||
num_microbatches=1,
|
num_microbatches=1,
|
||||||
|
@ -231,14 +254,14 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
grads_and_vars = tf.Variable([0.0])
|
grads_and_vars = tf.Variable([0.0])
|
||||||
opt.apply_gradients(grads_and_vars)
|
optimizer.apply_gradients(grads_and_vars)
|
||||||
|
|
||||||
# Expect no exception if _compute_gradients is called.
|
# Expect no exception if _compute_gradients is called.
|
||||||
var0 = tf.Variable([0.0])
|
var0 = tf.Variable([0.0])
|
||||||
data0 = tf.Variable([[0.0]])
|
data0 = tf.Variable([[0.0]])
|
||||||
loss = lambda: self._loss(data0, var0)
|
loss = lambda: self._loss(data0, var0)
|
||||||
grads_and_vars = opt._compute_gradients(loss, [var0])
|
grads_and_vars = optimizer._compute_gradients(loss, [var0])
|
||||||
opt.apply_gradients(grads_and_vars)
|
optimizer.apply_gradients(grads_and_vars)
|
||||||
|
|
||||||
|
|
||||||
class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
@ -248,8 +271,8 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
the Estimator framework.
|
the Estimator framework.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _make_linear_model_fn(self, opt_cls, l2_norm_clip, noise_multiplier,
|
def _make_linear_model_fn(self, optimizer_class, l2_norm_clip,
|
||||||
num_microbatches, learning_rate):
|
noise_multiplier, num_microbatches, learning_rate):
|
||||||
"""Returns a model function for a linear regressor."""
|
"""Returns a model function for a linear regressor."""
|
||||||
|
|
||||||
def linear_model_fn(features, labels, mode):
|
def linear_model_fn(features, labels, mode):
|
||||||
|
@ -264,7 +287,7 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
vector_loss = 0.5 * tf.math.squared_difference(labels, preds)
|
vector_loss = 0.5 * tf.math.squared_difference(labels, preds)
|
||||||
scalar_loss = tf.reduce_mean(input_tensor=vector_loss)
|
scalar_loss = tf.reduce_mean(input_tensor=vector_loss)
|
||||||
|
|
||||||
optimizer = opt_cls(
|
optimizer = optimizer_class(
|
||||||
l2_norm_clip=l2_norm_clip,
|
l2_norm_clip=l2_norm_clip,
|
||||||
noise_multiplier=noise_multiplier,
|
noise_multiplier=noise_multiplier,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
|
@ -280,26 +303,26 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
return linear_model_fn
|
return linear_model_fn
|
||||||
|
|
||||||
# Parameters for testing: optimizer, num_microbatches.
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1),
|
('DPGradientDescent_1', dp_optimizer_keras.DPKerasSGDOptimizer, 1),
|
||||||
('DPGradientDescent 2', dp_optimizer_keras.DPKerasSGDOptimizer, 2),
|
('DPGradientDescent_2', dp_optimizer_keras.DPKerasSGDOptimizer, 2),
|
||||||
('DPGradientDescent 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4),
|
('DPGradientDescent_4', dp_optimizer_keras.DPKerasSGDOptimizer, 4),
|
||||||
('DPGradientDescentVectorized 1',
|
('DPGradientDescent_None', dp_optimizer_keras.DPKerasSGDOptimizer, None),
|
||||||
|
('DPGradientDescentVectorized_1',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1),
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1),
|
||||||
('DPGradientDescentVectorized 2',
|
('DPGradientDescentVectorized_2',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2),
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2),
|
||||||
('DPGradientDescentVectorized 4',
|
('DPGradientDescentVectorized_4',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4),
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4),
|
||||||
('DPGradientDescentVectorized None',
|
('DPGradientDescentVectorized_None',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, None),
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, None),
|
||||||
)
|
)
|
||||||
def testBaseline(self, cls, num_microbatches):
|
def testBaselineNoNoise(self, optimizer_class, num_microbatches):
|
||||||
"""Tests that DP optimizers work with tf.estimator."""
|
"""Tests that DP optimizers work with tf.estimator."""
|
||||||
|
|
||||||
linear_regressor = tf_estimator.Estimator(
|
linear_regressor = tf_estimator.Estimator(
|
||||||
model_fn=self._make_linear_model_fn(cls, 100.0, 0.0, num_microbatches,
|
model_fn=self._make_linear_model_fn(optimizer_class, 100.0, 0.0,
|
||||||
0.05))
|
num_microbatches, 0.05))
|
||||||
|
|
||||||
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
|
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
|
||||||
true_bias = np.array([6.0]).astype(np.float32)
|
true_bias = np.array([6.0]).astype(np.float32)
|
||||||
|
@ -322,13 +345,12 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
linear_regressor.get_variable_value('dense/bias'), true_bias, atol=0.05)
|
linear_regressor.get_variable_value('dense/bias'), true_bias, atol=0.05)
|
||||||
|
|
||||||
# Parameters for testing: optimizer, num_microbatches.
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1),
|
('DPGradientDescent_1', dp_optimizer_keras.DPKerasSGDOptimizer),
|
||||||
('DPGradientDescentVectorized 1',
|
('DPGradientDescentVectorized_1',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1),
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),
|
||||||
)
|
)
|
||||||
def testClippingNorm(self, cls, num_microbatches):
|
def testClippingNorm(self, optimizer_class):
|
||||||
"""Tests that DP optimizers work with tf.estimator."""
|
"""Tests that DP optimizers work with tf.estimator."""
|
||||||
|
|
||||||
true_weights = np.array([[6.0], [0.0], [0], [0]]).astype(np.float32)
|
true_weights = np.array([[6.0], [0.0], [0], [0]]).astype(np.float32)
|
||||||
|
@ -342,8 +364,12 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
(train_data, train_labels)).batch(1)
|
(train_data, train_labels)).batch(1)
|
||||||
|
|
||||||
unclipped_linear_regressor = tf_estimator.Estimator(
|
unclipped_linear_regressor = tf_estimator.Estimator(
|
||||||
model_fn=self._make_linear_model_fn(cls, 1.0e9, 0.0, num_microbatches,
|
model_fn=self._make_linear_model_fn(
|
||||||
1.0))
|
optimizer_class=optimizer_class,
|
||||||
|
l2_norm_clip=1.0e9,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
num_microbatches=1,
|
||||||
|
learning_rate=1.0))
|
||||||
unclipped_linear_regressor.train(input_fn=train_input_fn, steps=1)
|
unclipped_linear_regressor.train(input_fn=train_input_fn, steps=1)
|
||||||
|
|
||||||
kernel_value = unclipped_linear_regressor.get_variable_value('dense/kernel')
|
kernel_value = unclipped_linear_regressor.get_variable_value('dense/kernel')
|
||||||
|
@ -351,8 +377,12 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
global_norm = np.linalg.norm(np.concatenate((kernel_value, [bias_value])))
|
global_norm = np.linalg.norm(np.concatenate((kernel_value, [bias_value])))
|
||||||
|
|
||||||
clipped_linear_regressor = tf_estimator.Estimator(
|
clipped_linear_regressor = tf_estimator.Estimator(
|
||||||
model_fn=self._make_linear_model_fn(cls, 1.0, 0.0, num_microbatches,
|
model_fn=self._make_linear_model_fn(
|
||||||
1.0))
|
optimizer_class=optimizer_class,
|
||||||
|
l2_norm_clip=1.0,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
num_microbatches=1,
|
||||||
|
learning_rate=1.0))
|
||||||
clipped_linear_regressor.train(input_fn=train_input_fn, steps=1)
|
clipped_linear_regressor.train(input_fn=train_input_fn, steps=1)
|
||||||
|
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
|
@ -367,29 +397,29 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
# Parameters for testing: optimizer, l2_norm_clip, noise_multiplier,
|
# Parameters for testing: optimizer, l2_norm_clip, noise_multiplier,
|
||||||
# num_microbatches.
|
# num_microbatches.
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('DPGradientDescent 2 4 1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0,
|
('DPGradientDescent_2_4_1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0,
|
||||||
4.0, 1),
|
4.0, 1),
|
||||||
('DPGradientDescent 3 2 4', dp_optimizer_keras.DPKerasSGDOptimizer, 3.0,
|
('DPGradientDescent_3_2_4', dp_optimizer_keras.DPKerasSGDOptimizer, 3.0,
|
||||||
2.0, 4),
|
2.0, 4),
|
||||||
('DPGradientDescent 8 6 8', dp_optimizer_keras.DPKerasSGDOptimizer, 8.0,
|
('DPGradientDescent_8_6_8', dp_optimizer_keras.DPKerasSGDOptimizer, 8.0,
|
||||||
6.0, 8),
|
6.0, 8),
|
||||||
('DPGradientDescentVectorized 2 4 1',
|
('DPGradientDescentVectorized_2_4_1',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.0, 4.0,
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.0, 4.0,
|
||||||
1),
|
1),
|
||||||
('DPGradientDescentVectorized 3 2 4',
|
('DPGradientDescentVectorized_3_2_4',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 3.0, 2.0,
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 3.0, 2.0,
|
||||||
4),
|
4),
|
||||||
('DPGradientDescentVectorized 8 6 8',
|
('DPGradientDescentVectorized_8_6_8',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 8.0, 6.0,
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 8.0, 6.0,
|
||||||
8),
|
8),
|
||||||
)
|
)
|
||||||
def testNoiseMultiplier(self, cls, l2_norm_clip, noise_multiplier,
|
def testNoiseMultiplier(self, optimizer_class, l2_norm_clip, noise_multiplier,
|
||||||
num_microbatches):
|
num_microbatches):
|
||||||
"""Tests that DP optimizers work with tf.estimator."""
|
"""Tests that DP optimizers work with tf.estimator."""
|
||||||
|
|
||||||
linear_regressor = tf_estimator.Estimator(
|
linear_regressor = tf_estimator.Estimator(
|
||||||
model_fn=self._make_linear_model_fn(
|
model_fn=self._make_linear_model_fn(
|
||||||
cls,
|
optimizer_class,
|
||||||
l2_norm_clip,
|
l2_norm_clip,
|
||||||
noise_multiplier,
|
noise_multiplier,
|
||||||
num_microbatches,
|
num_microbatches,
|
||||||
|
@ -423,9 +453,9 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
('DPAdamVectorized',
|
('DPAdamVectorized',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer),
|
dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer),
|
||||||
)
|
)
|
||||||
def testAssertOnNoCallOfGetGradients(self, cls):
|
def testRaisesOnNoCallOfGetGradients(self, optimizer_class):
|
||||||
"""Tests that assertion fails when DP gradients are not computed."""
|
"""Tests that assertion fails when DP gradients are not computed."""
|
||||||
opt = cls(
|
optimizer = optimizer_class(
|
||||||
l2_norm_clip=100.0,
|
l2_norm_clip=100.0,
|
||||||
noise_multiplier=0.0,
|
noise_multiplier=0.0,
|
||||||
num_microbatches=1,
|
num_microbatches=1,
|
||||||
|
@ -433,7 +463,7 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
grads_and_vars = tf.Variable([0.0])
|
grads_and_vars = tf.Variable([0.0])
|
||||||
opt.apply_gradients(grads_and_vars)
|
optimizer.apply_gradients(grads_and_vars)
|
||||||
|
|
||||||
def testLargeBatchEmulationNoNoise(self):
|
def testLargeBatchEmulationNoNoise(self):
|
||||||
# Test for emulation of large batch training.
|
# Test for emulation of large batch training.
|
||||||
|
@ -454,7 +484,7 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
x2 = tf.constant([[4.0, 2.0], [2.0, 1.0]], dtype=tf.float32)
|
x2 = tf.constant([[4.0, 2.0], [2.0, 1.0]], dtype=tf.float32)
|
||||||
loss2 = lambda: tf.matmul(var0, x2, transpose_b=True) + var1
|
loss2 = lambda: tf.matmul(var0, x2, transpose_b=True) + var1
|
||||||
|
|
||||||
opt = dp_optimizer_keras.DPKerasSGDOptimizer(
|
optimizer = dp_optimizer_keras.DPKerasSGDOptimizer(
|
||||||
l2_norm_clip=100.0,
|
l2_norm_clip=100.0,
|
||||||
noise_multiplier=0.0,
|
noise_multiplier=0.0,
|
||||||
gradient_accumulation_steps=2,
|
gradient_accumulation_steps=2,
|
||||||
|
@ -464,35 +494,36 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
|
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
|
||||||
self.assertAllCloseAccordingToType([3.0], var1)
|
self.assertAllCloseAccordingToType([3.0], var1)
|
||||||
|
|
||||||
opt.minimize(loss1, [var0, var1])
|
optimizer.minimize(loss1, [var0, var1])
|
||||||
# After first call to optimizer values didn't change
|
# After first call to optimizer values didn't change
|
||||||
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
|
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
|
||||||
self.assertAllCloseAccordingToType([3.0], var1)
|
self.assertAllCloseAccordingToType([3.0], var1)
|
||||||
|
|
||||||
opt.minimize(loss2, [var0, var1])
|
optimizer.minimize(loss2, [var0, var1])
|
||||||
# After second call to optimizer updates were applied
|
# After second call to optimizer updates were applied
|
||||||
self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0)
|
self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0)
|
||||||
self.assertAllCloseAccordingToType([2.0], var1)
|
self.assertAllCloseAccordingToType([2.0], var1)
|
||||||
|
|
||||||
opt.minimize(loss2, [var0, var1])
|
optimizer.minimize(loss2, [var0, var1])
|
||||||
# After third call to optimizer values didn't change
|
# After third call to optimizer values didn't change
|
||||||
self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0)
|
self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0)
|
||||||
self.assertAllCloseAccordingToType([2.0], var1)
|
self.assertAllCloseAccordingToType([2.0], var1)
|
||||||
|
|
||||||
opt.minimize(loss2, [var0, var1])
|
optimizer.minimize(loss2, [var0, var1])
|
||||||
# After fourth call to optimizer updates were applied again
|
# After fourth call to optimizer updates were applied again
|
||||||
self.assertAllCloseAccordingToType([[-4.0, -0.5]], var0)
|
self.assertAllCloseAccordingToType([[-4.0, -0.5]], var0)
|
||||||
self.assertAllCloseAccordingToType([1.0], var1)
|
self.assertAllCloseAccordingToType([1.0], var1)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('DPKerasSGDOptimizer 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1),
|
('DPKerasSGDOptimizer_1', dp_optimizer_keras.DPKerasSGDOptimizer, 1),
|
||||||
('DPKerasSGDOptimizer 2', dp_optimizer_keras.DPKerasSGDOptimizer, 2),
|
('DPKerasSGDOptimizer_2', dp_optimizer_keras.DPKerasSGDOptimizer, 2),
|
||||||
('DPKerasSGDOptimizer 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4),
|
('DPKerasSGDOptimizer_4', dp_optimizer_keras.DPKerasSGDOptimizer, 4),
|
||||||
('DPKerasAdamOptimizer 2', dp_optimizer_keras.DPKerasAdamOptimizer, 1),
|
('DPKerasAdamOptimizer_2', dp_optimizer_keras.DPKerasAdamOptimizer, 1),
|
||||||
('DPKerasAdagradOptimizer 2', dp_optimizer_keras.DPKerasAdagradOptimizer,
|
('DPKerasAdagradOptimizer_2', dp_optimizer_keras.DPKerasAdagradOptimizer,
|
||||||
2),
|
2),
|
||||||
)
|
)
|
||||||
def testLargeBatchEmulation(self, cls, gradient_accumulation_steps):
|
def testLargeBatchEmulation(self, optimizer_class,
|
||||||
|
gradient_accumulation_steps):
|
||||||
# Tests various optimizers with large batch emulation.
|
# Tests various optimizers with large batch emulation.
|
||||||
# Uses clipping and noise, thus does not test specific values
|
# Uses clipping and noise, thus does not test specific values
|
||||||
# of the variables and only tests how often variables are updated.
|
# of the variables and only tests how often variables are updated.
|
||||||
|
@ -501,7 +532,7 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
x = tf.constant([[2.0, 0.0], [0.0, 1.0]], dtype=tf.float32)
|
x = tf.constant([[2.0, 0.0], [0.0, 1.0]], dtype=tf.float32)
|
||||||
loss = lambda: tf.matmul(var0, x, transpose_b=True) + var1
|
loss = lambda: tf.matmul(var0, x, transpose_b=True) + var1
|
||||||
|
|
||||||
opt = cls(
|
optimizer = optimizer_class(
|
||||||
l2_norm_clip=100.0,
|
l2_norm_clip=100.0,
|
||||||
noise_multiplier=0.0,
|
noise_multiplier=0.0,
|
||||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
|
@ -510,7 +541,7 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
for _ in range(gradient_accumulation_steps):
|
for _ in range(gradient_accumulation_steps):
|
||||||
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
|
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
|
||||||
self.assertAllCloseAccordingToType([3.0], var1)
|
self.assertAllCloseAccordingToType([3.0], var1)
|
||||||
opt.minimize(loss, [var0, var1])
|
optimizer.minimize(loss, [var0, var1])
|
||||||
|
|
||||||
self.assertNotAllClose([[1.0, 2.0]], var0)
|
self.assertNotAllClose([[1.0, 2.0]], var0)
|
||||||
self.assertNotAllClose([3.0], var1)
|
self.assertNotAllClose([3.0], var1)
|
||||||
|
@ -547,19 +578,19 @@ class SimpleEmbeddingModel(tf.keras.Model):
|
||||||
return sequence_output, pooled_output
|
return sequence_output, pooled_output
|
||||||
|
|
||||||
|
|
||||||
def keras_embedding_model_fn(opt_cls,
|
def keras_embedding_model_fn(optimizer_class,
|
||||||
l2_norm_clip: float,
|
l2_norm_clip: float,
|
||||||
noise_multiplier: float,
|
noise_multiplier: float,
|
||||||
num_microbatches: int,
|
num_microbatches: int,
|
||||||
learning_rate: float,
|
learning_rate: float,
|
||||||
use_seq_output: bool = False,
|
use_sequence_output: bool = False,
|
||||||
unconnected_gradients_to_zero: bool = False):
|
unconnected_gradients_to_zero: bool = False):
|
||||||
"""Construct a simple embedding model with a classification layer."""
|
"""Construct a simple embedding model with a classification layer."""
|
||||||
|
|
||||||
# Every sample has 4 tokens (sequence length=4).
|
# Every sample has 4 tokens (sequence length=4).
|
||||||
x = tf.keras.layers.Input(shape=(4,), dtype=tf.float32, name='input')
|
x = tf.keras.layers.Input(shape=(4,), dtype=tf.float32, name='input')
|
||||||
sequence_output, pooled_output = SimpleEmbeddingModel()(x)
|
sequence_output, pooled_output = SimpleEmbeddingModel()(x)
|
||||||
if use_seq_output:
|
if use_sequence_output:
|
||||||
embedding = sequence_output
|
embedding = sequence_output
|
||||||
else:
|
else:
|
||||||
embedding = pooled_output
|
embedding = pooled_output
|
||||||
|
@ -568,7 +599,7 @@ def keras_embedding_model_fn(opt_cls,
|
||||||
embedding)
|
embedding)
|
||||||
model = tf.keras.Model(inputs=x, outputs=probs, name='model')
|
model = tf.keras.Model(inputs=x, outputs=probs, name='model')
|
||||||
|
|
||||||
optimizer = opt_cls(
|
optimizer = optimizer_class(
|
||||||
l2_norm_clip=l2_norm_clip,
|
l2_norm_clip=l2_norm_clip,
|
||||||
noise_multiplier=noise_multiplier,
|
noise_multiplier=noise_multiplier,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
|
@ -608,7 +639,7 @@ class DPVectorizedOptimizerUnconnectedNodesTest(tf.test.TestCase,
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('DPSGDVectorized_SeqOutput_UnconnectedGradients',
|
('DPSGDVectorized_SeqOutput_UnconnectedGradients',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),)
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),)
|
||||||
def testSeqOutputUnconnectedGradientsAsNoneFails(self, cls):
|
def testSeqOutputUnconnectedGradientsAsNoneFails(self, optimizer_class):
|
||||||
"""Tests that DP vectorized optimizers with 'None' unconnected gradients fail.
|
"""Tests that DP vectorized optimizers with 'None' unconnected gradients fail.
|
||||||
|
|
||||||
Sequence models that have unconnected gradients (with
|
Sequence models that have unconnected gradients (with
|
||||||
|
@ -620,16 +651,16 @@ class DPVectorizedOptimizerUnconnectedNodesTest(tf.test.TestCase,
|
||||||
These tests test the various combinations of this flag and the model.
|
These tests test the various combinations of this flag and the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cls: The DP optimizer class to test.
|
optimizer_class: The DP optimizer class to test.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embedding_model = keras_embedding_model_fn(
|
embedding_model = keras_embedding_model_fn(
|
||||||
cls,
|
optimizer_class,
|
||||||
l2_norm_clip=1.0,
|
l2_norm_clip=1.0,
|
||||||
noise_multiplier=0.5,
|
noise_multiplier=0.5,
|
||||||
num_microbatches=1,
|
num_microbatches=1,
|
||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
use_seq_output=True,
|
use_sequence_output=True,
|
||||||
unconnected_gradients_to_zero=False)
|
unconnected_gradients_to_zero=False)
|
||||||
|
|
||||||
train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32)
|
train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32)
|
||||||
|
@ -651,16 +682,17 @@ class DPVectorizedOptimizerUnconnectedNodesTest(tf.test.TestCase,
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('DPSGDVectorized_PooledOutput_UnconnectedGradients',
|
('DPSGDVectorized_PooledOutput_UnconnectedGradients',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),)
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),)
|
||||||
def testPooledOutputUnconnectedGradientsAsNonePasses(self, cls):
|
def testPooledOutputUnconnectedGradientsAsNonePasses(self, optimizer_class):
|
||||||
"""Tests that DP vectorized optimizers with 'None' unconnected gradients fail."""
|
"""Tests that DP vectorized optimizers with 'None' unconnected gradients fail.
|
||||||
|
"""
|
||||||
|
|
||||||
embedding_model = keras_embedding_model_fn(
|
embedding_model = keras_embedding_model_fn(
|
||||||
cls,
|
optimizer_class,
|
||||||
l2_norm_clip=1.0,
|
l2_norm_clip=1.0,
|
||||||
noise_multiplier=0.5,
|
noise_multiplier=0.5,
|
||||||
num_microbatches=1,
|
num_microbatches=1,
|
||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
use_seq_output=False,
|
use_sequence_output=False,
|
||||||
unconnected_gradients_to_zero=False)
|
unconnected_gradients_to_zero=False)
|
||||||
|
|
||||||
train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32)
|
train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32)
|
||||||
|
@ -684,16 +716,18 @@ class DPVectorizedOptimizerUnconnectedNodesTest(tf.test.TestCase,
|
||||||
('DPSGDVectorized_PooledOutput_UnconnectedGradientsAreZero',
|
('DPSGDVectorized_PooledOutput_UnconnectedGradientsAreZero',
|
||||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, False),
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, False),
|
||||||
)
|
)
|
||||||
def testUnconnectedGradientsAsZeroPasses(self, cls, use_seq_output):
|
def testUnconnectedGradientsAsZeroPasses(self, optimizer_class,
|
||||||
"""Tests that DP vectorized optimizers with 'Zero' unconnected gradients pass."""
|
use_sequence_output):
|
||||||
|
"""Tests that DP vectorized optimizers with 'Zero' unconnected gradients pass.
|
||||||
|
"""
|
||||||
|
|
||||||
embedding_model = keras_embedding_model_fn(
|
embedding_model = keras_embedding_model_fn(
|
||||||
cls,
|
optimizer_class,
|
||||||
l2_norm_clip=1.0,
|
l2_norm_clip=1.0,
|
||||||
noise_multiplier=0.5,
|
noise_multiplier=0.5,
|
||||||
num_microbatches=1,
|
num_microbatches=1,
|
||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
use_seq_output=use_seq_output,
|
use_sequence_output=use_sequence_output,
|
||||||
unconnected_gradients_to_zero=True)
|
unconnected_gradients_to_zero=True)
|
||||||
|
|
||||||
train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32)
|
train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32)
|
||||||
|
@ -710,5 +744,6 @@ class DPVectorizedOptimizerUnconnectedNodesTest(tf.test.TestCase,
|
||||||
# other exceptions are errors.
|
# other exceptions are errors.
|
||||||
self.fail('ValueError raised by model.fit().')
|
self.fail('ValueError raised by model.fit().')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue