Changes DPOptimizerClass to generically accept and use any dp_sum_query.

This enables creation of generic DPOptimizers by user's passing queries. The most common Gaussian query is automatically performed for convenience and backwards compatibility.

Byproducts of this update:
-ensures consistent implementations between the internal (and legacy) `get_gradients` and newer `_compute_gradients` for all queries.
-refactors for python readability.
-includes new tests ensuring that `_num_microbatches=None` is tested.
-changes the `_global_state` to to be initialized in the init function for `_compute_gradients`.

PiperOrigin-RevId: 480668376
This commit is contained in:
A. Unique TensorFlower 2022-10-12 11:03:22 -07:00
parent f8ed0fcd9c
commit 79fe32a60b
4 changed files with 388 additions and 215 deletions

View file

@ -61,9 +61,14 @@ else:
from tensorflow_privacy.privacy.keras_models.dp_keras_model import make_dp_model_class from tensorflow_privacy.privacy.keras_models.dp_keras_model import make_dp_model_class
# Optimizers # Optimizers
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPAdagradOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPAdamOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPSGDOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdagradOptimizer from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdagradOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdamOptimizer from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdamOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_gaussian_query_optimizer_class
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_keras_generic_optimizer_class
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_keras_optimizer_class from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_keras_optimizer_class
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras_vectorized import VectorizedDPKerasAdagradOptimizer from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras_vectorized import VectorizedDPKerasAdagradOptimizer

View file

@ -18,6 +18,18 @@ py_library(
deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"], deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"],
) )
py_library(
name = "dp_optimizer_factory",
srcs = [
"dp_optimizer_keras.py",
],
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/dp_query",
"//tensorflow_privacy/privacy/dp_query:gaussian_query",
],
)
py_library( py_library(
name = "dp_optimizer_vectorized", name = "dp_optimizer_vectorized",
srcs = [ srcs = [
@ -32,7 +44,10 @@ py_library(
"dp_optimizer_keras.py", "dp_optimizer_keras.py",
], ],
srcs_version = "PY3", srcs_version = "PY3",
deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"], deps = [
"//tensorflow_privacy/privacy/dp_query",
"//tensorflow_privacy/privacy/dp_query:gaussian_query",
],
) )
py_library( py_library(
@ -84,7 +99,7 @@ py_test(
python_version = "PY3", python_version = "PY3",
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", ":dp_optimizer_keras",
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras_vectorized", ":dp_optimizer_keras_vectorized",
], ],
) )

View file

@ -13,29 +13,40 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Differentially private version of Keras optimizer v2.""" """Differentially private version of Keras optimizer v2."""
from typing import Optional, Type
import warnings
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query from tensorflow_privacy.privacy.dp_query import gaussian_query
def make_keras_optimizer_class(cls): def _normalize(microbatch_gradient: tf.Tensor,
"""Given a subclass of `tf.keras.optimizers.legacy.Optimizer`, returns a DP-SGD subclass of it. num_microbatches: float) -> tf.Tensor:
"""Normalizes `microbatch_gradient` by `num_microbatches`."""
return tf.truediv(microbatch_gradient,
tf.cast(num_microbatches, microbatch_gradient.dtype))
def make_keras_generic_optimizer_class(
cls: Type[tf.keras.optimizers.Optimizer]):
"""Returns a differentially private (DP) subclass of `cls`.
Args: Args:
cls: Class from which to derive a DP subclass. Should be a subclass of cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.keras.optimizers.legacy.Optimizer`. `tf.keras.optimizers.legacy.Optimizer`.
Returns: Returns:
A DP-SGD subclass of `cls`. A generic DP-SGD subclass of `cls`, compatible with many DP queries.
""" """
class DPOptimizerClass(cls): # pylint: disable=empty-docstring class DPOptimizerClass(cls): # pylint: disable=empty-docstring,missing-class-docstring
__doc__ = """Differentially private subclass of class `{base_class}`. __doc__ = """Differentially private subclass of class `{base_class}`.
You can use this as a differentially private replacement for You can use this as a differentially private replacement for
`{base_class}`. This optimizer implements DP-SGD using `{base_class}`. This optimizer implements a differentiallyy private version
the standard Gaussian mechanism. of the stochastic gradient descent optimizer `cls` using the chosen
`dp_query.DPQuery` instance.
When instantiating this optimizer, you need to supply several When instantiating this optimizer, you need to supply several
DP-related arguments followed by the standard arguments for DP-related arguments followed by the standard arguments for
@ -45,8 +56,10 @@ def make_keras_optimizer_class(cls):
```python ```python
# Create optimizer. # Create optimizer.
opt = {dp_keras_class}(l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1, gaussian_query = gaussian_query.GaussianSumQuery(
<standard arguments>) l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1
)
opt = {dp_keras_class}(dp_sum_query=gaussian_query, <standard arguments>)
``` ```
When using the optimizer, be sure to pass in the loss as a When using the optimizer, be sure to pass in the loss as a
@ -92,8 +105,10 @@ def make_keras_optimizer_class(cls):
```python ```python
# Create optimizer which will be accumulating gradients for 4 steps. # Create optimizer which will be accumulating gradients for 4 steps.
# and then performing an update of model weights. # and then performing an update of model weights.
opt = {dp_keras_class}(l2_norm_clip=1.0, gaussian_query = gaussian_query.GaussianSumQuery(
noise_multiplier=0.5, l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1
)
opt = {dp_keras_class}(dp_sum_query=gaussian_query,
num_microbatches=1, num_microbatches=1,
gradient_accumulation_steps=4, gradient_accumulation_steps=4,
<standard arguments>) <standard arguments>)
@ -138,24 +153,23 @@ def make_keras_optimizer_class(cls):
def __init__( def __init__(
self, self,
l2_norm_clip, dp_sum_query: dp_query.DPQuery,
noise_multiplier, num_microbatches: Optional[int] = None,
num_microbatches=None, gradient_accumulation_steps: int = 1,
gradient_accumulation_steps=1,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs): **kwargs):
"""Initialize the DPOptimizerClass. """Initializes the DPOptimizerClass.
Args: Args:
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients). dp_sum_query: `DPQuery` object, specifying differential privacy
noise_multiplier: Ratio of the standard deviation to the clipping norm. mechanism to use.
num_microbatches: Number of microbatches into which each minibatch is num_microbatches: Number of microbatches into which each minibatch is
split. Default is `None` which means that number of microbatches split. Default is `None` which means that number of microbatches is
is equal to batch size (i.e. each microbatch contains exactly one equal to batch size (i.e. each microbatch contains exactly one
example). If `gradient_accumulation_steps` is greater than 1 and example). If `gradient_accumulation_steps` is greater than 1 and
`num_microbatches` is not `None` then the effective number of `num_microbatches` is not `None` then the effective number of
microbatches is equal to microbatches is equal to `num_microbatches *
`num_microbatches * gradient_accumulation_steps`. gradient_accumulation_steps`.
gradient_accumulation_steps: If greater than 1 then optimizer will be gradient_accumulation_steps: If greater than 1 then optimizer will be
accumulating gradients for this number of optimizer steps before accumulating gradients for this number of optimizer steps before
applying them to update model weights. If this argument is set to 1 applying them to update model weights. If this argument is set to 1
@ -165,13 +179,16 @@ def make_keras_optimizer_class(cls):
""" """
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.gradient_accumulation_steps = gradient_accumulation_steps self.gradient_accumulation_steps = gradient_accumulation_steps
self._l2_norm_clip = l2_norm_clip
self._noise_multiplier = noise_multiplier
self._num_microbatches = num_microbatches self._num_microbatches = num_microbatches
self._dp_sum_query = gaussian_query.GaussianSumQuery( self._dp_sum_query = dp_sum_query
l2_norm_clip, l2_norm_clip * noise_multiplier)
self._global_state = None
self._was_dp_gradients_called = False self._was_dp_gradients_called = False
# We initialize here for `_compute_gradients` because of requirements from
# the tf.keras.Model API. Specifically, keras models use the
# `_compute_gradients` method for both eager and graph mode. So,
# instantiating the state here is necessary to avoid graph compilation
# issues.
self._global_state = self._dp_sum_query.initial_global_state()
def _create_slots(self, var_list): def _create_slots(self, var_list):
super()._create_slots(var_list) # pytype: disable=attribute-error super()._create_slots(var_list) # pytype: disable=attribute-error
@ -233,19 +250,21 @@ def make_keras_optimizer_class(cls):
def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
"""DP-SGD version of base class method.""" """DP-SGD version of base class method."""
self._was_dp_gradients_called = True self._was_dp_gradients_called = True
# Compute loss. # Compute loss.
if not callable(loss) and tape is None: if not callable(loss) and tape is None:
raise ValueError('`tape` is required when a `Tensor` loss is passed.') raise ValueError('`tape` is required when a `Tensor` loss is passed.')
tape = tape if tape is not None else tf.GradientTape() tape = tape if tape is not None else tf.GradientTape()
if callable(loss):
with tape: with tape:
if callable(loss):
if not callable(var_list): if not callable(var_list):
tape.watch(var_list) tape.watch(var_list)
loss = loss() loss = loss()
if self._num_microbatches is None: if self._num_microbatches is None:
num_microbatches = tf.shape(input=loss)[0] num_microbatches = tf.shape(input=loss)[0]
else: else:
@ -255,60 +274,64 @@ def make_keras_optimizer_class(cls):
if callable(var_list): if callable(var_list):
var_list = var_list() var_list = var_list()
else:
with tape:
if self._num_microbatches is None:
num_microbatches = tf.shape(input=loss)[0]
else:
num_microbatches = self._num_microbatches
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [num_microbatches, -1]), axis=1)
var_list = tf.nest.flatten(var_list) var_list = tf.nest.flatten(var_list)
sample_params = (
self._dp_sum_query.derive_sample_params(self._global_state))
# Compute the per-microbatch losses using helpful jacobian method. # Compute the per-microbatch losses using helpful jacobian method.
with tf.keras.backend.name_scope(self._name + '/gradients'): with tf.keras.backend.name_scope(self._name + '/gradients'):
jacobian = tape.jacobian( jacobian_per_var = tape.jacobian(
microbatch_losses, var_list, unconnected_gradients='zero') microbatch_losses, var_list, unconnected_gradients='zero')
def clip_gradients(g): def process_microbatch(sample_state, microbatch_jacobians):
"""Clips gradients to given l2_norm_clip.""" """Process one microbatch (record) with privacy helper."""
return tf.clip_by_global_norm(g, self._l2_norm_clip)[0] sample_state = self._dp_sum_query.accumulate_record(
sample_params, sample_state, microbatch_jacobians)
return sample_state
# Clip all gradients. Note that `tf.map_fn` applies the given function sample_state = self._dp_sum_query.initial_sample_state(var_list)
# to its arguments unstacked along axis 0.
clipped_gradients = tf.map_fn(clip_gradients, jacobian)
def reduce_noise_normalize_batch(g): def body_fn(idx, sample_state):
# Sum gradients over all microbatches. microbatch_jacobians_per_var = [
summed_gradient = tf.reduce_sum(g, axis=0) jacobian[idx] for jacobian in jacobian_per_var
]
sample_state = process_microbatch(sample_state,
microbatch_jacobians_per_var)
return tf.add(idx, 1), sample_state
# Add noise to summed gradients. cond_fn = lambda idx, _: tf.less(idx, num_microbatches)
noise_stddev = self._l2_norm_clip * self._noise_multiplier idx = tf.constant(0)
noise = tf.random.normal( _, sample_state = tf.while_loop(cond_fn, body_fn, [idx, sample_state])
tf.shape(input=summed_gradient), stddev=noise_stddev)
noised_gradient = tf.add(summed_gradient, noise)
# Normalize by number of microbatches and return. grad_sums, self._global_state, _ = (
return tf.truediv(noised_gradient, self._dp_sum_query.get_noised_result(sample_state,
tf.cast(num_microbatches, tf.float32)) self._global_state))
final_grads = tf.nest.map_structure(_normalize, grad_sums,
[num_microbatches] * len(grad_sums))
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch, return list(zip(final_grads, var_list))
clipped_gradients)
return list(zip(final_gradients, var_list))
def get_gradients(self, loss, params): def get_gradients(self, loss, params):
"""DP-SGD version of base class method.""" """DP-SGD version of base class method."""
if not self._was_dp_gradients_called:
self._was_dp_gradients_called = True # We create the global state here due to tf.Estimator API requirements,
if self._global_state is None: # specifically, that instantiating the global state outside this
# function leads to graph compilation errors of attempting to capture an
# EagerTensor.
self._global_state = self._dp_sum_query.initial_global_state() self._global_state = self._dp_sum_query.initial_global_state()
self._was_dp_gradients_called = True
# This code mostly follows the logic in the original DPOptimizerClass # This code mostly follows the logic in the original DPOptimizerClass
# in dp_optimizer.py, except that this returns only the gradients, # in dp_optimizer.py, except that this returns only the gradients,
# not the gradients and variables. # not the gradients and variables.
microbatch_losses = tf.reshape(loss, [self._num_microbatches, -1]) if self._num_microbatches is None:
num_microbatches = tf.shape(input=loss)[0]
else:
num_microbatches = self._num_microbatches
microbatch_losses = tf.reshape(loss, [num_microbatches, -1])
sample_params = ( sample_params = (
self._dp_sum_query.derive_sample_params(self._global_state)) self._dp_sum_query.derive_sample_params(self._global_state))
@ -322,19 +345,20 @@ def make_keras_optimizer_class(cls):
return sample_state return sample_state
sample_state = self._dp_sum_query.initial_sample_state(params) sample_state = self._dp_sum_query.initial_sample_state(params)
for idx in range(self._num_microbatches):
def body_fn(idx, sample_state):
sample_state = process_microbatch(idx, sample_state) sample_state = process_microbatch(idx, sample_state)
return tf.add(idx, 1), sample_state
cond_fn = lambda idx, _: tf.less(idx, num_microbatches)
idx = tf.constant(0)
_, sample_state = tf.while_loop(cond_fn, body_fn, [idx, sample_state])
grad_sums, self._global_state, _ = ( grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state, self._dp_sum_query.get_noised_result(sample_state,
self._global_state)) self._global_state))
def normalize(v): final_grads = tf.nest.map_structure(_normalize, grad_sums,
try: [num_microbatches] * len(grad_sums))
return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32))
except TypeError:
return None
final_grads = tf.nest.map_structure(normalize, grad_sums)
return final_grads return final_grads
@ -351,8 +375,7 @@ def make_keras_optimizer_class(cls):
""" """
config = super().get_config() config = super().get_config()
config.update({ config.update({
'l2_norm_clip': self._l2_norm_clip, 'global_state': self._global_state._asdict(),
'noise_multiplier': self._noise_multiplier,
'num_microbatches': self._num_microbatches, 'num_microbatches': self._num_microbatches,
}) })
return config return config
@ -370,8 +393,103 @@ def make_keras_optimizer_class(cls):
return DPOptimizerClass return DPOptimizerClass
DPKerasAdagradOptimizer = make_keras_optimizer_class( def make_gaussian_query_optimizer_class(cls):
"""Returns a differentially private optimizer using the `GaussianSumQuery`.
Args:
cls: `DPOptimizerClass`, the output of `make_keras_optimizer_class`.
Returns:
A DP-SGD subclass of `cls` using the `GaussianQuery`, the canonical DP-SGD
implementation.
"""
def return_gaussian_query_optimizer(
l2_norm_clip: float,
noise_multiplier: float,
num_microbatches: Optional[int] = None,
gradient_accumulation_steps: int = 1,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs):
"""Returns a `DPOptimizerClass` `cls` using the `GaussianSumQuery`.
This function is a thin wrapper around
`make_keras_optimizer_class.<locals>.DPOptimizerClass` which can be used to
apply a `GaussianSumQuery` to any `DPOptimizerClass`.
When combined with stochastic gradient descent, this creates the canonical
DP-SGD algorithm of "Deep Learning with Differential Privacy"
(see https://arxiv.org/abs/1607.00133).
When instantiating this optimizer, you need to supply several
DP-related arguments followed by the standard arguments for
`{short_base_class}`.
As an example, see the below or the documentation of the DPOptimizerClass.
```python
# Create optimizer.
opt = {dp_keras_class}(l2_norm_clip=1.0, noise_multiplier=0.5,
num_microbatches=1, <standard arguments>)
```
Args:
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
num_microbatches: Number of microbatches into which each minibatch is
split. Default is `None` which means that number of microbatches is
equal to batch size (i.e. each microbatch contains exactly one example).
If `gradient_accumulation_steps` is greater than 1 and
`num_microbatches` is not `None` then the effective number of
microbatches is equal to `num_microbatches *
gradient_accumulation_steps`.
gradient_accumulation_steps: If greater than 1 then optimizer will be
accumulating gradients for this number of optimizer steps before
applying them to update model weights. If this argument is set to 1 then
updates will be applied on each optimizer step.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
"""
dp_sum_query = gaussian_query.GaussianSumQuery(
l2_norm_clip, l2_norm_clip * noise_multiplier)
return cls(
dp_sum_query=dp_sum_query,
num_microbatches=num_microbatches,
gradient_accumulation_steps=gradient_accumulation_steps,
*args,
**kwargs)
return return_gaussian_query_optimizer
def make_keras_optimizer_class(cls: Type[tf.keras.optimizers.Optimizer]):
"""Returns a differentially private optimizer using the `GaussianSumQuery`.
For backwards compatibility, we create this symbol to match the previous
output of `make_keras_optimizer_class` but using the new logic.
Args:
cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.keras.optimizers.Optimizer`.
"""
warnings.warn(
'`make_keras_optimizer_class` will be depracated on 2023-02-23. '
'Please switch to `make_gaussian_query_optimizer_class` and the '
'generic optimizers (`make_keras_generic_optimizer_class`).')
return make_gaussian_query_optimizer_class(
make_keras_generic_optimizer_class(cls))
GenericDPAdagradOptimizer = make_keras_generic_optimizer_class(
tf.keras.optimizers.legacy.Adagrad) tf.keras.optimizers.legacy.Adagrad)
DPKerasAdamOptimizer = make_keras_optimizer_class( GenericDPAdamOptimizer = make_keras_generic_optimizer_class(
tf.keras.optimizers.legacy.Adam) tf.keras.optimizers.legacy.Adam)
DPKerasSGDOptimizer = make_keras_optimizer_class(tf.keras.optimizers.legacy.SGD) GenericDPSGDOptimizer = make_keras_generic_optimizer_class(
tf.keras.optimizers.legacy.SGD)
# We keep the same names for backwards compatibility.
DPKerasAdagradOptimizer = make_gaussian_query_optimizer_class(
GenericDPAdagradOptimizer)
DPKerasAdamOptimizer = make_gaussian_query_optimizer_class(
GenericDPAdamOptimizer)
DPKerasSGDOptimizer = make_gaussian_query_optimizer_class(GenericDPSGDOptimizer)

View file

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -29,36 +28,30 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
return 0.5 * tf.reduce_sum( return 0.5 * tf.reduce_sum(
input_tensor=tf.math.squared_difference(val0, val1), axis=1) input_tensor=tf.math.squared_difference(val0, val1), axis=1)
# Parameters for testing: optimizer, num_microbatches, expected gradient for
# var0, expected gradient for var1.
@parameterized.named_parameters( @parameterized.named_parameters(
('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1, ('DPGradientDescent_1', dp_optimizer_keras.DPKerasSGDOptimizer, 1),
[-2.5, -2.5], [-0.5]), ('DPGradientDescent_None', dp_optimizer_keras.DPKerasSGDOptimizer, None),
('DPAdam 2', dp_optimizer_keras.DPKerasAdamOptimizer, 2, [-2.5, -2.5 ('DPAdam_2', dp_optimizer_keras.DPKerasAdamOptimizer, 2),
], [-0.5]), ('DPAdagrad _4', dp_optimizer_keras.DPKerasAdagradOptimizer, 4),
('DPAdagrad 4', dp_optimizer_keras.DPKerasAdagradOptimizer, 4, ('DPGradientDescentVectorized_1',
[-2.5, -2.5], [-0.5]), dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1),
('DPGradientDescentVectorized 1', ('DPAdamVectorized_2',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1, dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer, 2),
[-2.5, -2.5], [-0.5]), ('DPAdagradVectorized_4',
('DPAdamVectorized 2', dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4),
dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer, 2, ('DPAdagradVectorized_None',
[-2.5, -2.5], [-0.5]), dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None),
('DPAdagradVectorized 4',
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4,
[-2.5, -2.5], [-0.5]),
('DPAdagradVectorized None',
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None,
[-2.5, -2.5], [-0.5]),
) )
def testBaselineWithCallableLoss(self, cls, num_microbatches, expected_grad0, def testBaselineWithCallableLossNoNoise(self, optimizer_class,
expected_grad1): num_microbatches):
var0 = tf.Variable([1.0, 2.0]) var0 = tf.Variable([1.0, 2.0])
var1 = tf.Variable([3.0]) var1 = tf.Variable([3.0])
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
data1 = tf.Variable([[8.0], [2.0], [3.0], [1.0]]) data1 = tf.Variable([[8.0], [2.0], [3.0], [1.0]])
expected_grad0 = [-2.5, -2.5]
expected_grad1 = [-0.5]
opt = cls( optimizer = optimizer_class(
l2_norm_clip=100.0, l2_norm_clip=100.0,
noise_multiplier=0.0, noise_multiplier=0.0,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
@ -66,40 +59,68 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
loss = lambda: self._loss(data0, var0) + self._loss(data1, var1) loss = lambda: self._loss(data0, var0) + self._loss(data1, var1)
grads_and_vars = opt._compute_gradients(loss, [var0, var1]) grads_and_vars = optimizer._compute_gradients(loss, [var0, var1])
self.assertAllCloseAccordingToType(expected_grad0, grads_and_vars[0][0]) self.assertAllCloseAccordingToType(expected_grad0, grads_and_vars[0][0])
self.assertAllCloseAccordingToType(expected_grad1, grads_and_vars[1][0]) self.assertAllCloseAccordingToType(expected_grad1, grads_and_vars[1][0])
# Parameters for testing: optimizer, num_microbatches, expected gradient for def testKerasModelBaselineNoNoiseNoneMicrobatches(self):
# var0, expected gradient for var1. """Tests that DP optimizers work with tf.keras.Model."""
model = tf.keras.models.Sequential(layers=[
tf.keras.layers.Dense(
1,
activation='linear',
name='dense',
kernel_initializer='zeros',
bias_initializer='zeros')
])
optimizer = dp_optimizer_keras.DPKerasSGDOptimizer(
l2_norm_clip=100.0,
noise_multiplier=0.0,
num_microbatches=None,
learning_rate=0.05)
loss = tf.keras.losses.MeanSquaredError(reduction='none')
model.compile(optimizer, loss)
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
true_bias = np.array([6.0]).astype(np.float32)
train_data = np.random.normal(scale=3.0, size=(1000, 4)).astype(np.float32)
train_labels = np.matmul(train_data,
true_weights) + true_bias + np.random.normal(
scale=0.0, size=(1000, 1)).astype(np.float32)
model.fit(train_data, train_labels, batch_size=8, epochs=1, shuffle=False)
self.assertAllClose(model.get_weights()[0], true_weights, atol=0.05)
self.assertAllClose(model.get_weights()[1], true_bias, atol=0.05)
@parameterized.named_parameters( @parameterized.named_parameters(
('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1, ('DPGradientDescent_1', dp_optimizer_keras.DPKerasSGDOptimizer, 1),
[-2.5, -2.5], [-0.5]), ('DPGradientDescent_None', dp_optimizer_keras.DPKerasSGDOptimizer, None),
('DPAdam 2', dp_optimizer_keras.DPKerasAdamOptimizer, 2, [-2.5, -2.5 ('DPAdam_2', dp_optimizer_keras.DPKerasAdamOptimizer, 2),
], [-0.5]), ('DPAdagrad_4', dp_optimizer_keras.DPKerasAdagradOptimizer, 4),
('DPAdagrad 4', dp_optimizer_keras.DPKerasAdagradOptimizer, 4, ('DPGradientDescentVectorized_1',
[-2.5, -2.5], [-0.5]), dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1),
('DPGradientDescentVectorized 1', ('DPAdamVectorized_2',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1, dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer, 2),
[-2.5, -2.5], [-0.5]), ('DPAdagradVectorized_4',
('DPAdamVectorized 2', dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4),
dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer, 2, ('DPAdagradVectorized_None',
[-2.5, -2.5], [-0.5]), dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None),
('DPAdagradVectorized 4',
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4,
[-2.5, -2.5], [-0.5]),
('DPAdagradVectorized None',
dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None,
[-2.5, -2.5], [-0.5]),
) )
def testBaselineWithTensorLoss(self, cls, num_microbatches, expected_grad0, def testBaselineWithTensorLossNoNoise(self, optimizer_class,
expected_grad1): num_microbatches):
var0 = tf.Variable([1.0, 2.0]) var0 = tf.Variable([1.0, 2.0])
var1 = tf.Variable([3.0]) var1 = tf.Variable([3.0])
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
data1 = tf.Variable([[8.0], [2.0], [3.0], [1.0]]) data1 = tf.Variable([[8.0], [2.0], [3.0], [1.0]])
expected_grad0 = [-2.5, -2.5]
expected_grad1 = [-0.5]
opt = cls( optimizer = optimizer_class(
l2_norm_clip=100.0, l2_norm_clip=100.0,
noise_multiplier=0.0, noise_multiplier=0.0,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
@ -109,7 +130,7 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
with tape: with tape:
loss = self._loss(data0, var0) + self._loss(data1, var1) loss = self._loss(data0, var0) + self._loss(data1, var1)
grads_and_vars = opt._compute_gradients(loss, [var0, var1], tape=tape) grads_and_vars = optimizer._compute_gradients(loss, [var0, var1], tape=tape)
self.assertAllCloseAccordingToType(expected_grad0, grads_and_vars[0][0]) self.assertAllCloseAccordingToType(expected_grad0, grads_and_vars[0][0])
self.assertAllCloseAccordingToType(expected_grad1, grads_and_vars[1][0]) self.assertAllCloseAccordingToType(expected_grad1, grads_and_vars[1][0])
@ -118,11 +139,11 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
('DPGradientDescentVectorized', ('DPGradientDescentVectorized',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer), dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),
) )
def testClippingNorm(self, cls): def testClippingNorm(self, optimizer_class):
var0 = tf.Variable([0.0, 0.0]) var0 = tf.Variable([0.0, 0.0])
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]]) data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
opt = cls( optimizer = optimizer_class(
l2_norm_clip=1.0, l2_norm_clip=1.0,
noise_multiplier=0.0, noise_multiplier=0.0,
num_microbatches=1, num_microbatches=1,
@ -130,7 +151,7 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
loss = lambda: self._loss(data0, var0) loss = lambda: self._loss(data0, var0)
# Expected gradient is sum of differences. # Expected gradient is sum of differences.
grads_and_vars = opt._compute_gradients(loss, [var0]) grads_and_vars = optimizer._compute_gradients(loss, [var0])
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0]) self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
@parameterized.named_parameters( @parameterized.named_parameters(
@ -180,33 +201,35 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllCloseAccordingToType(expected1, grads_and_vars[1][0]) self.assertAllCloseAccordingToType(expected1, grads_and_vars[1][0])
@parameterized.named_parameters( @parameterized.named_parameters(
('DPGradientDescent 2 4 1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0, ('DPGradientDescent_2_4_1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0,
4.0, 1), 4.0, 1),
('DPGradientDescent 4 1 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4.0, ('DPGradientDescent_4_1_4', dp_optimizer_keras.DPKerasSGDOptimizer, 4.0,
1.0, 4), 1.0, 4),
('DPGradientDescentVectorized 2 4 1', ('DPGradientDescentVectorized_2_4_1',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.0, 4.0, dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.0, 4.0,
1), 1),
('DPGradientDescentVectorized 4 1 4', ('DPGradientDescentVectorized_4_1_4',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4.0, 1.0, dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4.0, 1.0,
4), 4),
) )
def testNoiseMultiplier(self, cls, l2_norm_clip, noise_multiplier, def testNoiseMultiplier(self, optimizer_class, l2_norm_clip, noise_multiplier,
num_microbatches): num_microbatches):
tf.random.set_seed(2)
var0 = tf.Variable(tf.zeros([1000], dtype=tf.float32)) var0 = tf.Variable(tf.zeros([1000], dtype=tf.float32))
data0 = tf.Variable(tf.zeros([16, 1000], dtype=tf.float32)) data0 = tf.Variable(tf.zeros([16, 1000], dtype=tf.float32))
opt = cls( optimizer = optimizer_class(
l2_norm_clip=l2_norm_clip, l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier, noise_multiplier=noise_multiplier,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
learning_rate=2.0) learning_rate=2.0)
loss = lambda: self._loss(data0, var0) loss = lambda: self._loss(data0, var0)
grads_and_vars = opt._compute_gradients(loss, [var0]) grads_and_vars = optimizer._compute_gradients(loss, [var0])
grads = grads_and_vars[0][0].numpy() grads = grads_and_vars[0][0].numpy()
# Test standard deviation is close to l2_norm_clip * noise_multiplier. # Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear( self.assertNear(
np.std(grads), l2_norm_clip * noise_multiplier / num_microbatches, 0.5) np.std(grads), l2_norm_clip * noise_multiplier / num_microbatches, 0.5)
@ -221,9 +244,9 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
('DPAdamVectorized', ('DPAdamVectorized',
dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer), dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer),
) )
def testAssertOnNoCallOfComputeGradients(self, cls): def testRaisesOnNoCallOfComputeGradients(self, optimizer_class):
"""Tests that assertion fails when DP gradients are not computed.""" """Tests that assertion fails when DP gradients are not computed."""
opt = cls( optimizer = optimizer_class(
l2_norm_clip=100.0, l2_norm_clip=100.0,
noise_multiplier=0.0, noise_multiplier=0.0,
num_microbatches=1, num_microbatches=1,
@ -231,14 +254,14 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
grads_and_vars = tf.Variable([0.0]) grads_and_vars = tf.Variable([0.0])
opt.apply_gradients(grads_and_vars) optimizer.apply_gradients(grads_and_vars)
# Expect no exception if _compute_gradients is called. # Expect no exception if _compute_gradients is called.
var0 = tf.Variable([0.0]) var0 = tf.Variable([0.0])
data0 = tf.Variable([[0.0]]) data0 = tf.Variable([[0.0]])
loss = lambda: self._loss(data0, var0) loss = lambda: self._loss(data0, var0)
grads_and_vars = opt._compute_gradients(loss, [var0]) grads_and_vars = optimizer._compute_gradients(loss, [var0])
opt.apply_gradients(grads_and_vars) optimizer.apply_gradients(grads_and_vars)
class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase): class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
@ -248,8 +271,8 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
the Estimator framework. the Estimator framework.
""" """
def _make_linear_model_fn(self, opt_cls, l2_norm_clip, noise_multiplier, def _make_linear_model_fn(self, optimizer_class, l2_norm_clip,
num_microbatches, learning_rate): noise_multiplier, num_microbatches, learning_rate):
"""Returns a model function for a linear regressor.""" """Returns a model function for a linear regressor."""
def linear_model_fn(features, labels, mode): def linear_model_fn(features, labels, mode):
@ -264,7 +287,7 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
vector_loss = 0.5 * tf.math.squared_difference(labels, preds) vector_loss = 0.5 * tf.math.squared_difference(labels, preds)
scalar_loss = tf.reduce_mean(input_tensor=vector_loss) scalar_loss = tf.reduce_mean(input_tensor=vector_loss)
optimizer = opt_cls( optimizer = optimizer_class(
l2_norm_clip=l2_norm_clip, l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier, noise_multiplier=noise_multiplier,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
@ -280,26 +303,26 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
return linear_model_fn return linear_model_fn
# Parameters for testing: optimizer, num_microbatches.
@parameterized.named_parameters( @parameterized.named_parameters(
('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1), ('DPGradientDescent_1', dp_optimizer_keras.DPKerasSGDOptimizer, 1),
('DPGradientDescent 2', dp_optimizer_keras.DPKerasSGDOptimizer, 2), ('DPGradientDescent_2', dp_optimizer_keras.DPKerasSGDOptimizer, 2),
('DPGradientDescent 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4), ('DPGradientDescent_4', dp_optimizer_keras.DPKerasSGDOptimizer, 4),
('DPGradientDescentVectorized 1', ('DPGradientDescent_None', dp_optimizer_keras.DPKerasSGDOptimizer, None),
('DPGradientDescentVectorized_1',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1), dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1),
('DPGradientDescentVectorized 2', ('DPGradientDescentVectorized_2',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2), dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2),
('DPGradientDescentVectorized 4', ('DPGradientDescentVectorized_4',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4), dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4),
('DPGradientDescentVectorized None', ('DPGradientDescentVectorized_None',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, None), dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, None),
) )
def testBaseline(self, cls, num_microbatches): def testBaselineNoNoise(self, optimizer_class, num_microbatches):
"""Tests that DP optimizers work with tf.estimator.""" """Tests that DP optimizers work with tf.estimator."""
linear_regressor = tf_estimator.Estimator( linear_regressor = tf_estimator.Estimator(
model_fn=self._make_linear_model_fn(cls, 100.0, 0.0, num_microbatches, model_fn=self._make_linear_model_fn(optimizer_class, 100.0, 0.0,
0.05)) num_microbatches, 0.05))
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32) true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
true_bias = np.array([6.0]).astype(np.float32) true_bias = np.array([6.0]).astype(np.float32)
@ -322,13 +345,12 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose( self.assertAllClose(
linear_regressor.get_variable_value('dense/bias'), true_bias, atol=0.05) linear_regressor.get_variable_value('dense/bias'), true_bias, atol=0.05)
# Parameters for testing: optimizer, num_microbatches.
@parameterized.named_parameters( @parameterized.named_parameters(
('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1), ('DPGradientDescent_1', dp_optimizer_keras.DPKerasSGDOptimizer),
('DPGradientDescentVectorized 1', ('DPGradientDescentVectorized_1',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1), dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),
) )
def testClippingNorm(self, cls, num_microbatches): def testClippingNorm(self, optimizer_class):
"""Tests that DP optimizers work with tf.estimator.""" """Tests that DP optimizers work with tf.estimator."""
true_weights = np.array([[6.0], [0.0], [0], [0]]).astype(np.float32) true_weights = np.array([[6.0], [0.0], [0], [0]]).astype(np.float32)
@ -342,8 +364,12 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
(train_data, train_labels)).batch(1) (train_data, train_labels)).batch(1)
unclipped_linear_regressor = tf_estimator.Estimator( unclipped_linear_regressor = tf_estimator.Estimator(
model_fn=self._make_linear_model_fn(cls, 1.0e9, 0.0, num_microbatches, model_fn=self._make_linear_model_fn(
1.0)) optimizer_class=optimizer_class,
l2_norm_clip=1.0e9,
noise_multiplier=0.0,
num_microbatches=1,
learning_rate=1.0))
unclipped_linear_regressor.train(input_fn=train_input_fn, steps=1) unclipped_linear_regressor.train(input_fn=train_input_fn, steps=1)
kernel_value = unclipped_linear_regressor.get_variable_value('dense/kernel') kernel_value = unclipped_linear_regressor.get_variable_value('dense/kernel')
@ -351,8 +377,12 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
global_norm = np.linalg.norm(np.concatenate((kernel_value, [bias_value]))) global_norm = np.linalg.norm(np.concatenate((kernel_value, [bias_value])))
clipped_linear_regressor = tf_estimator.Estimator( clipped_linear_regressor = tf_estimator.Estimator(
model_fn=self._make_linear_model_fn(cls, 1.0, 0.0, num_microbatches, model_fn=self._make_linear_model_fn(
1.0)) optimizer_class=optimizer_class,
l2_norm_clip=1.0,
noise_multiplier=0.0,
num_microbatches=1,
learning_rate=1.0))
clipped_linear_regressor.train(input_fn=train_input_fn, steps=1) clipped_linear_regressor.train(input_fn=train_input_fn, steps=1)
self.assertAllClose( self.assertAllClose(
@ -367,29 +397,29 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
# Parameters for testing: optimizer, l2_norm_clip, noise_multiplier, # Parameters for testing: optimizer, l2_norm_clip, noise_multiplier,
# num_microbatches. # num_microbatches.
@parameterized.named_parameters( @parameterized.named_parameters(
('DPGradientDescent 2 4 1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0, ('DPGradientDescent_2_4_1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0,
4.0, 1), 4.0, 1),
('DPGradientDescent 3 2 4', dp_optimizer_keras.DPKerasSGDOptimizer, 3.0, ('DPGradientDescent_3_2_4', dp_optimizer_keras.DPKerasSGDOptimizer, 3.0,
2.0, 4), 2.0, 4),
('DPGradientDescent 8 6 8', dp_optimizer_keras.DPKerasSGDOptimizer, 8.0, ('DPGradientDescent_8_6_8', dp_optimizer_keras.DPKerasSGDOptimizer, 8.0,
6.0, 8), 6.0, 8),
('DPGradientDescentVectorized 2 4 1', ('DPGradientDescentVectorized_2_4_1',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.0, 4.0, dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.0, 4.0,
1), 1),
('DPGradientDescentVectorized 3 2 4', ('DPGradientDescentVectorized_3_2_4',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 3.0, 2.0, dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 3.0, 2.0,
4), 4),
('DPGradientDescentVectorized 8 6 8', ('DPGradientDescentVectorized_8_6_8',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 8.0, 6.0, dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 8.0, 6.0,
8), 8),
) )
def testNoiseMultiplier(self, cls, l2_norm_clip, noise_multiplier, def testNoiseMultiplier(self, optimizer_class, l2_norm_clip, noise_multiplier,
num_microbatches): num_microbatches):
"""Tests that DP optimizers work with tf.estimator.""" """Tests that DP optimizers work with tf.estimator."""
linear_regressor = tf_estimator.Estimator( linear_regressor = tf_estimator.Estimator(
model_fn=self._make_linear_model_fn( model_fn=self._make_linear_model_fn(
cls, optimizer_class,
l2_norm_clip, l2_norm_clip,
noise_multiplier, noise_multiplier,
num_microbatches, num_microbatches,
@ -423,9 +453,9 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
('DPAdamVectorized', ('DPAdamVectorized',
dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer), dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer),
) )
def testAssertOnNoCallOfGetGradients(self, cls): def testRaisesOnNoCallOfGetGradients(self, optimizer_class):
"""Tests that assertion fails when DP gradients are not computed.""" """Tests that assertion fails when DP gradients are not computed."""
opt = cls( optimizer = optimizer_class(
l2_norm_clip=100.0, l2_norm_clip=100.0,
noise_multiplier=0.0, noise_multiplier=0.0,
num_microbatches=1, num_microbatches=1,
@ -433,7 +463,7 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
grads_and_vars = tf.Variable([0.0]) grads_and_vars = tf.Variable([0.0])
opt.apply_gradients(grads_and_vars) optimizer.apply_gradients(grads_and_vars)
def testLargeBatchEmulationNoNoise(self): def testLargeBatchEmulationNoNoise(self):
# Test for emulation of large batch training. # Test for emulation of large batch training.
@ -454,7 +484,7 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
x2 = tf.constant([[4.0, 2.0], [2.0, 1.0]], dtype=tf.float32) x2 = tf.constant([[4.0, 2.0], [2.0, 1.0]], dtype=tf.float32)
loss2 = lambda: tf.matmul(var0, x2, transpose_b=True) + var1 loss2 = lambda: tf.matmul(var0, x2, transpose_b=True) + var1
opt = dp_optimizer_keras.DPKerasSGDOptimizer( optimizer = dp_optimizer_keras.DPKerasSGDOptimizer(
l2_norm_clip=100.0, l2_norm_clip=100.0,
noise_multiplier=0.0, noise_multiplier=0.0,
gradient_accumulation_steps=2, gradient_accumulation_steps=2,
@ -464,35 +494,36 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0) self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
self.assertAllCloseAccordingToType([3.0], var1) self.assertAllCloseAccordingToType([3.0], var1)
opt.minimize(loss1, [var0, var1]) optimizer.minimize(loss1, [var0, var1])
# After first call to optimizer values didn't change # After first call to optimizer values didn't change
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0) self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
self.assertAllCloseAccordingToType([3.0], var1) self.assertAllCloseAccordingToType([3.0], var1)
opt.minimize(loss2, [var0, var1]) optimizer.minimize(loss2, [var0, var1])
# After second call to optimizer updates were applied # After second call to optimizer updates were applied
self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0) self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0)
self.assertAllCloseAccordingToType([2.0], var1) self.assertAllCloseAccordingToType([2.0], var1)
opt.minimize(loss2, [var0, var1]) optimizer.minimize(loss2, [var0, var1])
# After third call to optimizer values didn't change # After third call to optimizer values didn't change
self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0) self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0)
self.assertAllCloseAccordingToType([2.0], var1) self.assertAllCloseAccordingToType([2.0], var1)
opt.minimize(loss2, [var0, var1]) optimizer.minimize(loss2, [var0, var1])
# After fourth call to optimizer updates were applied again # After fourth call to optimizer updates were applied again
self.assertAllCloseAccordingToType([[-4.0, -0.5]], var0) self.assertAllCloseAccordingToType([[-4.0, -0.5]], var0)
self.assertAllCloseAccordingToType([1.0], var1) self.assertAllCloseAccordingToType([1.0], var1)
@parameterized.named_parameters( @parameterized.named_parameters(
('DPKerasSGDOptimizer 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1), ('DPKerasSGDOptimizer_1', dp_optimizer_keras.DPKerasSGDOptimizer, 1),
('DPKerasSGDOptimizer 2', dp_optimizer_keras.DPKerasSGDOptimizer, 2), ('DPKerasSGDOptimizer_2', dp_optimizer_keras.DPKerasSGDOptimizer, 2),
('DPKerasSGDOptimizer 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4), ('DPKerasSGDOptimizer_4', dp_optimizer_keras.DPKerasSGDOptimizer, 4),
('DPKerasAdamOptimizer 2', dp_optimizer_keras.DPKerasAdamOptimizer, 1), ('DPKerasAdamOptimizer_2', dp_optimizer_keras.DPKerasAdamOptimizer, 1),
('DPKerasAdagradOptimizer 2', dp_optimizer_keras.DPKerasAdagradOptimizer, ('DPKerasAdagradOptimizer_2', dp_optimizer_keras.DPKerasAdagradOptimizer,
2), 2),
) )
def testLargeBatchEmulation(self, cls, gradient_accumulation_steps): def testLargeBatchEmulation(self, optimizer_class,
gradient_accumulation_steps):
# Tests various optimizers with large batch emulation. # Tests various optimizers with large batch emulation.
# Uses clipping and noise, thus does not test specific values # Uses clipping and noise, thus does not test specific values
# of the variables and only tests how often variables are updated. # of the variables and only tests how often variables are updated.
@ -501,7 +532,7 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
x = tf.constant([[2.0, 0.0], [0.0, 1.0]], dtype=tf.float32) x = tf.constant([[2.0, 0.0], [0.0, 1.0]], dtype=tf.float32)
loss = lambda: tf.matmul(var0, x, transpose_b=True) + var1 loss = lambda: tf.matmul(var0, x, transpose_b=True) + var1
opt = cls( optimizer = optimizer_class(
l2_norm_clip=100.0, l2_norm_clip=100.0,
noise_multiplier=0.0, noise_multiplier=0.0,
gradient_accumulation_steps=gradient_accumulation_steps, gradient_accumulation_steps=gradient_accumulation_steps,
@ -510,7 +541,7 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
for _ in range(gradient_accumulation_steps): for _ in range(gradient_accumulation_steps):
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0) self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
self.assertAllCloseAccordingToType([3.0], var1) self.assertAllCloseAccordingToType([3.0], var1)
opt.minimize(loss, [var0, var1]) optimizer.minimize(loss, [var0, var1])
self.assertNotAllClose([[1.0, 2.0]], var0) self.assertNotAllClose([[1.0, 2.0]], var0)
self.assertNotAllClose([3.0], var1) self.assertNotAllClose([3.0], var1)
@ -547,19 +578,19 @@ class SimpleEmbeddingModel(tf.keras.Model):
return sequence_output, pooled_output return sequence_output, pooled_output
def keras_embedding_model_fn(opt_cls, def keras_embedding_model_fn(optimizer_class,
l2_norm_clip: float, l2_norm_clip: float,
noise_multiplier: float, noise_multiplier: float,
num_microbatches: int, num_microbatches: int,
learning_rate: float, learning_rate: float,
use_seq_output: bool = False, use_sequence_output: bool = False,
unconnected_gradients_to_zero: bool = False): unconnected_gradients_to_zero: bool = False):
"""Construct a simple embedding model with a classification layer.""" """Construct a simple embedding model with a classification layer."""
# Every sample has 4 tokens (sequence length=4). # Every sample has 4 tokens (sequence length=4).
x = tf.keras.layers.Input(shape=(4,), dtype=tf.float32, name='input') x = tf.keras.layers.Input(shape=(4,), dtype=tf.float32, name='input')
sequence_output, pooled_output = SimpleEmbeddingModel()(x) sequence_output, pooled_output = SimpleEmbeddingModel()(x)
if use_seq_output: if use_sequence_output:
embedding = sequence_output embedding = sequence_output
else: else:
embedding = pooled_output embedding = pooled_output
@ -568,7 +599,7 @@ def keras_embedding_model_fn(opt_cls,
embedding) embedding)
model = tf.keras.Model(inputs=x, outputs=probs, name='model') model = tf.keras.Model(inputs=x, outputs=probs, name='model')
optimizer = opt_cls( optimizer = optimizer_class(
l2_norm_clip=l2_norm_clip, l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier, noise_multiplier=noise_multiplier,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
@ -608,7 +639,7 @@ class DPVectorizedOptimizerUnconnectedNodesTest(tf.test.TestCase,
@parameterized.named_parameters( @parameterized.named_parameters(
('DPSGDVectorized_SeqOutput_UnconnectedGradients', ('DPSGDVectorized_SeqOutput_UnconnectedGradients',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),) dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),)
def testSeqOutputUnconnectedGradientsAsNoneFails(self, cls): def testSeqOutputUnconnectedGradientsAsNoneFails(self, optimizer_class):
"""Tests that DP vectorized optimizers with 'None' unconnected gradients fail. """Tests that DP vectorized optimizers with 'None' unconnected gradients fail.
Sequence models that have unconnected gradients (with Sequence models that have unconnected gradients (with
@ -620,16 +651,16 @@ class DPVectorizedOptimizerUnconnectedNodesTest(tf.test.TestCase,
These tests test the various combinations of this flag and the model. These tests test the various combinations of this flag and the model.
Args: Args:
cls: The DP optimizer class to test. optimizer_class: The DP optimizer class to test.
""" """
embedding_model = keras_embedding_model_fn( embedding_model = keras_embedding_model_fn(
cls, optimizer_class,
l2_norm_clip=1.0, l2_norm_clip=1.0,
noise_multiplier=0.5, noise_multiplier=0.5,
num_microbatches=1, num_microbatches=1,
learning_rate=1.0, learning_rate=1.0,
use_seq_output=True, use_sequence_output=True,
unconnected_gradients_to_zero=False) unconnected_gradients_to_zero=False)
train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32) train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32)
@ -651,16 +682,17 @@ class DPVectorizedOptimizerUnconnectedNodesTest(tf.test.TestCase,
@parameterized.named_parameters( @parameterized.named_parameters(
('DPSGDVectorized_PooledOutput_UnconnectedGradients', ('DPSGDVectorized_PooledOutput_UnconnectedGradients',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),) dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),)
def testPooledOutputUnconnectedGradientsAsNonePasses(self, cls): def testPooledOutputUnconnectedGradientsAsNonePasses(self, optimizer_class):
"""Tests that DP vectorized optimizers with 'None' unconnected gradients fail.""" """Tests that DP vectorized optimizers with 'None' unconnected gradients fail.
"""
embedding_model = keras_embedding_model_fn( embedding_model = keras_embedding_model_fn(
cls, optimizer_class,
l2_norm_clip=1.0, l2_norm_clip=1.0,
noise_multiplier=0.5, noise_multiplier=0.5,
num_microbatches=1, num_microbatches=1,
learning_rate=1.0, learning_rate=1.0,
use_seq_output=False, use_sequence_output=False,
unconnected_gradients_to_zero=False) unconnected_gradients_to_zero=False)
train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32) train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32)
@ -684,16 +716,18 @@ class DPVectorizedOptimizerUnconnectedNodesTest(tf.test.TestCase,
('DPSGDVectorized_PooledOutput_UnconnectedGradientsAreZero', ('DPSGDVectorized_PooledOutput_UnconnectedGradientsAreZero',
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, False), dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, False),
) )
def testUnconnectedGradientsAsZeroPasses(self, cls, use_seq_output): def testUnconnectedGradientsAsZeroPasses(self, optimizer_class,
"""Tests that DP vectorized optimizers with 'Zero' unconnected gradients pass.""" use_sequence_output):
"""Tests that DP vectorized optimizers with 'Zero' unconnected gradients pass.
"""
embedding_model = keras_embedding_model_fn( embedding_model = keras_embedding_model_fn(
cls, optimizer_class,
l2_norm_clip=1.0, l2_norm_clip=1.0,
noise_multiplier=0.5, noise_multiplier=0.5,
num_microbatches=1, num_microbatches=1,
learning_rate=1.0, learning_rate=1.0,
use_seq_output=use_seq_output, use_sequence_output=use_sequence_output,
unconnected_gradients_to_zero=True) unconnected_gradients_to_zero=True)
train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32) train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32)
@ -710,5 +744,6 @@ class DPVectorizedOptimizerUnconnectedNodesTest(tf.test.TestCase,
# other exceptions are errors. # other exceptions are errors.
self.fail('ValueError raised by model.fit().') self.fail('ValueError raised by model.fit().')
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()