Efficient DPSGD with support to microbatched losses.
PiperOrigin-RevId: 513886957
This commit is contained in:
parent
cbf34f2b04
commit
8bfafdd74d
6 changed files with 252 additions and 104 deletions
|
@ -21,7 +21,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
|
|||
`compute_gradient_norms()` function).
|
||||
"""
|
||||
|
||||
from typing import Dict, Iterable, Text, Union
|
||||
from typing import Dict, Iterable, Optional, Text, Union
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||
|
@ -31,7 +31,9 @@ InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
|
|||
|
||||
|
||||
def get_registry_generator_fn(
|
||||
tape: tf.GradientTape, layer_registry: lr.LayerRegistry
|
||||
tape: tf.GradientTape,
|
||||
layer_registry: lr.LayerRegistry,
|
||||
num_microbatches: Optional[lr.BatchSize] = None,
|
||||
):
|
||||
"""Creates the generator function for `compute_gradient_norms()`."""
|
||||
if layer_registry is None:
|
||||
|
@ -50,14 +52,14 @@ def get_registry_generator_fn(
|
|||
)
|
||||
registry_fn = layer_registry.lookup(layer_instance)
|
||||
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
|
||||
layer_instance, args, tape
|
||||
layer_instance, args, tape, num_microbatches
|
||||
)
|
||||
return layer_outputs, (layer_vars, layer_sqr_norm_fn)
|
||||
else:
|
||||
# Non-trainable layer.
|
||||
return layer_instance(*args, **kwargs), None
|
||||
|
||||
return registry_generator_fn
|
||||
return registry_generator_fn
|
||||
|
||||
|
||||
def compute_gradient_norms(
|
||||
|
@ -65,6 +67,7 @@ def compute_gradient_norms(
|
|||
x_batch: InputTensor,
|
||||
y_batch: tf.Tensor,
|
||||
layer_registry: lr.LayerRegistry,
|
||||
num_microbatches: Optional[lr.BatchSize] = None,
|
||||
):
|
||||
"""Computes the per-example loss gradient norms for given data.
|
||||
|
||||
|
@ -83,13 +86,21 @@ def compute_gradient_norms(
|
|||
compute gradient norms quickly. See
|
||||
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
|
||||
more details.
|
||||
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
|
||||
microbatches. If not None, indicates that the loss is grouped into
|
||||
num_microbatches (in this case, the batch dimension needs to be a multiple
|
||||
of num_microbatches). When there is microbatches, we always assume the
|
||||
loss is the mean over a microbatch. And the gradient norm is computed for
|
||||
each microbatch.
|
||||
|
||||
Returns:
|
||||
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
|
||||
per-example loss function.
|
||||
"""
|
||||
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
||||
registry_generator_fn = get_registry_generator_fn(tape, layer_registry)
|
||||
registry_generator_fn = get_registry_generator_fn(
|
||||
tape, layer_registry, num_microbatches
|
||||
)
|
||||
# First loop computes the model outputs, summed loss, and generator outputs.
|
||||
with tape:
|
||||
model_outputs, generator_outputs_list = (
|
||||
|
@ -102,6 +113,10 @@ def compute_gradient_norms(
|
|||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
||||
losses = per_example_loss_fn(y_batch, model_outputs)
|
||||
if num_microbatches is not None:
|
||||
losses = tf.reduce_mean(
|
||||
lr.add_microbatch_axis(losses, num_microbatches), axis=1
|
||||
)
|
||||
summed_loss = tf.reduce_sum(losses)
|
||||
# Unwrap the generator outputs so that the next loop avoids duplicating
|
||||
# backprop ops.
|
||||
|
@ -149,6 +164,7 @@ def compute_pred_and_clipped_gradients(
|
|||
y_batch: tf.Tensor,
|
||||
l2_norm_clip: float,
|
||||
layer_registry: lr.LayerRegistry,
|
||||
num_microbatches: Optional[lr.BatchSize] = None,
|
||||
):
|
||||
"""Computes the per-example predictions and per-example clipped loss gradient.
|
||||
|
||||
|
@ -177,6 +193,10 @@ def compute_pred_and_clipped_gradients(
|
|||
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
|
||||
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
||||
trainable weights (see `layer_registry_factories.py` for examples).
|
||||
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
|
||||
microbatches. If not None, indicates that the loss is grouped into
|
||||
num_microbatches (in this case, the batch dimension needs to be a multiple
|
||||
of num_microbatches).
|
||||
|
||||
Returns:
|
||||
A `tuple` `(y_pred, grad)`. The first element is the prediction generated by
|
||||
|
@ -184,11 +204,21 @@ def compute_pred_and_clipped_gradients(
|
|||
gradient of the loss function.
|
||||
"""
|
||||
gradient_norms = compute_gradient_norms(
|
||||
input_model, x_batch, y_batch, layer_registry
|
||||
input_model, x_batch, y_batch, layer_registry, num_microbatches
|
||||
)
|
||||
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
|
||||
with tf.GradientTape() as tape:
|
||||
y_pred = input_model(x_batch, training=True)
|
||||
if num_microbatches is not None:
|
||||
y_batch = lr.add_microbatch_axis(y_batch, num_microbatches)
|
||||
y_pred = lr.add_microbatch_axis(y_pred, num_microbatches)
|
||||
# Warning: When num_microbatches is not None, we need to be sure that
|
||||
# `compute_loss` always computes the mean over the microbatches
|
||||
# as it is the assumption made when computing the gradient norm.
|
||||
# It is indeed the case for multiple keras loss functions
|
||||
# (e.g. mean_squared_error and binary_crossentropy). However it
|
||||
# is not defined in the contract so may not hold, especially for
|
||||
# custom losses.
|
||||
loss_value = input_model.compute_loss(
|
||||
x_batch, y_batch, y_pred, loss_weights
|
||||
)
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
from typing import Callable, Any, List, Union
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
@ -49,14 +49,17 @@ class DoubleDense(tf.keras.layers.Layer):
|
|||
|
||||
|
||||
def double_dense_layer_computation(
|
||||
layer_instance: tf.keras.layers.Layer, inputs: Any, tape: tf.GradientTape
|
||||
layer_instance: tf.keras.layers.Layer,
|
||||
inputs: Any,
|
||||
tape: tf.GradientTape,
|
||||
num_microbatches: Optional[int],
|
||||
):
|
||||
"""Layer registry function for the custom `DoubleDense` layer class."""
|
||||
vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation(
|
||||
layer_instance.dense1, inputs, tape
|
||||
layer_instance.dense1, inputs, tape, num_microbatches
|
||||
)
|
||||
vars2, outputs, sqr_norm_fn2 = layer_registry.dense_layer_computation(
|
||||
layer_instance.dense2, (outputs,), tape
|
||||
layer_instance.dense2, (outputs,), tape, num_microbatches
|
||||
)
|
||||
|
||||
def sqr_norm_fn(base_vars):
|
||||
|
@ -68,7 +71,10 @@ def double_dense_layer_computation(
|
|||
|
||||
|
||||
def compute_true_gradient_norms(
|
||||
input_model: tf.keras.Model, x_batch: tf.Tensor, y_batch: tf.Tensor
|
||||
input_model: tf.keras.Model,
|
||||
x_batch: tf.Tensor,
|
||||
y_batch: tf.Tensor,
|
||||
num_microbatches: Optional[int],
|
||||
):
|
||||
"""Computes the real gradient norms for an input `(model, x, y)`."""
|
||||
loss_config = input_model.loss.get_config()
|
||||
|
@ -77,13 +83,22 @@ def compute_true_gradient_norms(
|
|||
with tf.GradientTape(persistent=True) as tape:
|
||||
y_pred = input_model(x_batch)
|
||||
loss = per_example_loss_fn(y_batch, y_pred)
|
||||
if num_microbatches is not None:
|
||||
loss = tf.reduce_mean(
|
||||
tf.reshape(
|
||||
loss,
|
||||
tf.concat([[num_microbatches, -1], tf.shape(loss)[1:]], axis=0),
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
if isinstance(loss, tf.RaggedTensor):
|
||||
loss = loss.to_tensor()
|
||||
sqr_norms = []
|
||||
for var in input_model.trainable_variables:
|
||||
jacobian = tape.jacobian(loss, var, experimental_use_pfor=False)
|
||||
reduction_axes = tf.range(1, len(jacobian.shape))
|
||||
sqr_norms.append(tf.reduce_sum(tf.square(jacobian), axis=reduction_axes))
|
||||
sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes)
|
||||
sqr_norms.append(sqr_norm)
|
||||
sqr_norm_tsr = tf.stack(sqr_norms, axis=1)
|
||||
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
||||
|
||||
|
@ -93,6 +108,7 @@ def get_computed_and_true_norms(
|
|||
layer_generator: LayerGenerator,
|
||||
input_dims: Union[int, List[int]],
|
||||
output_dim: int,
|
||||
num_microbatches: Optional[int],
|
||||
is_eager: bool,
|
||||
x_input: tf.Tensor,
|
||||
rng_seed: int = 777,
|
||||
|
@ -113,6 +129,7 @@ def get_computed_and_true_norms(
|
|||
`idim` and returns output tensors of dimension `odim`.
|
||||
input_dims: The input dimension(s) of the test `tf.keras.Model` instance.
|
||||
output_dim: The output dimension of the test `tf.keras.Model` instance.
|
||||
num_microbatches: The number of microbatches. None or an integer.
|
||||
is_eager: A `bool` that is `True` if the model should be run eagerly.
|
||||
x_input: `tf.Tensor` inputs to be tested.
|
||||
rng_seed: An `int` used to initialize model weights.
|
||||
|
@ -137,10 +154,16 @@ def get_computed_and_true_norms(
|
|||
y_batch = tf.ones_like(y_pred)
|
||||
tf.keras.utils.set_random_seed(rng_seed)
|
||||
computed_norms = clip_grads.compute_gradient_norms(
|
||||
model, x_input, y_batch, layer_registry=registry
|
||||
model,
|
||||
x_input,
|
||||
y_batch,
|
||||
layer_registry=registry,
|
||||
num_microbatches=num_microbatches,
|
||||
)
|
||||
tf.keras.utils.set_random_seed(rng_seed)
|
||||
true_norms = compute_true_gradient_norms(model, x_input, y_batch)
|
||||
true_norms = compute_true_gradient_norms(
|
||||
model, x_input, y_batch, num_microbatches
|
||||
)
|
||||
return (computed_norms, true_norms)
|
||||
|
||||
|
||||
|
@ -322,18 +345,30 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
@parameterized.product(
|
||||
model_name=list(get_dense_model_generators().keys()),
|
||||
layer_name=list(get_dense_layer_generators().keys()),
|
||||
input_dim=[1, 2],
|
||||
input_dim=[4],
|
||||
output_dim=[1, 2],
|
||||
num_microbatches=[None, 1, 2],
|
||||
is_eager=[True, False],
|
||||
)
|
||||
def test_gradient_norms_on_various_models(
|
||||
self, model_name, layer_name, input_dim, output_dim, is_eager
|
||||
self,
|
||||
model_name,
|
||||
layer_name,
|
||||
input_dim,
|
||||
output_dim,
|
||||
num_microbatches,
|
||||
is_eager,
|
||||
):
|
||||
model_generator = get_dense_model_generators()[model_name]
|
||||
layer_generator = get_dense_layer_generators()[layer_name]
|
||||
x_batches = get_nd_test_batches(input_dim)
|
||||
default_registry = layer_registry.make_default_layer_registry()
|
||||
for x_batch in x_batches:
|
||||
if (
|
||||
num_microbatches is not None
|
||||
and x_batch.shape[0] % num_microbatches != 0
|
||||
):
|
||||
continue
|
||||
if model_name == 'tower1':
|
||||
x_input = [x_batch, x_batch]
|
||||
else:
|
||||
|
@ -343,6 +378,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
layer_generator,
|
||||
input_dim,
|
||||
output_dim,
|
||||
num_microbatches,
|
||||
is_eager,
|
||||
x_input,
|
||||
registry=default_registry,
|
||||
|
@ -362,6 +398,10 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
tf.ragged.constant(
|
||||
[[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32
|
||||
),
|
||||
tf.ragged.constant(
|
||||
[[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1], [0, 1]],
|
||||
dtype=tf.int32,
|
||||
),
|
||||
# 3D inputs.
|
||||
tf.convert_to_tensor([[[0, 1]]], dtype_hint=tf.int32),
|
||||
tf.convert_to_tensor(
|
||||
|
@ -371,14 +411,24 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
[[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]],
|
||||
dtype=tf.int32,
|
||||
),
|
||||
tf.ragged.constant(
|
||||
[[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]], [[0]]],
|
||||
dtype=tf.int32,
|
||||
),
|
||||
],
|
||||
model_name=list(get_embedding_model_generators().keys()),
|
||||
output_dim=[1, 2],
|
||||
is_eager=[True, False],
|
||||
output_dim=[2],
|
||||
num_microbatches=[None, 1, 2],
|
||||
is_eager=[True],
|
||||
)
|
||||
def test_gradient_norms_on_various_models(
|
||||
self, x_batch, model_name, output_dim, is_eager
|
||||
self, x_batch, model_name, output_dim, num_microbatches, is_eager
|
||||
):
|
||||
if (
|
||||
num_microbatches is not None
|
||||
and x_batch.shape[0] % num_microbatches != 0
|
||||
):
|
||||
return
|
||||
valid_test_input = (
|
||||
not isinstance(x_batch, tf.RaggedTensor)
|
||||
and model_name == 'weighted_bow1'
|
||||
|
@ -391,6 +441,7 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
layer_generator=None,
|
||||
input_dims=x_batch.shape[1:],
|
||||
output_dim=output_dim,
|
||||
num_microbatches=num_microbatches,
|
||||
is_eager=is_eager,
|
||||
x_input=x_batch,
|
||||
registry=default_registry,
|
||||
|
@ -403,20 +454,27 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
@parameterized.product(
|
||||
input_dim=[1, 2],
|
||||
output_dim=[1, 2],
|
||||
num_microbatches=[None, 1, 2],
|
||||
is_eager=[True, False],
|
||||
)
|
||||
def test_gradient_norms_on_various_models(
|
||||
self, input_dim, output_dim, is_eager
|
||||
self, input_dim, output_dim, num_microbatches, is_eager
|
||||
):
|
||||
registry = layer_registry.make_default_layer_registry()
|
||||
registry.insert(DoubleDense, double_dense_layer_computation)
|
||||
x_batches = get_nd_test_batches(input_dim)
|
||||
for x_batch in x_batches:
|
||||
if (
|
||||
num_microbatches is not None
|
||||
and x_batch.shape[0] % num_microbatches != 0
|
||||
):
|
||||
continue
|
||||
(computed_norms, true_norms) = get_computed_and_true_norms(
|
||||
model_generator=make_two_layer_sequential_model,
|
||||
layer_generator=lambda a, b: DoubleDense(b),
|
||||
input_dims=input_dim,
|
||||
output_dim=output_dim,
|
||||
num_microbatches=num_microbatches,
|
||||
is_eager=is_eager,
|
||||
x_input=x_batch,
|
||||
registry=registry,
|
||||
|
|
|
@ -157,8 +157,8 @@ def all_trainable_layers_are_registered(
|
|||
|
||||
def add_aggregate_noise(
|
||||
input_model: tf.keras.Model,
|
||||
x_batch: InputTensor,
|
||||
clipped_grads: List[tf.Tensor],
|
||||
clipped_grads: list[tf.Tensor],
|
||||
batch_size: tf.Tensor,
|
||||
l2_norm_clip: float,
|
||||
noise_multiplier: float,
|
||||
) -> List[tf.Tensor]:
|
||||
|
@ -169,8 +169,9 @@ def add_aggregate_noise(
|
|||
|
||||
Args:
|
||||
input_model: The `tf.keras.Model` to obtain the layers from.
|
||||
x_batch: An `InputTensor` to be fed into the input layer of the model.
|
||||
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
|
||||
batch_size: The batch size, used for normalizing the noise, when the loss
|
||||
reduction is AUTO or SUM_OVER_BATCH_SIZE.
|
||||
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
|
||||
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
||||
|
||||
|
@ -186,17 +187,7 @@ def add_aggregate_noise(
|
|||
]:
|
||||
if input_model.loss.reduction == tf.keras.losses.Reduction.AUTO:
|
||||
logging.info('Assuming that the loss reduction is `SUM_OVER_BATCH_SIZE`.')
|
||||
if isinstance(x_batch, tf.Tensor):
|
||||
scale /= tf.cast(tf.shape(x_batch)[0], tf.float32)
|
||||
elif isinstance(x_batch, dict):
|
||||
batch_sizes = [
|
||||
tf.cast(tf.shape(v)[0], tf.float32) for v in x_batch.values()
|
||||
]
|
||||
scale /= tf.math.reduce_min(batch_sizes)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Unknown container/class %s for input' % x_batch.__class__.__name__
|
||||
)
|
||||
scale /= tf.cast(batch_size, tf.float32)
|
||||
|
||||
def add_noise(g):
|
||||
return g + tf.random.normal(
|
||||
|
|
|
@ -38,9 +38,18 @@ whose i-th entry is the L2 norm of the i-th input vector, then
|
|||
|
||||
where `l2_row_norm(y)` computes the L2 norm for each row of an input `y`.
|
||||
Details of this decomposition can be found in https://arxiv.org/abs/1510.01799
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict, Iterable, Text, Tuple, Type, Union
|
||||
We also extend fast gradient norm computation to the case when the losses
|
||||
are microbatched, i.e. each per example loss is the mean of a set of losses.
|
||||
This could be useful for achieving user-level privacy and for improving the
|
||||
quality of DP models, through better estimation of the gradients due to
|
||||
aggregation at the microbatch level.
|
||||
"""
|
||||
# copybara.strip_begin
|
||||
# The detailed algorithm can be found in go/fast-dpsgd-mb.
|
||||
# copybara.strip_end
|
||||
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Text, Tuple, Type, Union
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
|
@ -56,6 +65,7 @@ RegistryFunction = Callable[
|
|||
]
|
||||
|
||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
|
||||
BatchSize = Union[int, tf.Tensor]
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -88,6 +98,37 @@ class LayerRegistry:
|
|||
self._registry[layer_key] = layer_registry_function
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Utilities
|
||||
# ==============================================================================
|
||||
def add_microbatch_axis(
|
||||
x: tf.Tensor,
|
||||
num_microbatches: Optional[BatchSize],
|
||||
) -> tf.Tensor:
|
||||
"""Adds the microbatch axis.
|
||||
|
||||
Reshape the input tensor to replace the first(batch) dimension with the
|
||||
shape [num_microbatches, batch_size / num_microbatches]. The batch size
|
||||
must be a multiple of num_microbatches (unless it is None, meaning
|
||||
num_microbatches is the same as the batch size).
|
||||
|
||||
Args:
|
||||
x: the input tensor.
|
||||
num_microbatches: None or a numeric value or a scalar `tf.Tensor`.
|
||||
|
||||
Returns:
|
||||
The reshaped input tensor.
|
||||
"""
|
||||
if num_microbatches is None:
|
||||
return tf.expand_dims(x, 1)
|
||||
with tf.control_dependencies(
|
||||
[tf.assert_equal(tf.math.floormod(tf.shape(x)[0], num_microbatches), 0)]
|
||||
):
|
||||
return tf.reshape(
|
||||
x, tf.concat([[num_microbatches, -1], tf.shape(x)[1:]], axis=0)
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Supported Keras layers
|
||||
# ==============================================================================
|
||||
|
@ -95,6 +136,7 @@ def dense_layer_computation(
|
|||
layer_instance: tf.keras.layers.Dense,
|
||||
inputs: Tuple[InputTensor],
|
||||
tape: tf.GradientTape,
|
||||
num_microbatches: Optional[tf.Tensor] = None,
|
||||
) -> RegistryFunctionOutput:
|
||||
"""Registry function for `tf.keras.layers.Dense`.
|
||||
|
||||
|
@ -111,6 +153,9 @@ def dense_layer_computation(
|
|||
output.
|
||||
tape: A `tf.GradientTape` instance that will be used to watch the output
|
||||
`base_vars`.
|
||||
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
|
||||
indicating whether and how the losses are grouped into microbatches. If
|
||||
not None, num_microbatches must divide the batch size.
|
||||
|
||||
Returns:
|
||||
A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the
|
||||
|
@ -132,21 +177,29 @@ def dense_layer_computation(
|
|||
tape.watch(base_vars)
|
||||
layer_instance.activation = orig_activation
|
||||
outputs = orig_activation(base_vars) if orig_activation else base_vars
|
||||
|
||||
def sqr_norm_fn(base_vars_grads):
|
||||
sqr_inputs = tf.square(*inputs)
|
||||
inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs))
|
||||
input_sqr_norms = tf.reduce_sum(sqr_inputs, axis=inputs_reduction_axes)
|
||||
|
||||
def _compute_gramian(x):
|
||||
if num_microbatches is not None:
|
||||
x_microbatched = add_microbatch_axis(x, num_microbatches)
|
||||
return tf.matmul(x_microbatched, x_microbatched, transpose_b=True)
|
||||
else:
|
||||
# Special handling for better efficiency
|
||||
return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x)))
|
||||
|
||||
inputs_gram = _compute_gramian(*inputs)
|
||||
base_vars_grads_gram = _compute_gramian(base_vars_grads)
|
||||
if layer_instance.use_bias:
|
||||
# Adding a bias term is equivalent to a layer with no bias term and which
|
||||
# adds an additional variable to the layer input that only takes a
|
||||
# constant value of 1.0. This is thus equivalent to adding 1.0 to the sum
|
||||
# of the squared values of the inputs.
|
||||
input_sqr_norms += tf.cast(1.0, dtype=input_sqr_norms.dtype)
|
||||
reduction_axes = tf.range(1, tf.rank(base_vars_grads))
|
||||
base_vars_sqr_norms = tf.reduce_sum(
|
||||
tf.square(base_vars_grads), axis=reduction_axes
|
||||
inputs_gram += 1.0
|
||||
return tf.reduce_sum(
|
||||
inputs_gram * base_vars_grads_gram,
|
||||
axis=tf.range(1, tf.rank(inputs_gram)),
|
||||
)
|
||||
return input_sqr_norms * base_vars_sqr_norms
|
||||
|
||||
return base_vars, outputs, sqr_norm_fn
|
||||
|
||||
|
@ -155,6 +208,7 @@ def embedding_layer_computation(
|
|||
layer_instance: tf.keras.layers.Embedding,
|
||||
inputs: Tuple[InputTensor],
|
||||
tape: tf.GradientTape,
|
||||
num_microbatches: Optional[tf.Tensor] = None,
|
||||
) -> RegistryFunctionOutput:
|
||||
"""Registry function for `tf.keras.layers.Embedding`.
|
||||
|
||||
|
@ -171,6 +225,9 @@ def embedding_layer_computation(
|
|||
output.
|
||||
tape: A `tf.GradientTape` instance that will be used to watch the output
|
||||
`base_vars`.
|
||||
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
|
||||
indicating whether and how the losses are grouped into microbatches. If
|
||||
not None, num_microbatches must divide the batch size.
|
||||
|
||||
Returns:
|
||||
A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the
|
||||
|
@ -219,6 +276,13 @@ def embedding_layer_computation(
|
|||
raise NotImplementedError(
|
||||
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
|
||||
)
|
||||
row_indices = tf.cast(row_indices, tf.int32)
|
||||
if num_microbatches is not None:
|
||||
microbatch_size = tf.cast(nrows / num_microbatches, tf.int32)
|
||||
nrows = num_microbatches
|
||||
row_indices = tf.cast(
|
||||
tf.math.floordiv(row_indices, microbatch_size), tf.int32
|
||||
)
|
||||
# Sum-reduce the `IndexSlices` that is the result of a `tape.gradient()`
|
||||
# call. The sum is reduced by the repeated embedding indices and batch
|
||||
# index. It is adapted from the logic in:
|
||||
|
|
|
@ -26,53 +26,64 @@ def make_dp_model_class(cls):
|
|||
__doc__ = (
|
||||
"""DP subclass of `{base_model}`.
|
||||
|
||||
This can be used as a differentially private replacement for
|
||||
{base_model}. This class implements DP-SGD using the standard
|
||||
Gaussian mechanism.
|
||||
This can be used as a differentially private replacement for
|
||||
{base_model}. This class implements DP-SGD using the standard
|
||||
Gaussian mechanism.
|
||||
|
||||
This class also utilizes a faster gradient clipping algorithm if the
|
||||
following two conditions hold:
|
||||
This class also utilizes a faster gradient clipping algorithm if the
|
||||
following two conditions hold:
|
||||
(i) the trainable layers of the model are keys in the `dict` input
|
||||
`layer_registry`,
|
||||
(ii) the loss `tf.Tensor` for a given batch of examples is either a
|
||||
scalar or a 2D `tf.Tensor` that has only one column
|
||||
`(i.e., tf.shape(loss)[1] == 1)` and whose i-th row corresponds to
|
||||
the loss of the i-th example.
|
||||
This clipping algorithm specifically computes clipped gradients at the
|
||||
per-example level using the layer registry functions in `layer_registry`
|
||||
(see clip_grads.py for more information about the algorithm). In this
|
||||
setting, microbatching is not used (it is equivalent to
|
||||
`num_microbatches == batch_size`), and the input `num_microbatches`
|
||||
is ignored.
|
||||
This clipping algorithm specifically computes clipped gradients at the
|
||||
per-example or per microbatch (when `num_microbatches` is not None)
|
||||
level using the layer registry functions in `layer_registry` (see
|
||||
clip_grads.py for more information about the algorithm).
|
||||
|
||||
When instantiating this class, you need to supply several
|
||||
DP-related arguments followed by the standard arguments for
|
||||
`{short_base_model}`.
|
||||
WARNING: with faster gradient clipping, and when num_microbatches is not
|
||||
None, the per microbatch loss is assumed to be computed as the mean
|
||||
of the loss over the microbatch, or effectively, by reshaping the loss
|
||||
from the shape [batch_size, ...] to the shape
|
||||
[num_microbatches, batch_size/num_microbatches, ...] and computing the
|
||||
mean of the loss over the microbatches. This would require that the loss
|
||||
function behaves accordingly. This is true for multiple common
|
||||
predefined keras loss functions (e.g. mean_squared_loss,
|
||||
binary_crossentropy) but may not hold for custom losses (and how such
|
||||
aggregation is done is not exposed by the loss function, unfortunately).
|
||||
It is the caller's responsibility to make sure that the loss function
|
||||
does behave this way.
|
||||
|
||||
Examples:
|
||||
When instantiating this class, you need to supply several
|
||||
DP-related arguments followed by the standard arguments for
|
||||
`{short_base_model}`.
|
||||
|
||||
```python
|
||||
# Create Model instance.
|
||||
model = {dp_model_class}(l2_norm_clip=1.0, noise_multiplier=0.5, use_xla=True,
|
||||
<standard arguments>)
|
||||
```
|
||||
Examples:
|
||||
|
||||
You should use your {dp_model_class} instance with a standard instance
|
||||
of `tf.keras.Optimizer` as the optimizer, and a standard reduced loss.
|
||||
You do not need to use a differentially private optimizer.
|
||||
```python
|
||||
# Create Model instance.
|
||||
model = {dp_model_class}(l2_norm_clip=1.0, noise_multiplier=0.5, use_xla=True,
|
||||
<standard arguments>)
|
||||
```
|
||||
|
||||
```python
|
||||
# Use a standard (non-DP) optimizer.
|
||||
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
|
||||
You should use your {dp_model_class} instance with a standard instance
|
||||
of `tf.keras.Optimizer` as the optimizer, and a standard reduced loss.
|
||||
You do not need to use a differentially private optimizer.
|
||||
|
||||
# Use a standard reduced loss.
|
||||
loss = tf.keras.losses.MeanSquaredError()
|
||||
```python
|
||||
# Use a standard (non-DP) optimizer.
|
||||
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
|
||||
|
||||
model.compile(optimizer=optimizer, loss=loss)
|
||||
model.fit(train_data, train_labels, epochs=1, batch_size=32)
|
||||
```
|
||||
# Use a standard reduced loss.
|
||||
loss = tf.keras.losses.MeanSquaredError()
|
||||
|
||||
"""
|
||||
model.compile(optimizer=optimizer, loss=loss)
|
||||
model.fit(train_data, train_labels, epochs=1, batch_size=32)
|
||||
```
|
||||
|
||||
"""
|
||||
).format(
|
||||
base_model='tf.keras.' + cls.__name__,
|
||||
short_base_model=cls.__name__,
|
||||
|
@ -115,6 +126,7 @@ def make_dp_model_class(cls):
|
|||
if isinstance(num_microbatches, bool):
|
||||
raise ValueError('Boolean value supplied for `num_microbatches`. '
|
||||
'Did you intend it for `use_xla`?')
|
||||
self._num_microbatches = num_microbatches
|
||||
|
||||
# If all the trainable layers are in the input layer registry, we
|
||||
# don't need to use microbatching and can instead use the "fast"
|
||||
|
@ -126,16 +138,8 @@ def make_dp_model_class(cls):
|
|||
)
|
||||
and gradient_clipping_utils.has_internal_compute_graph(self)
|
||||
):
|
||||
if num_microbatches is not None:
|
||||
raise ValueError(
|
||||
'Cannot initialize a model where num_microbatches '
|
||||
'is not `None` and all trainable layers are '
|
||||
'registered in layer_registry.'
|
||||
)
|
||||
self._num_microbatches = None
|
||||
self._enable_fast_peg_computation = True
|
||||
else:
|
||||
self._num_microbatches = num_microbatches
|
||||
self._enable_fast_peg_computation = False
|
||||
|
||||
if use_xla:
|
||||
|
@ -198,10 +202,20 @@ def make_dp_model_class(cls):
|
|||
# trick, and uses these norms to clip the per-example gradients.
|
||||
x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data)
|
||||
y_pred, clipped_grads = clip_grads.compute_pred_and_clipped_gradients(
|
||||
self, x, y, self._l2_norm_clip, self._layer_registry
|
||||
self,
|
||||
x,
|
||||
y,
|
||||
self._l2_norm_clip,
|
||||
self._layer_registry,
|
||||
self._num_microbatches,
|
||||
)
|
||||
batch_size = self._num_microbatches or tf.shape(y)[0]
|
||||
grads = gradient_clipping_utils.add_aggregate_noise(
|
||||
self, x, clipped_grads, self._l2_norm_clip, self._noise_multiplier
|
||||
self,
|
||||
clipped_grads,
|
||||
batch_size,
|
||||
self._l2_norm_clip,
|
||||
self._noise_multiplier,
|
||||
)
|
||||
else:
|
||||
logging.info('Computing gradients using microbatching.')
|
||||
|
|
|
@ -139,9 +139,7 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
|||
train_labels = np.array([[1.0], [3.0], [-2.0], [-4.0]])
|
||||
learning_rate = 1.0
|
||||
|
||||
for test_reg, test_nm in zip(
|
||||
get_layer_registries(), [num_microbatches, None]
|
||||
):
|
||||
for test_reg in get_layer_registries():
|
||||
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
|
||||
loss = tf.keras.losses.MeanSquaredError()
|
||||
|
||||
|
@ -149,7 +147,7 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
|||
model = dp_keras_model.DPSequential(
|
||||
l2_norm_clip=l2_norm_clip,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=test_nm,
|
||||
num_microbatches=num_microbatches,
|
||||
layer_registry=test_reg,
|
||||
layers=[
|
||||
tf.keras.layers.InputLayer(input_shape=(2,)),
|
||||
|
@ -173,10 +171,11 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
|||
train_data, train_labels, w, l2_norm_clip, effective_num_microbatches
|
||||
)
|
||||
expected_weights = np.squeeze(-learning_rate * expected_grads)
|
||||
|
||||
self.assertAllClose(model_weights, expected_weights)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('noise_multiplier 3 2 None', 3.0, 2.0, None),
|
||||
('noise_multiplier 5 4 None', 5.0, 4.0, None),
|
||||
('noise_multiplier 3 2 1', 3.0, 2.0, 1),
|
||||
('noise_multiplier 5 4 1', 5.0, 4.0, 1),
|
||||
('noise_multiplier 3 2 2', 3.0, 2.0, 2),
|
||||
|
@ -198,9 +197,7 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
learning_rate = 1.0
|
||||
|
||||
for test_reg, test_nm in zip(
|
||||
get_layer_registries(), [num_microbatches, None]
|
||||
):
|
||||
for test_reg in get_layer_registries():
|
||||
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
|
||||
loss = tf.keras.losses.MeanSquaredError()
|
||||
|
||||
|
@ -208,7 +205,7 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
|||
model = dp_keras_model.DPSequential(
|
||||
l2_norm_clip=l2_norm_clip,
|
||||
noise_multiplier=noise_multiplier,
|
||||
num_microbatches=test_nm,
|
||||
num_microbatches=num_microbatches,
|
||||
layer_registry=test_reg,
|
||||
layers=[
|
||||
tf.keras.layers.InputLayer(input_shape=(1000,)),
|
||||
|
@ -220,11 +217,7 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
|||
model.compile(optimizer=optimizer, loss=loss)
|
||||
model.fit(train_data, train_labels, epochs=1, batch_size=4)
|
||||
|
||||
effective_num_microbatches = (
|
||||
train_data.shape[0]
|
||||
if model._num_microbatches is None
|
||||
else num_microbatches
|
||||
)
|
||||
effective_num_microbatches = num_microbatches or train_data.shape[0]
|
||||
|
||||
model_weights = model.get_weights()
|
||||
measured_std = np.std(model_weights[0])
|
||||
|
@ -248,16 +241,14 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
|||
train_labels = np.array([[0], [1], [1], [0]])
|
||||
learning_rate = 1.0
|
||||
|
||||
for test_reg, test_nm in zip(
|
||||
get_layer_registries(), [num_microbatches, None]
|
||||
):
|
||||
for test_reg in get_layer_registries():
|
||||
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
|
||||
model = dp_keras_model.DPSequential(
|
||||
l2_norm_clip=1.0e9,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=test_nm,
|
||||
num_microbatches=num_microbatches,
|
||||
layer_registry=test_reg,
|
||||
layers=[
|
||||
tf.keras.layers.InputLayer(input_shape=(2,)),
|
||||
|
|
Loading…
Reference in a new issue