Support weighted losses in DPModel.

PiperOrigin-RevId: 538011437
This commit is contained in:
Walid Krichene 2023-06-05 16:26:38 -07:00 committed by A. Unique TensorFlower
parent 60d237be83
commit 18c43b351b
5 changed files with 297 additions and 217 deletions

View file

@ -69,9 +69,10 @@ def get_registry_generator_fn(
def compute_gradient_norms( def compute_gradient_norms(
input_model: tf.keras.Model, input_model: tf.keras.Model,
layer_registry: lr.LayerRegistry,
x_batch: InputTensor, x_batch: InputTensor,
y_batch: tf.Tensor, y_batch: tf.Tensor,
layer_registry: lr.LayerRegistry, weight_batch: Optional[tf.Tensor] = None,
per_example_loss_fn: Optional[LossFn] = None, per_example_loss_fn: Optional[LossFn] = None,
num_microbatches: Optional[lr.BatchSize] = None, num_microbatches: Optional[lr.BatchSize] = None,
trainable_vars: Optional[List[tf.Variable]] = None, trainable_vars: Optional[List[tf.Variable]] = None,
@ -83,16 +84,19 @@ def compute_gradient_norms(
Args: Args:
input_model: The `tf.keras.Model` from which to obtain the layers from. The input_model: The `tf.keras.Model` from which to obtain the layers from. The
loss of the model *must* be a scalar loss. loss of the model *must* be a scalar loss. When using microbatching, the
loss reduction must be mean.
layer_registry: A `LayerRegistry` instance containing functions that help
compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details.
x_batch: An `InputTensor` representing a batch of inputs to the model. The x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension. first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the must be the batch dimension. The number of examples should match the
number of examples in `x_batch`. number of examples in `x_batch`.
layer_registry: A `LayerRegistry` instance containing functions that help weight_batch: Optional batch of weights, passed to the loss function.
compute gradient norms quickly. See Weights apply to the loss prior to clipping.
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details.
per_example_loss_fn: takes as input predictions, labels and weights, and per_example_loss_fn: takes as input predictions, labels and weights, and
outputs a vector of per-example losses. If None, derived from outputs a vector of per-example losses. If None, derived from
`input_model.loss` by disabling its reduction. `input_model.loss` by disabling its reduction.
@ -108,8 +112,11 @@ def compute_gradient_norms(
variables are included. variables are included.
Returns: Returns:
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th A scalar vector, whose i-th entry is the norm of the gradient of the i-th
per-example loss function. weighted example loss (when num_microbatches is None) or the norm of the
gradient of the i-th microbatch loss (define as a mean over the microbatch).
Note that when the loss is weighted (`weight_batch` is not None), weights
are applied prior to clipping.
""" """
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
registry_generator_fn = get_registry_generator_fn( registry_generator_fn = get_registry_generator_fn(
@ -127,7 +134,7 @@ def compute_gradient_norms(
loss_config = input_model.loss.get_config() loss_config = input_model.loss.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE loss_config['reduction'] = tf.keras.losses.Reduction.NONE
per_example_loss_fn = input_model.loss.from_config(loss_config) per_example_loss_fn = input_model.loss.from_config(loss_config)
losses = per_example_loss_fn(y_batch, model_outputs) losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
if losses.shape is None: if losses.shape is None:
raise NotImplementedError( raise NotImplementedError(
"The unreduced (or per-example) loss's shape cannot be `None`" "The unreduced (or per-example) loss's shape cannot be `None`"
@ -140,7 +147,7 @@ def compute_gradient_norms(
) )
if num_microbatches is not None: if num_microbatches is not None:
losses = tf.reduce_mean( losses = tf.reduce_mean(
lr.add_microbatch_axis(losses, num_microbatches), axis=1 lr.maybe_add_microbatch_axis(losses, num_microbatches), axis=1
) )
summed_loss = tf.reduce_sum(losses) summed_loss = tf.reduce_sum(losses)
# Unwrap the generator outputs so that the next loop avoids duplicating # Unwrap the generator outputs so that the next loop avoids duplicating
@ -165,6 +172,10 @@ def compute_gradient_norms(
vars_list, vars_list,
unconnected_gradients=tf.UnconnectedGradients.ZERO, unconnected_gradients=tf.UnconnectedGradients.ZERO,
) )
if not grads_list:
raise ValueError('The gradient list cannot be empty.')
if len(grads_list) != len(sqr_norm_fns_list):
raise ValueError('There must be as many norms as gradients.')
sqr_norm_list = [] sqr_norm_list = []
for grads, f in zip(grads_list, sqr_norm_fns_list): for grads, f in zip(grads_list, sqr_norm_fns_list):
sqr_norm_list.append(f(grads)) sqr_norm_list.append(f(grads))
@ -199,30 +210,26 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
def compute_clipped_gradients_and_outputs( def compute_clipped_gradients_and_outputs(
input_model: tf.keras.Model, input_model: tf.keras.Model,
x_batch: InputTensor,
y_batch: tf.Tensor,
l2_norm_clip: float, l2_norm_clip: float,
layer_registry: lr.LayerRegistry, layer_registry: lr.LayerRegistry,
x_batch: InputTensor,
y_batch: tf.Tensor,
weight_batch: Optional[tf.Tensor] = None,
num_microbatches: Optional[lr.BatchSize] = None, num_microbatches: Optional[lr.BatchSize] = None,
clipping_loss: Optional[LossFn] = None, clipping_loss: Optional[LossFn] = None,
) -> Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor]: ) -> Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor]:
"""Computes the per-example clipped loss gradient and other useful outputs. """Computes the per-example clipped loss gradient and other useful outputs.
Given a batch of observations `(x_batch, y_batch)`, the main steps of this Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main
function are: (i) compute the l2-norm of the gradients of the trainable steps of this function are:
variables of `input_model` for each example in the batch; (ii) use the norms (i) compute the l2-norm of the gradients w.r.t. the trainable variables of
computed in (i) to obtain "clip_weights" that are used to generate a weighted `input_model`, for each weighted example loss in the batch;
loss function whose gradient for each example has l2-norm at most (ii) use the norms computed in (i) to obtain "clip_weights" that are used to
`l2_norm_clip`; (iii) output the clipped gradients in (ii) and other useful reweight the loss function, such that each gradient of this reweighted loss
outputs to the caller. has l2-norm at most `l2_norm_clip`.
Args: Args:
input_model: The `tf.keras.Model` from which to obtain the layers from. input_model: The `tf.keras.Model` from which to obtain the layers from.
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
number of examples in `x_batch`.
l2_norm_clip: A `float` indicating the norm to which per-example gradients l2_norm_clip: A `float` indicating the norm to which per-example gradients
will be clipped. That is, all gradients of the per-example loss functions will be clipped. That is, all gradients of the per-example loss functions
will have norm at most `l2_norm_clip`. will have norm at most `l2_norm_clip`.
@ -232,6 +239,15 @@ def compute_clipped_gradients_and_outputs(
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the `output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable weights (see `layer_registry_factories.py` for examples). trainable weights (see `layer_registry_factories.py` for examples).
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
number of examples in `x_batch`.
weight_batch: Optional vector of weights, passed to the loss function. Must
be of size [batch_size]. In case of microbatching, this will be reshaped
to [num_microbatches, batch_size/num_microbatches] before passing it to
the loss.
num_microbatches: An optional number or scalar `tf.Tensor` for the number of num_microbatches: An optional number or scalar `tf.Tensor` for the number of
microbatches. If not None, indicates that the loss is grouped into microbatches. If not None, indicates that the loss is grouped into
num_microbatches (in this case, the batch dimension needs to be a multiple num_microbatches (in this case, the batch dimension needs to be a multiple
@ -243,11 +259,11 @@ def compute_clipped_gradients_and_outputs(
the value of the clipped loss does not reflect the true loss. the value of the clipped loss does not reflect the true loss.
Returns: Returns:
A `tuple` `(grad, y_pred, clipping_loss_value)`. The first element is the clipped_grad: list of the clipped gradients of the loss function (one per
clipped gradient of the loss function, the second is the result of trainable variable in `input_model`).
applying `input_model` to `x_batch`, and the third is loss value of y_pred: the result of applying `input_model` to `x_batch`.
`input_model`, weighted by the loss weights generated by a specific clipping_loss_value: the loss value weighted in such a way that its gradient
`compute_clip_weights()` call. is `clipped_grad`.
""" """
if input_model.loss.reduction == 'none': if input_model.loss.reduction == 'none':
raise NotImplementedError( raise NotImplementedError(
@ -258,13 +274,26 @@ def compute_clipped_gradients_and_outputs(
clipping_loss = input_model.compiled_loss clipping_loss = input_model.compiled_loss
gradient_norms = compute_gradient_norms( gradient_norms = compute_gradient_norms(
input_model, input_model,
layer_registry,
x_batch, x_batch,
y_batch, y_batch,
layer_registry, weight_batch,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
trainable_vars=input_model.trainable_variables, trainable_vars=input_model.trainable_variables,
) )
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms) clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
if weight_batch is not None:
# Let w be the `weight_batch`, c be the `clip_weights`, and l be the losses.
# c is computed based on the gradient of w*l, so that if we scale w*l by c,
# the result has bounded per-example gradients. So the loss to optimize is
# c*w*l. Here we compute c*w before passing it to the loss.
weight_batch = lr.maybe_add_microbatch_axis(weight_batch, num_microbatches)
if num_microbatches is None:
clip_weights = clip_weights * weight_batch # shape [num_microbatches]
else:
# In this case, weight_batch is of shape [batch_size, microbatch_size],
# we multiply by the clip_weights (which is of shape [num_microbatches])
clip_weights = clip_weights[:, tf.newaxis] * weight_batch
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
# WARNING: When num_microbatches is not None, we need to be sure that # WARNING: When num_microbatches is not None, we need to be sure that
# `compute_loss` always computes the mean over the microbatches # `compute_loss` always computes the mean over the microbatches
@ -274,17 +303,9 @@ def compute_clipped_gradients_and_outputs(
# is not defined in the contract so may not hold, especially for # is not defined in the contract so may not hold, especially for
# custom losses. # custom losses.
y_pred = input_model(x_batch, training=True) y_pred = input_model(x_batch, training=True)
loss_y_batch = ( mb_y_batch = lr.maybe_add_microbatch_axis(y_batch, num_microbatches)
y_batch mb_y_pred = lr.maybe_add_microbatch_axis(y_pred, num_microbatches)
if num_microbatches is None clipping_loss_value = clipping_loss(mb_y_batch, mb_y_pred, clip_weights)
else lr.add_microbatch_axis(y_batch, num_microbatches)
)
loss_y_pred = (
y_pred
if num_microbatches is None
else lr.add_microbatch_axis(y_pred, num_microbatches)
)
clipping_loss_value = clipping_loss(loss_y_batch, loss_y_pred, loss_weights)
clipped_grads = tape.gradient( clipped_grads = tape.gradient(
clipping_loss_value, clipping_loss_value,
input_model.trainable_variables, input_model.trainable_variables,

View file

@ -12,12 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import itertools
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
@ -71,17 +69,23 @@ def double_dense_layer_computation(
return [vars1, vars2], outputs, sqr_norm_fn return [vars1, vars2], outputs, sqr_norm_fn
def test_loss_fn(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor: def test_loss_fn(
x = tf.reshape(x, (tf.shape(x)[0], -1)) x: tf.Tensor, y: tf.Tensor, weights: Optional[tf.Tensor] = None
y = tf.reshape(y, (tf.shape(y)[0], -1)) ) -> tf.Tensor:
# Define a loss function which is unlikely to be coincidently defined. # Define a loss function which is unlikely to be coincidently defined.
return 3.14 * tf.reduce_sum(tf.square(x - y), axis=1) if weights is None:
weights = 1.0
loss = 3.14 * tf.reduce_sum(
tf.cast(weights, tf.float32) * tf.square(x - y), axis=1
)
return loss
def compute_true_gradient_norms( def compute_true_gradient_norms(
input_model: tf.keras.Model, input_model: tf.keras.Model,
x_batch: tf.Tensor, x_batch: tf.Tensor,
y_batch: tf.Tensor, y_batch: tf.Tensor,
weight_batch: Optional[tf.Tensor],
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]], per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
num_microbatches: Optional[int], num_microbatches: Optional[int],
trainable_vars: Optional[tf.Variable] = None, trainable_vars: Optional[tf.Variable] = None,
@ -93,7 +97,7 @@ def compute_true_gradient_norms(
per_example_loss_fn = input_model.loss.from_config(loss_config) per_example_loss_fn = input_model.loss.from_config(loss_config)
with tf.GradientTape(persistent=True) as tape: with tf.GradientTape(persistent=True) as tape:
y_pred = input_model(x_batch) y_pred = input_model(x_batch)
loss = per_example_loss_fn(y_batch, y_pred) loss = per_example_loss_fn(y_batch, y_pred, weight_batch)
if num_microbatches is not None: if num_microbatches is not None:
loss = tf.reduce_mean( loss = tf.reduce_mean(
tf.reshape( tf.reshape(
@ -123,7 +127,8 @@ def get_computed_and_true_norms(
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]], per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
num_microbatches: Optional[int], num_microbatches: Optional[int],
is_eager: bool, is_eager: bool,
x_input: tf.Tensor, x_batch: tf.Tensor,
weight_batch: Optional[tf.Tensor] = None,
rng_seed: int = 777, rng_seed: int = 777,
registry: layer_registry.LayerRegistry = None, registry: layer_registry.LayerRegistry = None,
partial: bool = False, partial: bool = False,
@ -146,10 +151,11 @@ def get_computed_and_true_norms(
per_example_loss_fn: If not None, used as vectorized per example loss per_example_loss_fn: If not None, used as vectorized per example loss
function. function.
num_microbatches: The number of microbatches. None or an integer. num_microbatches: The number of microbatches. None or an integer.
is_eager: A `bool` that is `True` if the model should be run eagerly. is_eager: whether the model should be run eagerly.
x_input: `tf.Tensor` inputs to be tested. x_batch: inputs to be tested.
rng_seed: An `int` used to initialize model weights. weight_batch: optional weights passed to the loss.
registry: A `layer_registry.LayerRegistry` instance. rng_seed: used as a seed for random initialization.
registry: required for fast clipping.
partial: Whether to compute the gradient norm with respect to a partial set partial: Whether to compute the gradient norm with respect to a partial set
of varibles. If True, only consider the variables in the first layer. of varibles. If True, only consider the variables in the first layer.
@ -175,13 +181,14 @@ def get_computed_and_true_norms(
trainable_vars = l.trainable_variables trainable_vars = l.trainable_variables
if trainable_vars: if trainable_vars:
break break
y_pred = model(x_input) y_pred = model(x_batch)
y_batch = tf.ones_like(y_pred) y_batch = tf.ones_like(y_pred)
tf.keras.utils.set_random_seed(rng_seed) tf.keras.utils.set_random_seed(rng_seed)
computed_norms = clip_grads.compute_gradient_norms( computed_norms = clip_grads.compute_gradient_norms(
model, input_model=model,
x_input, x_batch=x_batch,
y_batch, y_batch=y_batch,
weight_batch=weight_batch,
layer_registry=registry, layer_registry=registry,
per_example_loss_fn=per_example_loss_fn, per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
@ -190,8 +197,9 @@ def get_computed_and_true_norms(
tf.keras.utils.set_random_seed(rng_seed) tf.keras.utils.set_random_seed(rng_seed)
true_norms = compute_true_gradient_norms( true_norms = compute_true_gradient_norms(
model, model,
x_input, x_batch,
y_batch, y_batch,
weight_batch,
per_example_loss_fn, per_example_loss_fn,
num_microbatches, num_microbatches,
trainable_vars=trainable_vars, trainable_vars=trainable_vars,
@ -309,24 +317,16 @@ def make_weighted_bow_model(layer_generator, input_dims, output_dim):
# ============================================================================== # ==============================================================================
# Factory functions. # Factory functions.
# ============================================================================== # ==============================================================================
def get_nd_test_tensors(n: int):
"""Returns a list of candidate tests for a given dimension n."""
return [
tf.zeros((n,), dtype=tf.float64),
tf.convert_to_tensor(range(n), dtype_hint=tf.float64),
]
def get_nd_test_batches(n: int): def get_nd_test_batches(n: int):
"""Returns a list of candidate input batches of dimension n.""" """Returns a list of input batches of dimension n."""
result = [] # The first two batches have a single element, the last batch has 2 elements.
tensors = get_nd_test_tensors(n) x0 = tf.zeros([1, n], dtype=tf.float64)
for batch_size in range(1, len(tensors) + 1, 1): x1 = tf.constant([range(n)], dtype=tf.float64)
combinations = list( x2 = tf.concat([x0, x1], axis=0)
itertools.combinations(get_nd_test_tensors(n), batch_size) w0 = tf.constant([1], dtype=tf.float64)
) w1 = tf.constant([2], dtype=tf.float64)
result = result + [tf.stack(ts, axis=0) for ts in combinations] w2 = tf.constant([0.5, 0.5], dtype=tf.float64)
return result return [x0, x1, x2], [w0, w1, w2]
def get_dense_layer_generators(): def get_dense_layer_generators():
@ -366,11 +366,14 @@ class ClipGradsDirectTest(tf.test.TestCase, parameterized.TestCase):
) )
def test_clip_weights(self, input_dim, clip_value): def test_clip_weights(self, input_dim, clip_value):
tol = 1e-6 tol = 1e-6
for t in get_nd_test_tensors(input_dim): ts, _ = get_nd_test_batches(input_dim)
self.assertIsNone(clip_grads.compute_clip_weights(None, t)) for t in ts:
weights = clip_grads.compute_clip_weights(clip_value, t) weights = clip_grads.compute_clip_weights(clip_value, t)
self.assertAllLessEqual(t * weights, clip_value + tol) self.assertAllLessEqual(t * weights, clip_value + tol)
def test_clip_weights_none(self):
self.assertIsNone(clip_grads.compute_clip_weights(None, tf.ones(3)))
class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase): class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
@ -383,6 +386,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
num_microbatches=[None, 1, 2], num_microbatches=[None, 1, 2],
is_eager=[True, False], is_eager=[True, False],
partial=[True, False], partial=[True, False],
weighted=[True, False],
) )
def test_gradient_norms_on_various_models( def test_gradient_norms_on_various_models(
self, self,
@ -394,21 +398,16 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
num_microbatches, num_microbatches,
is_eager, is_eager,
partial, partial,
weighted,
): ):
model_generator = get_dense_model_generators()[model_name] model_generator = get_dense_model_generators()[model_name]
layer_generator = get_dense_layer_generators()[layer_name] layer_generator = get_dense_layer_generators()[layer_name]
x_batches = get_nd_test_batches(input_dim) x_batches, weight_batches = get_nd_test_batches(input_dim)
default_registry = layer_registry.make_default_layer_registry() default_registry = layer_registry.make_default_layer_registry()
for x_batch in x_batches: for x_batch, weight_batch in zip(x_batches, weight_batches):
if ( batch_size = x_batch.shape[0]
num_microbatches is not None if num_microbatches is not None and batch_size % num_microbatches != 0:
and x_batch.shape[0] % num_microbatches != 0
):
continue continue
if model_name == 'tower1':
x_input = [x_batch, x_batch]
else:
x_input = x_batch
(computed_norms, true_norms) = get_computed_and_true_norms( (computed_norms, true_norms) = get_computed_and_true_norms(
model_generator, model_generator,
layer_generator, layer_generator,
@ -417,10 +416,13 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
per_example_loss_fn, per_example_loss_fn,
num_microbatches, num_microbatches,
is_eager, is_eager,
x_input, x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch,
weight_batch=weight_batch if weighted else None,
registry=default_registry, registry=default_registry,
partial=partial, partial=partial,
) )
expected_size = num_microbatches or batch_size
self.assertEqual(computed_norms.shape[0], expected_size)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
@ -471,16 +473,14 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
is_eager, is_eager,
partial, partial,
): ):
batch_size = x_batch.shape[0]
# The following are invalid test combinations, and are skipped.
if ( if (
num_microbatches is not None num_microbatches is not None and batch_size % num_microbatches != 0
and x_batch.shape[0] % num_microbatches != 0 ) or (
model_name == 'weighted_bow1' and isinstance(x_batch, tf.RaggedTensor)
): ):
return return
valid_test_input = (
not isinstance(x_batch, tf.RaggedTensor)
and model_name == 'weighted_bow1'
) or (model_name != 'weighted_bow1')
if valid_test_input:
default_registry = layer_registry.make_default_layer_registry() default_registry = layer_registry.make_default_layer_registry()
model_generator = get_embedding_model_generators()[model_name] model_generator = get_embedding_model_generators()[model_name]
(computed_norms, true_norms) = get_computed_and_true_norms( (computed_norms, true_norms) = get_computed_and_true_norms(
@ -491,10 +491,11 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
per_example_loss_fn=per_example_loss_fn, per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
is_eager=is_eager, is_eager=is_eager,
x_input=x_batch, x_batch=x_batch,
registry=default_registry, registry=default_registry,
partial=partial, partial=partial,
) )
self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
@ -507,6 +508,7 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
num_microbatches=[None, 2], num_microbatches=[None, 2],
is_eager=[True, False], is_eager=[True, False],
partial=[True, False], partial=[True, False],
weighted=[True, False],
) )
def test_gradient_norms_on_various_models( def test_gradient_norms_on_various_models(
self, self,
@ -516,15 +518,14 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
num_microbatches, num_microbatches,
is_eager, is_eager,
partial, partial,
weighted,
): ):
registry = layer_registry.make_default_layer_registry() registry = layer_registry.make_default_layer_registry()
registry.insert(DoubleDense, double_dense_layer_computation) registry.insert(DoubleDense, double_dense_layer_computation)
x_batches = get_nd_test_batches(input_dim) x_batches, weight_batches = get_nd_test_batches(input_dim)
for x_batch in x_batches: for x_batch, weight_batch in zip(x_batches, weight_batches):
if ( batch_size = x_batch.shape[0]
num_microbatches is not None if num_microbatches is not None and batch_size % num_microbatches != 0:
and x_batch.shape[0] % num_microbatches != 0
):
continue continue
(computed_norms, true_norms) = get_computed_and_true_norms( (computed_norms, true_norms) = get_computed_and_true_norms(
model_generator=make_two_layer_sequential_model, model_generator=make_two_layer_sequential_model,
@ -534,10 +535,12 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
per_example_loss_fn=per_example_loss_fn, per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
is_eager=is_eager, is_eager=is_eager,
x_input=x_batch, x_batch=x_batch,
weight_batch=weight_batch if weighted else None,
registry=registry, registry=registry,
partial=partial, partial=partial,
) )
self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
@ -574,16 +577,13 @@ class ClipGradsComputeClippedGradsAndOutputsTest(
) )
# Stop early for efficiency. # Stop early for efficiency.
if reduction == 'none': if reduction == 'none':
self.assertRaises( with self.assertRaises(NotImplementedError):
NotImplementedError, clip_grads.compute_clipped_gradients_and_outputs(
# function tested
clip_grads.compute_clipped_gradients_and_outputs,
# function args
self._model, self._model,
x_batch,
y_batch,
l2_norm_clip, l2_norm_clip,
layer_registry.make_default_layer_registry(), layer_registry.make_default_layer_registry(),
x_batch,
y_batch,
) )
return return
# NOTE: losses from this point are scalar losses. # NOTE: losses from this point are scalar losses.
@ -593,10 +593,10 @@ class ClipGradsComputeClippedGradsAndOutputsTest(
true_grads = tape.gradient(loss_value, self._model.trainable_variables) true_grads = tape.gradient(loss_value, self._model.trainable_variables)
clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs( clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs(
self._model, self._model,
x_batch,
y_batch,
l2_norm_clip, l2_norm_clip,
layer_registry.make_default_layer_registry(), layer_registry.make_default_layer_registry(),
x_batch,
y_batch,
) )
# Computes the L2 norm manually. # Computes the L2 norm manually.

View file

@ -105,26 +105,23 @@ class LayerRegistry:
# ============================================================================== # ==============================================================================
# Utilities # Utilities
# ============================================================================== # ==============================================================================
def add_microbatch_axis( def maybe_add_microbatch_axis(
x: tf.Tensor, x: tf.Tensor,
num_microbatches: Optional[BatchSize], num_microbatches: Optional[BatchSize],
) -> tf.Tensor: ) -> tf.Tensor:
"""Adds the microbatch axis. """Adds the microbatch axis.
Reshape the input tensor to replace the first(batch) dimension with the
shape [num_microbatches, batch_size / num_microbatches]. The batch size
must be a multiple of num_microbatches (unless it is None, meaning
num_microbatches is the same as the batch size).
Args: Args:
x: the input tensor. x: the input tensor.
num_microbatches: None or a numeric value or a scalar `tf.Tensor`. num_microbatches: If None, x is returned unchanged. Otherwise, must divide
the batch size.
Returns: Returns:
The reshaped input tensor. The input tensor x, reshaped from [batch_size, ...] to
[num_microbatches, batch_size / num_microbatches, ...].
""" """
if num_microbatches is None: if num_microbatches is None:
return tf.expand_dims(x, 1) return x
with tf.control_dependencies( with tf.control_dependencies(
[tf.assert_equal(tf.math.floormod(tf.shape(x)[0], num_microbatches), 0)] [tf.assert_equal(tf.math.floormod(tf.shape(x)[0], num_microbatches), 0)]
): ):
@ -193,7 +190,7 @@ def dense_layer_computation(
def _compute_gramian(x): def _compute_gramian(x):
if num_microbatches is not None: if num_microbatches is not None:
x_microbatched = add_microbatch_axis(x, num_microbatches) x_microbatched = maybe_add_microbatch_axis(x, num_microbatches)
return tf.matmul(x_microbatched, x_microbatched, transpose_b=True) return tf.matmul(x_microbatched, x_microbatched, transpose_b=True)
else: else:
# Special handling for better efficiency # Special handling for better efficiency

View file

@ -15,7 +15,6 @@
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
@ -179,15 +178,19 @@ def make_dp_model_class(cls):
tf.shape(stacked_grads)[0], summed_grads.dtype tf.shape(stacked_grads)[0], summed_grads.dtype
) )
def _compute_per_example_grads(self, data): def _compute_per_example_grads(self, microbatched_data):
if self._clipping_loss is None: if self._clipping_loss is None:
self._make_clipping_loss() self._make_clipping_loss()
microbatched_x, microbatched_y = data microbatched_x, microbatched_y, microbatched_weights = (
tf.keras.utils.unpack_x_y_sample_weight(microbatched_data)
)
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
microbatched_y_pred = self(microbatched_x, training=True) microbatched_y_pred = self(microbatched_x, training=True)
# NOTE: `self._clipping_loss` does not include any regularization terms. # NOTE: `self._clipping_loss` does not include any regularization terms.
microbatched_loss = self._clipping_loss( microbatched_loss = self._clipping_loss(
microbatched_y, microbatched_y_pred microbatched_y,
microbatched_y_pred,
sample_weight=microbatched_weights,
) )
grads_list = tape.gradient(microbatched_loss, self.trainable_variables) grads_list = tape.gradient(microbatched_loss, self.trainable_variables)
clipped_grads = self._process_per_example_grads(grads_list) clipped_grads = self._process_per_example_grads(grads_list)
@ -232,12 +235,8 @@ def make_dp_model_class(cls):
self._make_clipping_loss() self._make_clipping_loss()
output_metrics = {} output_metrics = {}
x, y, weights = tf.keras.utils.unpack_x_y_sample_weight(data) x, y, weights = tf.keras.utils.unpack_x_y_sample_weight(data)
if weights is not None:
raise NotImplementedError(
'DPModel does not currently support weighted losses.'
)
batch_size = tf.shape(y)[0] batch_size = tf.shape(y)[0]
eff_num_microbatches = self._num_microbatches or batch_size num_microbatches = self._num_microbatches or batch_size
# Branch based on gradient clipping algorithm. # Branch based on gradient clipping algorithm.
if self._enable_fast_peg_computation: if self._enable_fast_peg_computation:
@ -251,13 +250,14 @@ def make_dp_model_class(cls):
# microbatches is done here. # microbatches is done here.
clipped_grads, y_pred, clipping_loss = ( clipped_grads, y_pred, clipping_loss = (
clip_grads.compute_clipped_gradients_and_outputs( clip_grads.compute_clipped_gradients_and_outputs(
self, input_model=self,
x, x_batch=x,
y, y_batch=y,
self._l2_norm_clip, weight_batch=weights,
self._layer_registry, l2_norm_clip=self._l2_norm_clip,
self._num_microbatches, layer_registry=self._layer_registry,
self._clipping_loss, num_microbatches=self._num_microbatches,
clipping_loss=self._clipping_loss,
) )
) )
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
@ -265,7 +265,7 @@ def make_dp_model_class(cls):
grads = gradient_clipping_utils.add_aggregate_noise( grads = gradient_clipping_utils.add_aggregate_noise(
self, self,
clipped_grads, clipped_grads,
eff_num_microbatches, num_microbatches,
self._l2_norm_clip, self._l2_norm_clip,
self._noise_multiplier, self._noise_multiplier,
) )
@ -276,7 +276,7 @@ def make_dp_model_class(cls):
# Computes per-example clipped gradients directly. This is called # Computes per-example clipped gradients directly. This is called
# if at least one of the layers cannot use the "fast" gradient clipping # if at least one of the layers cannot use the "fast" gradient clipping
# algorithm. # algorithm.
reshape_fn = lambda z: lr.add_microbatch_axis(z, eff_num_microbatches) reshape_fn = lambda z: lr.maybe_add_microbatch_axis(z, num_microbatches)
microbatched_data = tf.nest.map_structure(reshape_fn, data) microbatched_data = tf.nest.map_structure(reshape_fn, data)
clipped_grads = tf.vectorized_map( clipped_grads = tf.vectorized_map(
self._compute_per_example_grads, self._compute_per_example_grads,
@ -305,7 +305,9 @@ def make_dp_model_class(cls):
output_metrics[_PRIVATIZED_LOSS_NAME] += summed_regularization_loss output_metrics[_PRIVATIZED_LOSS_NAME] += summed_regularization_loss
# Log the true loss, including regularization losses. # Log the true loss, including regularization losses.
self.compiled_loss(y, y_pred, regularization_losses=self.losses) self.compiled_loss(
y, y_pred, sample_weight=weights, regularization_losses=self.losses
)
# Forward the private gradients to the optimizer and return the results. # Forward the private gradients to the optimizer and return the results.
self.optimizer.apply_gradients(zip(grads, self.trainable_variables)) self.optimizer.apply_gradients(zip(grads, self.trainable_variables))

View file

@ -21,8 +21,11 @@ from tensorflow_privacy.privacy.keras_models import dp_keras_model
def get_data(): def get_data():
# Data is for hidden weights of [3, 1] and bias of 2. # Data is for hidden weights of [3, 1] and bias of 2.
# With mean squared loss, we expect loss = 15^2 = 225, gradients of # Loss is (w.x + b - y)^2, model is initialized at (w, b) = (0, 0).
# weights = [90, 120], and gradient of bias = 30. # y = 15
# Loss: y^2 = 15^2 = 225
# Gradient w.r.t. w = -2yx = [90, 120]
# Gradient w.r.t. b = -2y = 30
data = np.array([[3, 4]]) data = np.array([[3, 4]])
labels = np.matmul(data, [[3], [1]]) + 2 labels = np.matmul(data, [[3], [1]]) + 2
return data, labels return data, labels
@ -41,8 +44,10 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
layers=[ layers=[
tf.keras.layers.InputLayer(input_shape=(2,)), tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense( tf.keras.layers.Dense(
1, kernel_initializer='zeros', bias_initializer='zeros') 1, kernel_initializer='zeros', bias_initializer='zeros'
]) ),
],
)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01) optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
loss = tf.keras.losses.MeanSquaredError() loss = tf.keras.losses.MeanSquaredError()
@ -58,101 +63,149 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product( @parameterized.product(
l2_norm_clip=(10.0, 40.0, 200.0), l2_norm_clip=(10.0, 40.0, 200.0),
fast_clipping=(True, False), fast_clipping=(True, False),
sequential=(True, False),
weighted=(False, True),
) )
def testClippingNorm(self, l2_norm_clip, fast_clipping): def testClippingNorm(self, l2_norm_clip, fast_clipping, sequential, weighted):
"""Tests that clipping norm works.""" """Tests that clipping norm works."""
train_data, train_labels = get_data() train_data, train_labels = get_data()
# Simple linear model returns w * x + b. # Simple linear model returns w * x + b.
layer = tf.keras.layers.Dense(
1, kernel_initializer='zeros', bias_initializer='zeros'
)
if sequential:
model = dp_keras_model.DPSequential( model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip, l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0, noise_multiplier=0.0,
layer_registry=layer_registry.make_default_layer_registry() layer_registry=layer_registry.make_default_layer_registry()
if fast_clipping if fast_clipping
else None, else None,
layers=[ layers=[tf.keras.layers.InputLayer(input_shape=(2,)), layer],
tf.keras.layers.InputLayer(input_shape=(2,)), )
tf.keras.layers.Dense( else:
1, kernel_initializer='zeros', bias_initializer='zeros' inputs = tf.keras.Input(shape=(2,), dtype=tf.float32)
), outputs = layer(inputs)
], model = dp_keras_model.DPModel(
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
layer_registry=layer_registry.make_default_layer_registry()
if fast_clipping
else None,
inputs=inputs,
outputs=outputs,
) )
learning_rate = 0.01 learning_rate = 0.01
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError() loss = tf.keras.losses.MeanSquaredError()
model.compile(optimizer=optimizer, loss=loss) model.compile(optimizer=optimizer, loss=loss)
expected_loss = loss(train_labels, model(train_data))
results = model.fit(train_data, train_labels, epochs=1, batch_size=1)
model_weights = model.get_weights() weights = None
data = tf.data.Dataset.from_tensors((train_data, train_labels))
expected_grad_w = np.array([90.0, 120.0])
expected_grad_b = np.array([30.0])
if weighted:
# Apply a weight to the (single) example.
weights = [0.18]
data = tf.data.Dataset.from_tensors((train_data, train_labels, weights))
expected_grad_w *= 0.18
expected_grad_b *= 0.18
unclipped_gradient = np.sqrt(90**2 + 120**2 + 30**2) unclipped_norm = np.linalg.norm(
scale = min(1.0, l2_norm_clip / unclipped_gradient) np.concatenate([expected_grad_w, expected_grad_b])
expected_weights = np.array([[90], [120]]) * scale * learning_rate )
expected_bias = np.array([30]) * scale * learning_rate scale = min(1.0, l2_norm_clip / unclipped_norm)
expected_weights = expected_grad_w * scale * learning_rate
expected_bias = expected_grad_b * scale * learning_rate
expected_loss = loss(train_labels, model(train_data), weights)
results = model.fit(data, epochs=1, batch_size=1)
weights, bias = model.get_weights()
# Check parameters are as expected, taking into account the learning rate. # Check parameters are as expected, taking into account the learning rate.
self.assertAllClose(model_weights[0], expected_weights) self.assertAllClose(np.squeeze(weights), expected_weights)
self.assertAllClose(model_weights[1], expected_bias) self.assertAllClose(bias, expected_bias)
# Check the value of the loss. # Check the value of the loss.
actual_loss = results.history['loss'][0] actual_loss = results.history['loss'][0]
self.assertAllClose(expected_loss, actual_loss) self.assertAllClose(expected_loss, actual_loss)
def _compute_expected_gradients(self, data, labels, w, l2_norm_clip, def _compute_expected_gradients(
num_microbatches): self,
data,
labels,
weights,
w0,
l2_norm_clip,
num_microbatches,
):
if weights is None:
weights = np.array([1], dtype=np.float32)
batch_size = data.shape[0] batch_size = data.shape[0]
if num_microbatches is None: if num_microbatches is None:
num_microbatches = batch_size num_microbatches = batch_size
preds = np.matmul(data, w0[:, np.newaxis])
preds = np.matmul(data, np.expand_dims(w, axis=1)) grads = 2 * data * (preds - labels) * weights[:, np.newaxis]
grads = np.reshape(
grads = 2 * data * (preds - labels) grads, [num_microbatches, batch_size // num_microbatches, -1]
)
grads = np.reshape(grads,
[num_microbatches, batch_size // num_microbatches, -1])
mb_grads = np.mean(grads, axis=1) mb_grads = np.mean(grads, axis=1)
mb_grad_norms = np.linalg.norm(mb_grads, axis=1) mb_grad_norms = np.linalg.norm(mb_grads, axis=1)
scale = np.minimum(l2_norm_clip / mb_grad_norms, 1.0) scale = np.minimum(l2_norm_clip / mb_grad_norms, 1.0)
mb_grads = mb_grads * scale[:, np.newaxis] mb_grads = mb_grads * scale[:, np.newaxis]
final_grads = np.mean(mb_grads, axis=0) final_grads = np.mean(mb_grads, axis=0)
return final_grads return final_grads
@parameterized.product( @parameterized.product(
num_microbatches=(None, 1, 2, 4), num_microbatches=(None, 1, 2, 4),
fast_clipping=(False, True), fast_clipping=(True, False),
sequential=(False, True),
weighted=(True, False),
) )
def testMicrobatches(self, num_microbatches, fast_clipping): def testMicrobatches(
self, num_microbatches, fast_clipping, sequential, weighted
):
l2_norm_clip = 1.0 l2_norm_clip = 1.0
train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]]) train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
w = np.zeros((2))
train_labels = np.array([[1.0], [3.0], [-2.0], [-4.0]]) train_labels = np.array([[1.0], [3.0], [-2.0], [-4.0]])
if weighted:
train_weights = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)
dataset = tf.data.Dataset.from_tensors(
(train_data, train_labels, train_weights)
)
else:
train_weights = None
dataset = tf.data.Dataset.from_tensors((train_data, train_labels))
learning_rate = 1.0 learning_rate = 1.0
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError() loss = tf.keras.losses.MeanSquaredError()
# Simple linear model returns w * x. # Simple linear model returns w * x.
layer = tf.keras.layers.Dense(1, use_bias=False, kernel_initializer='zeros')
registry = (
layer_registry.make_default_layer_registry() if fast_clipping else None
)
if sequential:
model = dp_keras_model.DPSequential( model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip, l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0, noise_multiplier=0.0,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
layer_registry=layer_registry.make_default_layer_registry() layer_registry=registry,
if fast_clipping layers=[tf.keras.layers.InputLayer(input_shape=(2,)), layer],
else None, )
layers=[ else:
tf.keras.layers.InputLayer(input_shape=(2,)), inputs = tf.keras.Input(shape=(2,), dtype=tf.float32)
tf.keras.layers.Dense( outputs = layer(inputs)
1, use_bias=False, kernel_initializer='zeros' model = dp_keras_model.DPModel(
), l2_norm_clip=l2_norm_clip,
], noise_multiplier=0.0,
num_microbatches=num_microbatches,
layer_registry=registry,
inputs=inputs,
outputs=outputs,
) )
model.compile(optimizer=optimizer, loss=loss) model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False) model.fit(dataset, epochs=1, batch_size=4, shuffle=False)
model_weights = np.squeeze(model.get_weights()) model_weights = np.squeeze(model.get_weights())
@ -163,7 +216,14 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
) )
expected_grads = self._compute_expected_gradients( expected_grads = self._compute_expected_gradients(
train_data, train_labels, w, l2_norm_clip, effective_num_microbatches train_data,
train_labels,
train_weights,
np.zeros(
2,
),
l2_norm_clip,
effective_num_microbatches,
) )
expected_weights = np.squeeze(-learning_rate * expected_grads) expected_weights = np.squeeze(-learning_rate * expected_grads)
self.assertAllClose(model_weights, expected_weights) self.assertAllClose(model_weights, expected_weights)