Support weighted losses in DPModel.

PiperOrigin-RevId: 538011437
This commit is contained in:
Walid Krichene 2023-06-05 16:26:38 -07:00 committed by A. Unique TensorFlower
parent 60d237be83
commit 18c43b351b
5 changed files with 297 additions and 217 deletions

View file

@ -69,9 +69,10 @@ def get_registry_generator_fn(
def compute_gradient_norms(
input_model: tf.keras.Model,
layer_registry: lr.LayerRegistry,
x_batch: InputTensor,
y_batch: tf.Tensor,
layer_registry: lr.LayerRegistry,
weight_batch: Optional[tf.Tensor] = None,
per_example_loss_fn: Optional[LossFn] = None,
num_microbatches: Optional[lr.BatchSize] = None,
trainable_vars: Optional[List[tf.Variable]] = None,
@ -83,16 +84,19 @@ def compute_gradient_norms(
Args:
input_model: The `tf.keras.Model` from which to obtain the layers from. The
loss of the model *must* be a scalar loss.
loss of the model *must* be a scalar loss. When using microbatching, the
loss reduction must be mean.
layer_registry: A `LayerRegistry` instance containing functions that help
compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details.
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
number of examples in `x_batch`.
layer_registry: A `LayerRegistry` instance containing functions that help
compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details.
weight_batch: Optional batch of weights, passed to the loss function.
Weights apply to the loss prior to clipping.
per_example_loss_fn: takes as input predictions, labels and weights, and
outputs a vector of per-example losses. If None, derived from
`input_model.loss` by disabling its reduction.
@ -108,8 +112,11 @@ def compute_gradient_norms(
variables are included.
Returns:
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
per-example loss function.
A scalar vector, whose i-th entry is the norm of the gradient of the i-th
weighted example loss (when num_microbatches is None) or the norm of the
gradient of the i-th microbatch loss (define as a mean over the microbatch).
Note that when the loss is weighted (`weight_batch` is not None), weights
are applied prior to clipping.
"""
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
registry_generator_fn = get_registry_generator_fn(
@ -127,7 +134,7 @@ def compute_gradient_norms(
loss_config = input_model.loss.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
per_example_loss_fn = input_model.loss.from_config(loss_config)
losses = per_example_loss_fn(y_batch, model_outputs)
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
if losses.shape is None:
raise NotImplementedError(
"The unreduced (or per-example) loss's shape cannot be `None`"
@ -140,7 +147,7 @@ def compute_gradient_norms(
)
if num_microbatches is not None:
losses = tf.reduce_mean(
lr.add_microbatch_axis(losses, num_microbatches), axis=1
lr.maybe_add_microbatch_axis(losses, num_microbatches), axis=1
)
summed_loss = tf.reduce_sum(losses)
# Unwrap the generator outputs so that the next loop avoids duplicating
@ -165,6 +172,10 @@ def compute_gradient_norms(
vars_list,
unconnected_gradients=tf.UnconnectedGradients.ZERO,
)
if not grads_list:
raise ValueError('The gradient list cannot be empty.')
if len(grads_list) != len(sqr_norm_fns_list):
raise ValueError('There must be as many norms as gradients.')
sqr_norm_list = []
for grads, f in zip(grads_list, sqr_norm_fns_list):
sqr_norm_list.append(f(grads))
@ -199,30 +210,26 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
def compute_clipped_gradients_and_outputs(
input_model: tf.keras.Model,
x_batch: InputTensor,
y_batch: tf.Tensor,
l2_norm_clip: float,
layer_registry: lr.LayerRegistry,
x_batch: InputTensor,
y_batch: tf.Tensor,
weight_batch: Optional[tf.Tensor] = None,
num_microbatches: Optional[lr.BatchSize] = None,
clipping_loss: Optional[LossFn] = None,
) -> Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor]:
"""Computes the per-example clipped loss gradient and other useful outputs.
Given a batch of observations `(x_batch, y_batch)`, the main steps of this
function are: (i) compute the l2-norm of the gradients of the trainable
variables of `input_model` for each example in the batch; (ii) use the norms
computed in (i) to obtain "clip_weights" that are used to generate a weighted
loss function whose gradient for each example has l2-norm at most
`l2_norm_clip`; (iii) output the clipped gradients in (ii) and other useful
outputs to the caller.
Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main
steps of this function are:
(i) compute the l2-norm of the gradients w.r.t. the trainable variables of
`input_model`, for each weighted example loss in the batch;
(ii) use the norms computed in (i) to obtain "clip_weights" that are used to
reweight the loss function, such that each gradient of this reweighted loss
has l2-norm at most `l2_norm_clip`.
Args:
input_model: The `tf.keras.Model` from which to obtain the layers from.
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
number of examples in `x_batch`.
l2_norm_clip: A `float` indicating the norm to which per-example gradients
will be clipped. That is, all gradients of the per-example loss functions
will have norm at most `l2_norm_clip`.
@ -232,6 +239,15 @@ def compute_clipped_gradients_and_outputs(
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable weights (see `layer_registry_factories.py` for examples).
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
number of examples in `x_batch`.
weight_batch: Optional vector of weights, passed to the loss function. Must
be of size [batch_size]. In case of microbatching, this will be reshaped
to [num_microbatches, batch_size/num_microbatches] before passing it to
the loss.
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
microbatches. If not None, indicates that the loss is grouped into
num_microbatches (in this case, the batch dimension needs to be a multiple
@ -243,11 +259,11 @@ def compute_clipped_gradients_and_outputs(
the value of the clipped loss does not reflect the true loss.
Returns:
A `tuple` `(grad, y_pred, clipping_loss_value)`. The first element is the
clipped gradient of the loss function, the second is the result of
applying `input_model` to `x_batch`, and the third is loss value of
`input_model`, weighted by the loss weights generated by a specific
`compute_clip_weights()` call.
clipped_grad: list of the clipped gradients of the loss function (one per
trainable variable in `input_model`).
y_pred: the result of applying `input_model` to `x_batch`.
clipping_loss_value: the loss value weighted in such a way that its gradient
is `clipped_grad`.
"""
if input_model.loss.reduction == 'none':
raise NotImplementedError(
@ -258,13 +274,26 @@ def compute_clipped_gradients_and_outputs(
clipping_loss = input_model.compiled_loss
gradient_norms = compute_gradient_norms(
input_model,
layer_registry,
x_batch,
y_batch,
layer_registry,
weight_batch,
num_microbatches=num_microbatches,
trainable_vars=input_model.trainable_variables,
)
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
if weight_batch is not None:
# Let w be the `weight_batch`, c be the `clip_weights`, and l be the losses.
# c is computed based on the gradient of w*l, so that if we scale w*l by c,
# the result has bounded per-example gradients. So the loss to optimize is
# c*w*l. Here we compute c*w before passing it to the loss.
weight_batch = lr.maybe_add_microbatch_axis(weight_batch, num_microbatches)
if num_microbatches is None:
clip_weights = clip_weights * weight_batch # shape [num_microbatches]
else:
# In this case, weight_batch is of shape [batch_size, microbatch_size],
# we multiply by the clip_weights (which is of shape [num_microbatches])
clip_weights = clip_weights[:, tf.newaxis] * weight_batch
with tf.GradientTape() as tape:
# WARNING: When num_microbatches is not None, we need to be sure that
# `compute_loss` always computes the mean over the microbatches
@ -274,17 +303,9 @@ def compute_clipped_gradients_and_outputs(
# is not defined in the contract so may not hold, especially for
# custom losses.
y_pred = input_model(x_batch, training=True)
loss_y_batch = (
y_batch
if num_microbatches is None
else lr.add_microbatch_axis(y_batch, num_microbatches)
)
loss_y_pred = (
y_pred
if num_microbatches is None
else lr.add_microbatch_axis(y_pred, num_microbatches)
)
clipping_loss_value = clipping_loss(loss_y_batch, loss_y_pred, loss_weights)
mb_y_batch = lr.maybe_add_microbatch_axis(y_batch, num_microbatches)
mb_y_pred = lr.maybe_add_microbatch_axis(y_pred, num_microbatches)
clipping_loss_value = clipping_loss(mb_y_batch, mb_y_pred, clip_weights)
clipped_grads = tape.gradient(
clipping_loss_value,
input_model.trainable_variables,

View file

@ -12,12 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
@ -71,17 +69,23 @@ def double_dense_layer_computation(
return [vars1, vars2], outputs, sqr_norm_fn
def test_loss_fn(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
x = tf.reshape(x, (tf.shape(x)[0], -1))
y = tf.reshape(y, (tf.shape(y)[0], -1))
def test_loss_fn(
x: tf.Tensor, y: tf.Tensor, weights: Optional[tf.Tensor] = None
) -> tf.Tensor:
# Define a loss function which is unlikely to be coincidently defined.
return 3.14 * tf.reduce_sum(tf.square(x - y), axis=1)
if weights is None:
weights = 1.0
loss = 3.14 * tf.reduce_sum(
tf.cast(weights, tf.float32) * tf.square(x - y), axis=1
)
return loss
def compute_true_gradient_norms(
input_model: tf.keras.Model,
x_batch: tf.Tensor,
y_batch: tf.Tensor,
weight_batch: Optional[tf.Tensor],
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
num_microbatches: Optional[int],
trainable_vars: Optional[tf.Variable] = None,
@ -93,7 +97,7 @@ def compute_true_gradient_norms(
per_example_loss_fn = input_model.loss.from_config(loss_config)
with tf.GradientTape(persistent=True) as tape:
y_pred = input_model(x_batch)
loss = per_example_loss_fn(y_batch, y_pred)
loss = per_example_loss_fn(y_batch, y_pred, weight_batch)
if num_microbatches is not None:
loss = tf.reduce_mean(
tf.reshape(
@ -123,7 +127,8 @@ def get_computed_and_true_norms(
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
num_microbatches: Optional[int],
is_eager: bool,
x_input: tf.Tensor,
x_batch: tf.Tensor,
weight_batch: Optional[tf.Tensor] = None,
rng_seed: int = 777,
registry: layer_registry.LayerRegistry = None,
partial: bool = False,
@ -146,10 +151,11 @@ def get_computed_and_true_norms(
per_example_loss_fn: If not None, used as vectorized per example loss
function.
num_microbatches: The number of microbatches. None or an integer.
is_eager: A `bool` that is `True` if the model should be run eagerly.
x_input: `tf.Tensor` inputs to be tested.
rng_seed: An `int` used to initialize model weights.
registry: A `layer_registry.LayerRegistry` instance.
is_eager: whether the model should be run eagerly.
x_batch: inputs to be tested.
weight_batch: optional weights passed to the loss.
rng_seed: used as a seed for random initialization.
registry: required for fast clipping.
partial: Whether to compute the gradient norm with respect to a partial set
of varibles. If True, only consider the variables in the first layer.
@ -175,13 +181,14 @@ def get_computed_and_true_norms(
trainable_vars = l.trainable_variables
if trainable_vars:
break
y_pred = model(x_input)
y_pred = model(x_batch)
y_batch = tf.ones_like(y_pred)
tf.keras.utils.set_random_seed(rng_seed)
computed_norms = clip_grads.compute_gradient_norms(
model,
x_input,
y_batch,
input_model=model,
x_batch=x_batch,
y_batch=y_batch,
weight_batch=weight_batch,
layer_registry=registry,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
@ -190,8 +197,9 @@ def get_computed_and_true_norms(
tf.keras.utils.set_random_seed(rng_seed)
true_norms = compute_true_gradient_norms(
model,
x_input,
x_batch,
y_batch,
weight_batch,
per_example_loss_fn,
num_microbatches,
trainable_vars=trainable_vars,
@ -309,24 +317,16 @@ def make_weighted_bow_model(layer_generator, input_dims, output_dim):
# ==============================================================================
# Factory functions.
# ==============================================================================
def get_nd_test_tensors(n: int):
"""Returns a list of candidate tests for a given dimension n."""
return [
tf.zeros((n,), dtype=tf.float64),
tf.convert_to_tensor(range(n), dtype_hint=tf.float64),
]
def get_nd_test_batches(n: int):
"""Returns a list of candidate input batches of dimension n."""
result = []
tensors = get_nd_test_tensors(n)
for batch_size in range(1, len(tensors) + 1, 1):
combinations = list(
itertools.combinations(get_nd_test_tensors(n), batch_size)
)
result = result + [tf.stack(ts, axis=0) for ts in combinations]
return result
"""Returns a list of input batches of dimension n."""
# The first two batches have a single element, the last batch has 2 elements.
x0 = tf.zeros([1, n], dtype=tf.float64)
x1 = tf.constant([range(n)], dtype=tf.float64)
x2 = tf.concat([x0, x1], axis=0)
w0 = tf.constant([1], dtype=tf.float64)
w1 = tf.constant([2], dtype=tf.float64)
w2 = tf.constant([0.5, 0.5], dtype=tf.float64)
return [x0, x1, x2], [w0, w1, w2]
def get_dense_layer_generators():
@ -366,11 +366,14 @@ class ClipGradsDirectTest(tf.test.TestCase, parameterized.TestCase):
)
def test_clip_weights(self, input_dim, clip_value):
tol = 1e-6
for t in get_nd_test_tensors(input_dim):
self.assertIsNone(clip_grads.compute_clip_weights(None, t))
ts, _ = get_nd_test_batches(input_dim)
for t in ts:
weights = clip_grads.compute_clip_weights(clip_value, t)
self.assertAllLessEqual(t * weights, clip_value + tol)
def test_clip_weights_none(self):
self.assertIsNone(clip_grads.compute_clip_weights(None, tf.ones(3)))
class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
@ -383,6 +386,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
num_microbatches=[None, 1, 2],
is_eager=[True, False],
partial=[True, False],
weighted=[True, False],
)
def test_gradient_norms_on_various_models(
self,
@ -394,21 +398,16 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
num_microbatches,
is_eager,
partial,
weighted,
):
model_generator = get_dense_model_generators()[model_name]
layer_generator = get_dense_layer_generators()[layer_name]
x_batches = get_nd_test_batches(input_dim)
x_batches, weight_batches = get_nd_test_batches(input_dim)
default_registry = layer_registry.make_default_layer_registry()
for x_batch in x_batches:
if (
num_microbatches is not None
and x_batch.shape[0] % num_microbatches != 0
):
for x_batch, weight_batch in zip(x_batches, weight_batches):
batch_size = x_batch.shape[0]
if num_microbatches is not None and batch_size % num_microbatches != 0:
continue
if model_name == 'tower1':
x_input = [x_batch, x_batch]
else:
x_input = x_batch
(computed_norms, true_norms) = get_computed_and_true_norms(
model_generator,
layer_generator,
@ -417,10 +416,13 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
per_example_loss_fn,
num_microbatches,
is_eager,
x_input,
x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch,
weight_batch=weight_batch if weighted else None,
registry=default_registry,
partial=partial,
)
expected_size = num_microbatches or batch_size
self.assertEqual(computed_norms.shape[0], expected_size)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
@ -471,31 +473,30 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
is_eager,
partial,
):
batch_size = x_batch.shape[0]
# The following are invalid test combinations, and are skipped.
if (
num_microbatches is not None
and x_batch.shape[0] % num_microbatches != 0
num_microbatches is not None and batch_size % num_microbatches != 0
) or (
model_name == 'weighted_bow1' and isinstance(x_batch, tf.RaggedTensor)
):
return
valid_test_input = (
not isinstance(x_batch, tf.RaggedTensor)
and model_name == 'weighted_bow1'
) or (model_name != 'weighted_bow1')
if valid_test_input:
default_registry = layer_registry.make_default_layer_registry()
model_generator = get_embedding_model_generators()[model_name]
(computed_norms, true_norms) = get_computed_and_true_norms(
model_generator=model_generator,
layer_generator=None,
input_dims=x_batch.shape[1:],
output_dim=output_dim,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
is_eager=is_eager,
x_input=x_batch,
registry=default_registry,
partial=partial,
)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
default_registry = layer_registry.make_default_layer_registry()
model_generator = get_embedding_model_generators()[model_name]
(computed_norms, true_norms) = get_computed_and_true_norms(
model_generator=model_generator,
layer_generator=None,
input_dims=x_batch.shape[1:],
output_dim=output_dim,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
is_eager=is_eager,
x_batch=x_batch,
registry=default_registry,
partial=partial,
)
self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
@ -507,6 +508,7 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
num_microbatches=[None, 2],
is_eager=[True, False],
partial=[True, False],
weighted=[True, False],
)
def test_gradient_norms_on_various_models(
self,
@ -516,15 +518,14 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
num_microbatches,
is_eager,
partial,
weighted,
):
registry = layer_registry.make_default_layer_registry()
registry.insert(DoubleDense, double_dense_layer_computation)
x_batches = get_nd_test_batches(input_dim)
for x_batch in x_batches:
if (
num_microbatches is not None
and x_batch.shape[0] % num_microbatches != 0
):
x_batches, weight_batches = get_nd_test_batches(input_dim)
for x_batch, weight_batch in zip(x_batches, weight_batches):
batch_size = x_batch.shape[0]
if num_microbatches is not None and batch_size % num_microbatches != 0:
continue
(computed_norms, true_norms) = get_computed_and_true_norms(
model_generator=make_two_layer_sequential_model,
@ -534,10 +535,12 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
is_eager=is_eager,
x_input=x_batch,
x_batch=x_batch,
weight_batch=weight_batch if weighted else None,
registry=registry,
partial=partial,
)
self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
@ -574,17 +577,14 @@ class ClipGradsComputeClippedGradsAndOutputsTest(
)
# Stop early for efficiency.
if reduction == 'none':
self.assertRaises(
NotImplementedError,
# function tested
clip_grads.compute_clipped_gradients_and_outputs,
# function args
self._model,
x_batch,
y_batch,
l2_norm_clip,
layer_registry.make_default_layer_registry(),
)
with self.assertRaises(NotImplementedError):
clip_grads.compute_clipped_gradients_and_outputs(
self._model,
l2_norm_clip,
layer_registry.make_default_layer_registry(),
x_batch,
y_batch,
)
return
# NOTE: losses from this point are scalar losses.
with tf.GradientTape() as tape:
@ -593,10 +593,10 @@ class ClipGradsComputeClippedGradsAndOutputsTest(
true_grads = tape.gradient(loss_value, self._model.trainable_variables)
clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs(
self._model,
x_batch,
y_batch,
l2_norm_clip,
layer_registry.make_default_layer_registry(),
x_batch,
y_batch,
)
# Computes the L2 norm manually.

View file

@ -105,26 +105,23 @@ class LayerRegistry:
# ==============================================================================
# Utilities
# ==============================================================================
def add_microbatch_axis(
def maybe_add_microbatch_axis(
x: tf.Tensor,
num_microbatches: Optional[BatchSize],
) -> tf.Tensor:
"""Adds the microbatch axis.
Reshape the input tensor to replace the first(batch) dimension with the
shape [num_microbatches, batch_size / num_microbatches]. The batch size
must be a multiple of num_microbatches (unless it is None, meaning
num_microbatches is the same as the batch size).
Args:
x: the input tensor.
num_microbatches: None or a numeric value or a scalar `tf.Tensor`.
num_microbatches: If None, x is returned unchanged. Otherwise, must divide
the batch size.
Returns:
The reshaped input tensor.
The input tensor x, reshaped from [batch_size, ...] to
[num_microbatches, batch_size / num_microbatches, ...].
"""
if num_microbatches is None:
return tf.expand_dims(x, 1)
return x
with tf.control_dependencies(
[tf.assert_equal(tf.math.floormod(tf.shape(x)[0], num_microbatches), 0)]
):
@ -193,7 +190,7 @@ def dense_layer_computation(
def _compute_gramian(x):
if num_microbatches is not None:
x_microbatched = add_microbatch_axis(x, num_microbatches)
x_microbatched = maybe_add_microbatch_axis(x, num_microbatches)
return tf.matmul(x_microbatched, x_microbatched, transpose_b=True)
else:
# Special handling for better efficiency

View file

@ -15,7 +15,6 @@
from absl import logging
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
@ -179,15 +178,19 @@ def make_dp_model_class(cls):
tf.shape(stacked_grads)[0], summed_grads.dtype
)
def _compute_per_example_grads(self, data):
def _compute_per_example_grads(self, microbatched_data):
if self._clipping_loss is None:
self._make_clipping_loss()
microbatched_x, microbatched_y = data
microbatched_x, microbatched_y, microbatched_weights = (
tf.keras.utils.unpack_x_y_sample_weight(microbatched_data)
)
with tf.GradientTape() as tape:
microbatched_y_pred = self(microbatched_x, training=True)
# NOTE: `self._clipping_loss` does not include any regularization terms.
microbatched_loss = self._clipping_loss(
microbatched_y, microbatched_y_pred
microbatched_y,
microbatched_y_pred,
sample_weight=microbatched_weights,
)
grads_list = tape.gradient(microbatched_loss, self.trainable_variables)
clipped_grads = self._process_per_example_grads(grads_list)
@ -232,12 +235,8 @@ def make_dp_model_class(cls):
self._make_clipping_loss()
output_metrics = {}
x, y, weights = tf.keras.utils.unpack_x_y_sample_weight(data)
if weights is not None:
raise NotImplementedError(
'DPModel does not currently support weighted losses.'
)
batch_size = tf.shape(y)[0]
eff_num_microbatches = self._num_microbatches or batch_size
num_microbatches = self._num_microbatches or batch_size
# Branch based on gradient clipping algorithm.
if self._enable_fast_peg_computation:
@ -251,13 +250,14 @@ def make_dp_model_class(cls):
# microbatches is done here.
clipped_grads, y_pred, clipping_loss = (
clip_grads.compute_clipped_gradients_and_outputs(
self,
x,
y,
self._l2_norm_clip,
self._layer_registry,
self._num_microbatches,
self._clipping_loss,
input_model=self,
x_batch=x,
y_batch=y,
weight_batch=weights,
l2_norm_clip=self._l2_norm_clip,
layer_registry=self._layer_registry,
num_microbatches=self._num_microbatches,
clipping_loss=self._clipping_loss,
)
)
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
@ -265,7 +265,7 @@ def make_dp_model_class(cls):
grads = gradient_clipping_utils.add_aggregate_noise(
self,
clipped_grads,
eff_num_microbatches,
num_microbatches,
self._l2_norm_clip,
self._noise_multiplier,
)
@ -276,7 +276,7 @@ def make_dp_model_class(cls):
# Computes per-example clipped gradients directly. This is called
# if at least one of the layers cannot use the "fast" gradient clipping
# algorithm.
reshape_fn = lambda z: lr.add_microbatch_axis(z, eff_num_microbatches)
reshape_fn = lambda z: lr.maybe_add_microbatch_axis(z, num_microbatches)
microbatched_data = tf.nest.map_structure(reshape_fn, data)
clipped_grads = tf.vectorized_map(
self._compute_per_example_grads,
@ -305,7 +305,9 @@ def make_dp_model_class(cls):
output_metrics[_PRIVATIZED_LOSS_NAME] += summed_regularization_loss
# Log the true loss, including regularization losses.
self.compiled_loss(y, y_pred, regularization_losses=self.losses)
self.compiled_loss(
y, y_pred, sample_weight=weights, regularization_losses=self.losses
)
# Forward the private gradients to the optimizer and return the results.
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))

View file

@ -21,8 +21,11 @@ from tensorflow_privacy.privacy.keras_models import dp_keras_model
def get_data():
# Data is for hidden weights of [3, 1] and bias of 2.
# With mean squared loss, we expect loss = 15^2 = 225, gradients of
# weights = [90, 120], and gradient of bias = 30.
# Loss is (w.x + b - y)^2, model is initialized at (w, b) = (0, 0).
# y = 15
# Loss: y^2 = 15^2 = 225
# Gradient w.r.t. w = -2yx = [90, 120]
# Gradient w.r.t. b = -2y = 30
data = np.array([[3, 4]])
labels = np.matmul(data, [[3], [1]]) + 2
return data, labels
@ -41,8 +44,10 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
1, kernel_initializer='zeros', bias_initializer='zeros')
])
1, kernel_initializer='zeros', bias_initializer='zeros'
),
],
)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
loss = tf.keras.losses.MeanSquaredError()
@ -58,101 +63,149 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
l2_norm_clip=(10.0, 40.0, 200.0),
fast_clipping=(True, False),
sequential=(True, False),
weighted=(False, True),
)
def testClippingNorm(self, l2_norm_clip, fast_clipping):
def testClippingNorm(self, l2_norm_clip, fast_clipping, sequential, weighted):
"""Tests that clipping norm works."""
train_data, train_labels = get_data()
# Simple linear model returns w * x + b.
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
layer_registry=layer_registry.make_default_layer_registry()
if fast_clipping
else None,
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
1, kernel_initializer='zeros', bias_initializer='zeros'
),
],
layer = tf.keras.layers.Dense(
1, kernel_initializer='zeros', bias_initializer='zeros'
)
if sequential:
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
layer_registry=layer_registry.make_default_layer_registry()
if fast_clipping
else None,
layers=[tf.keras.layers.InputLayer(input_shape=(2,)), layer],
)
else:
inputs = tf.keras.Input(shape=(2,), dtype=tf.float32)
outputs = layer(inputs)
model = dp_keras_model.DPModel(
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
layer_registry=layer_registry.make_default_layer_registry()
if fast_clipping
else None,
inputs=inputs,
outputs=outputs,
)
learning_rate = 0.01
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError()
model.compile(optimizer=optimizer, loss=loss)
expected_loss = loss(train_labels, model(train_data))
results = model.fit(train_data, train_labels, epochs=1, batch_size=1)
model_weights = model.get_weights()
weights = None
data = tf.data.Dataset.from_tensors((train_data, train_labels))
expected_grad_w = np.array([90.0, 120.0])
expected_grad_b = np.array([30.0])
if weighted:
# Apply a weight to the (single) example.
weights = [0.18]
data = tf.data.Dataset.from_tensors((train_data, train_labels, weights))
expected_grad_w *= 0.18
expected_grad_b *= 0.18
unclipped_gradient = np.sqrt(90**2 + 120**2 + 30**2)
scale = min(1.0, l2_norm_clip / unclipped_gradient)
expected_weights = np.array([[90], [120]]) * scale * learning_rate
expected_bias = np.array([30]) * scale * learning_rate
unclipped_norm = np.linalg.norm(
np.concatenate([expected_grad_w, expected_grad_b])
)
scale = min(1.0, l2_norm_clip / unclipped_norm)
expected_weights = expected_grad_w * scale * learning_rate
expected_bias = expected_grad_b * scale * learning_rate
expected_loss = loss(train_labels, model(train_data), weights)
results = model.fit(data, epochs=1, batch_size=1)
weights, bias = model.get_weights()
# Check parameters are as expected, taking into account the learning rate.
self.assertAllClose(model_weights[0], expected_weights)
self.assertAllClose(model_weights[1], expected_bias)
self.assertAllClose(np.squeeze(weights), expected_weights)
self.assertAllClose(bias, expected_bias)
# Check the value of the loss.
actual_loss = results.history['loss'][0]
self.assertAllClose(expected_loss, actual_loss)
def _compute_expected_gradients(self, data, labels, w, l2_norm_clip,
num_microbatches):
def _compute_expected_gradients(
self,
data,
labels,
weights,
w0,
l2_norm_clip,
num_microbatches,
):
if weights is None:
weights = np.array([1], dtype=np.float32)
batch_size = data.shape[0]
if num_microbatches is None:
num_microbatches = batch_size
preds = np.matmul(data, np.expand_dims(w, axis=1))
grads = 2 * data * (preds - labels)
grads = np.reshape(grads,
[num_microbatches, batch_size // num_microbatches, -1])
preds = np.matmul(data, w0[:, np.newaxis])
grads = 2 * data * (preds - labels) * weights[:, np.newaxis]
grads = np.reshape(
grads, [num_microbatches, batch_size // num_microbatches, -1]
)
mb_grads = np.mean(grads, axis=1)
mb_grad_norms = np.linalg.norm(mb_grads, axis=1)
scale = np.minimum(l2_norm_clip / mb_grad_norms, 1.0)
mb_grads = mb_grads * scale[:, np.newaxis]
final_grads = np.mean(mb_grads, axis=0)
return final_grads
@parameterized.product(
num_microbatches=(None, 1, 2, 4),
fast_clipping=(False, True),
fast_clipping=(True, False),
sequential=(False, True),
weighted=(True, False),
)
def testMicrobatches(self, num_microbatches, fast_clipping):
def testMicrobatches(
self, num_microbatches, fast_clipping, sequential, weighted
):
l2_norm_clip = 1.0
train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
w = np.zeros((2))
train_labels = np.array([[1.0], [3.0], [-2.0], [-4.0]])
if weighted:
train_weights = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)
dataset = tf.data.Dataset.from_tensors(
(train_data, train_labels, train_weights)
)
else:
train_weights = None
dataset = tf.data.Dataset.from_tensors((train_data, train_labels))
learning_rate = 1.0
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError()
# Simple linear model returns w * x.
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
layer_registry=layer_registry.make_default_layer_registry()
if fast_clipping
else None,
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
1, use_bias=False, kernel_initializer='zeros'
),
],
layer = tf.keras.layers.Dense(1, use_bias=False, kernel_initializer='zeros')
registry = (
layer_registry.make_default_layer_registry() if fast_clipping else None
)
if sequential:
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
layer_registry=registry,
layers=[tf.keras.layers.InputLayer(input_shape=(2,)), layer],
)
else:
inputs = tf.keras.Input(shape=(2,), dtype=tf.float32)
outputs = layer(inputs)
model = dp_keras_model.DPModel(
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
layer_registry=registry,
inputs=inputs,
outputs=outputs,
)
model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
model.fit(dataset, epochs=1, batch_size=4, shuffle=False)
model_weights = np.squeeze(model.get_weights())
@ -163,7 +216,14 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
)
expected_grads = self._compute_expected_gradients(
train_data, train_labels, w, l2_norm_clip, effective_num_microbatches
train_data,
train_labels,
train_weights,
np.zeros(
2,
),
l2_norm_clip,
effective_num_microbatches,
)
expected_weights = np.squeeze(-learning_rate * expected_grads)
self.assertAllClose(model_weights, expected_weights)