Support weighted losses in DPModel.
PiperOrigin-RevId: 538011437
This commit is contained in:
parent
60d237be83
commit
18c43b351b
5 changed files with 297 additions and 217 deletions
|
@ -69,9 +69,10 @@ def get_registry_generator_fn(
|
|||
|
||||
def compute_gradient_norms(
|
||||
input_model: tf.keras.Model,
|
||||
layer_registry: lr.LayerRegistry,
|
||||
x_batch: InputTensor,
|
||||
y_batch: tf.Tensor,
|
||||
layer_registry: lr.LayerRegistry,
|
||||
weight_batch: Optional[tf.Tensor] = None,
|
||||
per_example_loss_fn: Optional[LossFn] = None,
|
||||
num_microbatches: Optional[lr.BatchSize] = None,
|
||||
trainable_vars: Optional[List[tf.Variable]] = None,
|
||||
|
@ -83,16 +84,19 @@ def compute_gradient_norms(
|
|||
|
||||
Args:
|
||||
input_model: The `tf.keras.Model` from which to obtain the layers from. The
|
||||
loss of the model *must* be a scalar loss.
|
||||
loss of the model *must* be a scalar loss. When using microbatching, the
|
||||
loss reduction must be mean.
|
||||
layer_registry: A `LayerRegistry` instance containing functions that help
|
||||
compute gradient norms quickly. See
|
||||
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
|
||||
more details.
|
||||
x_batch: An `InputTensor` representing a batch of inputs to the model. The
|
||||
first axis must be the batch dimension.
|
||||
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
||||
must be the batch dimension. The number of examples should match the
|
||||
number of examples in `x_batch`.
|
||||
layer_registry: A `LayerRegistry` instance containing functions that help
|
||||
compute gradient norms quickly. See
|
||||
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
|
||||
more details.
|
||||
weight_batch: Optional batch of weights, passed to the loss function.
|
||||
Weights apply to the loss prior to clipping.
|
||||
per_example_loss_fn: takes as input predictions, labels and weights, and
|
||||
outputs a vector of per-example losses. If None, derived from
|
||||
`input_model.loss` by disabling its reduction.
|
||||
|
@ -108,8 +112,11 @@ def compute_gradient_norms(
|
|||
variables are included.
|
||||
|
||||
Returns:
|
||||
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
|
||||
per-example loss function.
|
||||
A scalar vector, whose i-th entry is the norm of the gradient of the i-th
|
||||
weighted example loss (when num_microbatches is None) or the norm of the
|
||||
gradient of the i-th microbatch loss (define as a mean over the microbatch).
|
||||
Note that when the loss is weighted (`weight_batch` is not None), weights
|
||||
are applied prior to clipping.
|
||||
"""
|
||||
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
||||
registry_generator_fn = get_registry_generator_fn(
|
||||
|
@ -127,7 +134,7 @@ def compute_gradient_norms(
|
|||
loss_config = input_model.loss.get_config()
|
||||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
||||
losses = per_example_loss_fn(y_batch, model_outputs)
|
||||
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
|
||||
if losses.shape is None:
|
||||
raise NotImplementedError(
|
||||
"The unreduced (or per-example) loss's shape cannot be `None`"
|
||||
|
@ -140,7 +147,7 @@ def compute_gradient_norms(
|
|||
)
|
||||
if num_microbatches is not None:
|
||||
losses = tf.reduce_mean(
|
||||
lr.add_microbatch_axis(losses, num_microbatches), axis=1
|
||||
lr.maybe_add_microbatch_axis(losses, num_microbatches), axis=1
|
||||
)
|
||||
summed_loss = tf.reduce_sum(losses)
|
||||
# Unwrap the generator outputs so that the next loop avoids duplicating
|
||||
|
@ -165,6 +172,10 @@ def compute_gradient_norms(
|
|||
vars_list,
|
||||
unconnected_gradients=tf.UnconnectedGradients.ZERO,
|
||||
)
|
||||
if not grads_list:
|
||||
raise ValueError('The gradient list cannot be empty.')
|
||||
if len(grads_list) != len(sqr_norm_fns_list):
|
||||
raise ValueError('There must be as many norms as gradients.')
|
||||
sqr_norm_list = []
|
||||
for grads, f in zip(grads_list, sqr_norm_fns_list):
|
||||
sqr_norm_list.append(f(grads))
|
||||
|
@ -199,30 +210,26 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
|
|||
|
||||
def compute_clipped_gradients_and_outputs(
|
||||
input_model: tf.keras.Model,
|
||||
x_batch: InputTensor,
|
||||
y_batch: tf.Tensor,
|
||||
l2_norm_clip: float,
|
||||
layer_registry: lr.LayerRegistry,
|
||||
x_batch: InputTensor,
|
||||
y_batch: tf.Tensor,
|
||||
weight_batch: Optional[tf.Tensor] = None,
|
||||
num_microbatches: Optional[lr.BatchSize] = None,
|
||||
clipping_loss: Optional[LossFn] = None,
|
||||
) -> Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor]:
|
||||
"""Computes the per-example clipped loss gradient and other useful outputs.
|
||||
|
||||
Given a batch of observations `(x_batch, y_batch)`, the main steps of this
|
||||
function are: (i) compute the l2-norm of the gradients of the trainable
|
||||
variables of `input_model` for each example in the batch; (ii) use the norms
|
||||
computed in (i) to obtain "clip_weights" that are used to generate a weighted
|
||||
loss function whose gradient for each example has l2-norm at most
|
||||
`l2_norm_clip`; (iii) output the clipped gradients in (ii) and other useful
|
||||
outputs to the caller.
|
||||
Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main
|
||||
steps of this function are:
|
||||
(i) compute the l2-norm of the gradients w.r.t. the trainable variables of
|
||||
`input_model`, for each weighted example loss in the batch;
|
||||
(ii) use the norms computed in (i) to obtain "clip_weights" that are used to
|
||||
reweight the loss function, such that each gradient of this reweighted loss
|
||||
has l2-norm at most `l2_norm_clip`.
|
||||
|
||||
Args:
|
||||
input_model: The `tf.keras.Model` from which to obtain the layers from.
|
||||
x_batch: An `InputTensor` representing a batch of inputs to the model. The
|
||||
first axis must be the batch dimension.
|
||||
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
||||
must be the batch dimension. The number of examples should match the
|
||||
number of examples in `x_batch`.
|
||||
l2_norm_clip: A `float` indicating the norm to which per-example gradients
|
||||
will be clipped. That is, all gradients of the per-example loss functions
|
||||
will have norm at most `l2_norm_clip`.
|
||||
|
@ -232,6 +239,15 @@ def compute_clipped_gradients_and_outputs(
|
|||
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
|
||||
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
||||
trainable weights (see `layer_registry_factories.py` for examples).
|
||||
x_batch: An `InputTensor` representing a batch of inputs to the model. The
|
||||
first axis must be the batch dimension.
|
||||
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
||||
must be the batch dimension. The number of examples should match the
|
||||
number of examples in `x_batch`.
|
||||
weight_batch: Optional vector of weights, passed to the loss function. Must
|
||||
be of size [batch_size]. In case of microbatching, this will be reshaped
|
||||
to [num_microbatches, batch_size/num_microbatches] before passing it to
|
||||
the loss.
|
||||
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
|
||||
microbatches. If not None, indicates that the loss is grouped into
|
||||
num_microbatches (in this case, the batch dimension needs to be a multiple
|
||||
|
@ -243,11 +259,11 @@ def compute_clipped_gradients_and_outputs(
|
|||
the value of the clipped loss does not reflect the true loss.
|
||||
|
||||
Returns:
|
||||
A `tuple` `(grad, y_pred, clipping_loss_value)`. The first element is the
|
||||
clipped gradient of the loss function, the second is the result of
|
||||
applying `input_model` to `x_batch`, and the third is loss value of
|
||||
`input_model`, weighted by the loss weights generated by a specific
|
||||
`compute_clip_weights()` call.
|
||||
clipped_grad: list of the clipped gradients of the loss function (one per
|
||||
trainable variable in `input_model`).
|
||||
y_pred: the result of applying `input_model` to `x_batch`.
|
||||
clipping_loss_value: the loss value weighted in such a way that its gradient
|
||||
is `clipped_grad`.
|
||||
"""
|
||||
if input_model.loss.reduction == 'none':
|
||||
raise NotImplementedError(
|
||||
|
@ -258,13 +274,26 @@ def compute_clipped_gradients_and_outputs(
|
|||
clipping_loss = input_model.compiled_loss
|
||||
gradient_norms = compute_gradient_norms(
|
||||
input_model,
|
||||
layer_registry,
|
||||
x_batch,
|
||||
y_batch,
|
||||
layer_registry,
|
||||
weight_batch,
|
||||
num_microbatches=num_microbatches,
|
||||
trainable_vars=input_model.trainable_variables,
|
||||
)
|
||||
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
|
||||
clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
|
||||
if weight_batch is not None:
|
||||
# Let w be the `weight_batch`, c be the `clip_weights`, and l be the losses.
|
||||
# c is computed based on the gradient of w*l, so that if we scale w*l by c,
|
||||
# the result has bounded per-example gradients. So the loss to optimize is
|
||||
# c*w*l. Here we compute c*w before passing it to the loss.
|
||||
weight_batch = lr.maybe_add_microbatch_axis(weight_batch, num_microbatches)
|
||||
if num_microbatches is None:
|
||||
clip_weights = clip_weights * weight_batch # shape [num_microbatches]
|
||||
else:
|
||||
# In this case, weight_batch is of shape [batch_size, microbatch_size],
|
||||
# we multiply by the clip_weights (which is of shape [num_microbatches])
|
||||
clip_weights = clip_weights[:, tf.newaxis] * weight_batch
|
||||
with tf.GradientTape() as tape:
|
||||
# WARNING: When num_microbatches is not None, we need to be sure that
|
||||
# `compute_loss` always computes the mean over the microbatches
|
||||
|
@ -274,17 +303,9 @@ def compute_clipped_gradients_and_outputs(
|
|||
# is not defined in the contract so may not hold, especially for
|
||||
# custom losses.
|
||||
y_pred = input_model(x_batch, training=True)
|
||||
loss_y_batch = (
|
||||
y_batch
|
||||
if num_microbatches is None
|
||||
else lr.add_microbatch_axis(y_batch, num_microbatches)
|
||||
)
|
||||
loss_y_pred = (
|
||||
y_pred
|
||||
if num_microbatches is None
|
||||
else lr.add_microbatch_axis(y_pred, num_microbatches)
|
||||
)
|
||||
clipping_loss_value = clipping_loss(loss_y_batch, loss_y_pred, loss_weights)
|
||||
mb_y_batch = lr.maybe_add_microbatch_axis(y_batch, num_microbatches)
|
||||
mb_y_pred = lr.maybe_add_microbatch_axis(y_pred, num_microbatches)
|
||||
clipping_loss_value = clipping_loss(mb_y_batch, mb_y_pred, clip_weights)
|
||||
clipped_grads = tape.gradient(
|
||||
clipping_loss_value,
|
||||
input_model.trainable_variables,
|
||||
|
|
|
@ -12,12 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||
|
||||
|
@ -71,17 +69,23 @@ def double_dense_layer_computation(
|
|||
return [vars1, vars2], outputs, sqr_norm_fn
|
||||
|
||||
|
||||
def test_loss_fn(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
|
||||
x = tf.reshape(x, (tf.shape(x)[0], -1))
|
||||
y = tf.reshape(y, (tf.shape(y)[0], -1))
|
||||
def test_loss_fn(
|
||||
x: tf.Tensor, y: tf.Tensor, weights: Optional[tf.Tensor] = None
|
||||
) -> tf.Tensor:
|
||||
# Define a loss function which is unlikely to be coincidently defined.
|
||||
return 3.14 * tf.reduce_sum(tf.square(x - y), axis=1)
|
||||
if weights is None:
|
||||
weights = 1.0
|
||||
loss = 3.14 * tf.reduce_sum(
|
||||
tf.cast(weights, tf.float32) * tf.square(x - y), axis=1
|
||||
)
|
||||
return loss
|
||||
|
||||
|
||||
def compute_true_gradient_norms(
|
||||
input_model: tf.keras.Model,
|
||||
x_batch: tf.Tensor,
|
||||
y_batch: tf.Tensor,
|
||||
weight_batch: Optional[tf.Tensor],
|
||||
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
|
||||
num_microbatches: Optional[int],
|
||||
trainable_vars: Optional[tf.Variable] = None,
|
||||
|
@ -93,7 +97,7 @@ def compute_true_gradient_norms(
|
|||
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
||||
with tf.GradientTape(persistent=True) as tape:
|
||||
y_pred = input_model(x_batch)
|
||||
loss = per_example_loss_fn(y_batch, y_pred)
|
||||
loss = per_example_loss_fn(y_batch, y_pred, weight_batch)
|
||||
if num_microbatches is not None:
|
||||
loss = tf.reduce_mean(
|
||||
tf.reshape(
|
||||
|
@ -123,7 +127,8 @@ def get_computed_and_true_norms(
|
|||
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
|
||||
num_microbatches: Optional[int],
|
||||
is_eager: bool,
|
||||
x_input: tf.Tensor,
|
||||
x_batch: tf.Tensor,
|
||||
weight_batch: Optional[tf.Tensor] = None,
|
||||
rng_seed: int = 777,
|
||||
registry: layer_registry.LayerRegistry = None,
|
||||
partial: bool = False,
|
||||
|
@ -146,10 +151,11 @@ def get_computed_and_true_norms(
|
|||
per_example_loss_fn: If not None, used as vectorized per example loss
|
||||
function.
|
||||
num_microbatches: The number of microbatches. None or an integer.
|
||||
is_eager: A `bool` that is `True` if the model should be run eagerly.
|
||||
x_input: `tf.Tensor` inputs to be tested.
|
||||
rng_seed: An `int` used to initialize model weights.
|
||||
registry: A `layer_registry.LayerRegistry` instance.
|
||||
is_eager: whether the model should be run eagerly.
|
||||
x_batch: inputs to be tested.
|
||||
weight_batch: optional weights passed to the loss.
|
||||
rng_seed: used as a seed for random initialization.
|
||||
registry: required for fast clipping.
|
||||
partial: Whether to compute the gradient norm with respect to a partial set
|
||||
of varibles. If True, only consider the variables in the first layer.
|
||||
|
||||
|
@ -175,13 +181,14 @@ def get_computed_and_true_norms(
|
|||
trainable_vars = l.trainable_variables
|
||||
if trainable_vars:
|
||||
break
|
||||
y_pred = model(x_input)
|
||||
y_pred = model(x_batch)
|
||||
y_batch = tf.ones_like(y_pred)
|
||||
tf.keras.utils.set_random_seed(rng_seed)
|
||||
computed_norms = clip_grads.compute_gradient_norms(
|
||||
model,
|
||||
x_input,
|
||||
y_batch,
|
||||
input_model=model,
|
||||
x_batch=x_batch,
|
||||
y_batch=y_batch,
|
||||
weight_batch=weight_batch,
|
||||
layer_registry=registry,
|
||||
per_example_loss_fn=per_example_loss_fn,
|
||||
num_microbatches=num_microbatches,
|
||||
|
@ -190,8 +197,9 @@ def get_computed_and_true_norms(
|
|||
tf.keras.utils.set_random_seed(rng_seed)
|
||||
true_norms = compute_true_gradient_norms(
|
||||
model,
|
||||
x_input,
|
||||
x_batch,
|
||||
y_batch,
|
||||
weight_batch,
|
||||
per_example_loss_fn,
|
||||
num_microbatches,
|
||||
trainable_vars=trainable_vars,
|
||||
|
@ -309,24 +317,16 @@ def make_weighted_bow_model(layer_generator, input_dims, output_dim):
|
|||
# ==============================================================================
|
||||
# Factory functions.
|
||||
# ==============================================================================
|
||||
def get_nd_test_tensors(n: int):
|
||||
"""Returns a list of candidate tests for a given dimension n."""
|
||||
return [
|
||||
tf.zeros((n,), dtype=tf.float64),
|
||||
tf.convert_to_tensor(range(n), dtype_hint=tf.float64),
|
||||
]
|
||||
|
||||
|
||||
def get_nd_test_batches(n: int):
|
||||
"""Returns a list of candidate input batches of dimension n."""
|
||||
result = []
|
||||
tensors = get_nd_test_tensors(n)
|
||||
for batch_size in range(1, len(tensors) + 1, 1):
|
||||
combinations = list(
|
||||
itertools.combinations(get_nd_test_tensors(n), batch_size)
|
||||
)
|
||||
result = result + [tf.stack(ts, axis=0) for ts in combinations]
|
||||
return result
|
||||
"""Returns a list of input batches of dimension n."""
|
||||
# The first two batches have a single element, the last batch has 2 elements.
|
||||
x0 = tf.zeros([1, n], dtype=tf.float64)
|
||||
x1 = tf.constant([range(n)], dtype=tf.float64)
|
||||
x2 = tf.concat([x0, x1], axis=0)
|
||||
w0 = tf.constant([1], dtype=tf.float64)
|
||||
w1 = tf.constant([2], dtype=tf.float64)
|
||||
w2 = tf.constant([0.5, 0.5], dtype=tf.float64)
|
||||
return [x0, x1, x2], [w0, w1, w2]
|
||||
|
||||
|
||||
def get_dense_layer_generators():
|
||||
|
@ -366,11 +366,14 @@ class ClipGradsDirectTest(tf.test.TestCase, parameterized.TestCase):
|
|||
)
|
||||
def test_clip_weights(self, input_dim, clip_value):
|
||||
tol = 1e-6
|
||||
for t in get_nd_test_tensors(input_dim):
|
||||
self.assertIsNone(clip_grads.compute_clip_weights(None, t))
|
||||
ts, _ = get_nd_test_batches(input_dim)
|
||||
for t in ts:
|
||||
weights = clip_grads.compute_clip_weights(clip_value, t)
|
||||
self.assertAllLessEqual(t * weights, clip_value + tol)
|
||||
|
||||
def test_clip_weights_none(self):
|
||||
self.assertIsNone(clip_grads.compute_clip_weights(None, tf.ones(3)))
|
||||
|
||||
|
||||
class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
@ -383,6 +386,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
num_microbatches=[None, 1, 2],
|
||||
is_eager=[True, False],
|
||||
partial=[True, False],
|
||||
weighted=[True, False],
|
||||
)
|
||||
def test_gradient_norms_on_various_models(
|
||||
self,
|
||||
|
@ -394,21 +398,16 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
num_microbatches,
|
||||
is_eager,
|
||||
partial,
|
||||
weighted,
|
||||
):
|
||||
model_generator = get_dense_model_generators()[model_name]
|
||||
layer_generator = get_dense_layer_generators()[layer_name]
|
||||
x_batches = get_nd_test_batches(input_dim)
|
||||
x_batches, weight_batches = get_nd_test_batches(input_dim)
|
||||
default_registry = layer_registry.make_default_layer_registry()
|
||||
for x_batch in x_batches:
|
||||
if (
|
||||
num_microbatches is not None
|
||||
and x_batch.shape[0] % num_microbatches != 0
|
||||
):
|
||||
for x_batch, weight_batch in zip(x_batches, weight_batches):
|
||||
batch_size = x_batch.shape[0]
|
||||
if num_microbatches is not None and batch_size % num_microbatches != 0:
|
||||
continue
|
||||
if model_name == 'tower1':
|
||||
x_input = [x_batch, x_batch]
|
||||
else:
|
||||
x_input = x_batch
|
||||
(computed_norms, true_norms) = get_computed_and_true_norms(
|
||||
model_generator,
|
||||
layer_generator,
|
||||
|
@ -417,10 +416,13 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
per_example_loss_fn,
|
||||
num_microbatches,
|
||||
is_eager,
|
||||
x_input,
|
||||
x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch,
|
||||
weight_batch=weight_batch if weighted else None,
|
||||
registry=default_registry,
|
||||
partial=partial,
|
||||
)
|
||||
expected_size = num_microbatches or batch_size
|
||||
self.assertEqual(computed_norms.shape[0], expected_size)
|
||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
|
@ -471,16 +473,14 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
is_eager,
|
||||
partial,
|
||||
):
|
||||
batch_size = x_batch.shape[0]
|
||||
# The following are invalid test combinations, and are skipped.
|
||||
if (
|
||||
num_microbatches is not None
|
||||
and x_batch.shape[0] % num_microbatches != 0
|
||||
num_microbatches is not None and batch_size % num_microbatches != 0
|
||||
) or (
|
||||
model_name == 'weighted_bow1' and isinstance(x_batch, tf.RaggedTensor)
|
||||
):
|
||||
return
|
||||
valid_test_input = (
|
||||
not isinstance(x_batch, tf.RaggedTensor)
|
||||
and model_name == 'weighted_bow1'
|
||||
) or (model_name != 'weighted_bow1')
|
||||
if valid_test_input:
|
||||
default_registry = layer_registry.make_default_layer_registry()
|
||||
model_generator = get_embedding_model_generators()[model_name]
|
||||
(computed_norms, true_norms) = get_computed_and_true_norms(
|
||||
|
@ -491,10 +491,11 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
per_example_loss_fn=per_example_loss_fn,
|
||||
num_microbatches=num_microbatches,
|
||||
is_eager=is_eager,
|
||||
x_input=x_batch,
|
||||
x_batch=x_batch,
|
||||
registry=default_registry,
|
||||
partial=partial,
|
||||
)
|
||||
self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size)
|
||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
|
@ -507,6 +508,7 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
num_microbatches=[None, 2],
|
||||
is_eager=[True, False],
|
||||
partial=[True, False],
|
||||
weighted=[True, False],
|
||||
)
|
||||
def test_gradient_norms_on_various_models(
|
||||
self,
|
||||
|
@ -516,15 +518,14 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
num_microbatches,
|
||||
is_eager,
|
||||
partial,
|
||||
weighted,
|
||||
):
|
||||
registry = layer_registry.make_default_layer_registry()
|
||||
registry.insert(DoubleDense, double_dense_layer_computation)
|
||||
x_batches = get_nd_test_batches(input_dim)
|
||||
for x_batch in x_batches:
|
||||
if (
|
||||
num_microbatches is not None
|
||||
and x_batch.shape[0] % num_microbatches != 0
|
||||
):
|
||||
x_batches, weight_batches = get_nd_test_batches(input_dim)
|
||||
for x_batch, weight_batch in zip(x_batches, weight_batches):
|
||||
batch_size = x_batch.shape[0]
|
||||
if num_microbatches is not None and batch_size % num_microbatches != 0:
|
||||
continue
|
||||
(computed_norms, true_norms) = get_computed_and_true_norms(
|
||||
model_generator=make_two_layer_sequential_model,
|
||||
|
@ -534,10 +535,12 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
per_example_loss_fn=per_example_loss_fn,
|
||||
num_microbatches=num_microbatches,
|
||||
is_eager=is_eager,
|
||||
x_input=x_batch,
|
||||
x_batch=x_batch,
|
||||
weight_batch=weight_batch if weighted else None,
|
||||
registry=registry,
|
||||
partial=partial,
|
||||
)
|
||||
self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size)
|
||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
|
@ -574,16 +577,13 @@ class ClipGradsComputeClippedGradsAndOutputsTest(
|
|||
)
|
||||
# Stop early for efficiency.
|
||||
if reduction == 'none':
|
||||
self.assertRaises(
|
||||
NotImplementedError,
|
||||
# function tested
|
||||
clip_grads.compute_clipped_gradients_and_outputs,
|
||||
# function args
|
||||
with self.assertRaises(NotImplementedError):
|
||||
clip_grads.compute_clipped_gradients_and_outputs(
|
||||
self._model,
|
||||
x_batch,
|
||||
y_batch,
|
||||
l2_norm_clip,
|
||||
layer_registry.make_default_layer_registry(),
|
||||
x_batch,
|
||||
y_batch,
|
||||
)
|
||||
return
|
||||
# NOTE: losses from this point are scalar losses.
|
||||
|
@ -593,10 +593,10 @@ class ClipGradsComputeClippedGradsAndOutputsTest(
|
|||
true_grads = tape.gradient(loss_value, self._model.trainable_variables)
|
||||
clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs(
|
||||
self._model,
|
||||
x_batch,
|
||||
y_batch,
|
||||
l2_norm_clip,
|
||||
layer_registry.make_default_layer_registry(),
|
||||
x_batch,
|
||||
y_batch,
|
||||
)
|
||||
|
||||
# Computes the L2 norm manually.
|
||||
|
|
|
@ -105,26 +105,23 @@ class LayerRegistry:
|
|||
# ==============================================================================
|
||||
# Utilities
|
||||
# ==============================================================================
|
||||
def add_microbatch_axis(
|
||||
def maybe_add_microbatch_axis(
|
||||
x: tf.Tensor,
|
||||
num_microbatches: Optional[BatchSize],
|
||||
) -> tf.Tensor:
|
||||
"""Adds the microbatch axis.
|
||||
|
||||
Reshape the input tensor to replace the first(batch) dimension with the
|
||||
shape [num_microbatches, batch_size / num_microbatches]. The batch size
|
||||
must be a multiple of num_microbatches (unless it is None, meaning
|
||||
num_microbatches is the same as the batch size).
|
||||
|
||||
Args:
|
||||
x: the input tensor.
|
||||
num_microbatches: None or a numeric value or a scalar `tf.Tensor`.
|
||||
num_microbatches: If None, x is returned unchanged. Otherwise, must divide
|
||||
the batch size.
|
||||
|
||||
Returns:
|
||||
The reshaped input tensor.
|
||||
The input tensor x, reshaped from [batch_size, ...] to
|
||||
[num_microbatches, batch_size / num_microbatches, ...].
|
||||
"""
|
||||
if num_microbatches is None:
|
||||
return tf.expand_dims(x, 1)
|
||||
return x
|
||||
with tf.control_dependencies(
|
||||
[tf.assert_equal(tf.math.floormod(tf.shape(x)[0], num_microbatches), 0)]
|
||||
):
|
||||
|
@ -193,7 +190,7 @@ def dense_layer_computation(
|
|||
|
||||
def _compute_gramian(x):
|
||||
if num_microbatches is not None:
|
||||
x_microbatched = add_microbatch_axis(x, num_microbatches)
|
||||
x_microbatched = maybe_add_microbatch_axis(x, num_microbatches)
|
||||
return tf.matmul(x_microbatched, x_microbatched, transpose_b=True)
|
||||
else:
|
||||
# Special handling for better efficiency
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||
|
@ -179,15 +178,19 @@ def make_dp_model_class(cls):
|
|||
tf.shape(stacked_grads)[0], summed_grads.dtype
|
||||
)
|
||||
|
||||
def _compute_per_example_grads(self, data):
|
||||
def _compute_per_example_grads(self, microbatched_data):
|
||||
if self._clipping_loss is None:
|
||||
self._make_clipping_loss()
|
||||
microbatched_x, microbatched_y = data
|
||||
microbatched_x, microbatched_y, microbatched_weights = (
|
||||
tf.keras.utils.unpack_x_y_sample_weight(microbatched_data)
|
||||
)
|
||||
with tf.GradientTape() as tape:
|
||||
microbatched_y_pred = self(microbatched_x, training=True)
|
||||
# NOTE: `self._clipping_loss` does not include any regularization terms.
|
||||
microbatched_loss = self._clipping_loss(
|
||||
microbatched_y, microbatched_y_pred
|
||||
microbatched_y,
|
||||
microbatched_y_pred,
|
||||
sample_weight=microbatched_weights,
|
||||
)
|
||||
grads_list = tape.gradient(microbatched_loss, self.trainable_variables)
|
||||
clipped_grads = self._process_per_example_grads(grads_list)
|
||||
|
@ -232,12 +235,8 @@ def make_dp_model_class(cls):
|
|||
self._make_clipping_loss()
|
||||
output_metrics = {}
|
||||
x, y, weights = tf.keras.utils.unpack_x_y_sample_weight(data)
|
||||
if weights is not None:
|
||||
raise NotImplementedError(
|
||||
'DPModel does not currently support weighted losses.'
|
||||
)
|
||||
batch_size = tf.shape(y)[0]
|
||||
eff_num_microbatches = self._num_microbatches or batch_size
|
||||
num_microbatches = self._num_microbatches or batch_size
|
||||
|
||||
# Branch based on gradient clipping algorithm.
|
||||
if self._enable_fast_peg_computation:
|
||||
|
@ -251,13 +250,14 @@ def make_dp_model_class(cls):
|
|||
# microbatches is done here.
|
||||
clipped_grads, y_pred, clipping_loss = (
|
||||
clip_grads.compute_clipped_gradients_and_outputs(
|
||||
self,
|
||||
x,
|
||||
y,
|
||||
self._l2_norm_clip,
|
||||
self._layer_registry,
|
||||
self._num_microbatches,
|
||||
self._clipping_loss,
|
||||
input_model=self,
|
||||
x_batch=x,
|
||||
y_batch=y,
|
||||
weight_batch=weights,
|
||||
l2_norm_clip=self._l2_norm_clip,
|
||||
layer_registry=self._layer_registry,
|
||||
num_microbatches=self._num_microbatches,
|
||||
clipping_loss=self._clipping_loss,
|
||||
)
|
||||
)
|
||||
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
|
||||
|
@ -265,7 +265,7 @@ def make_dp_model_class(cls):
|
|||
grads = gradient_clipping_utils.add_aggregate_noise(
|
||||
self,
|
||||
clipped_grads,
|
||||
eff_num_microbatches,
|
||||
num_microbatches,
|
||||
self._l2_norm_clip,
|
||||
self._noise_multiplier,
|
||||
)
|
||||
|
@ -276,7 +276,7 @@ def make_dp_model_class(cls):
|
|||
# Computes per-example clipped gradients directly. This is called
|
||||
# if at least one of the layers cannot use the "fast" gradient clipping
|
||||
# algorithm.
|
||||
reshape_fn = lambda z: lr.add_microbatch_axis(z, eff_num_microbatches)
|
||||
reshape_fn = lambda z: lr.maybe_add_microbatch_axis(z, num_microbatches)
|
||||
microbatched_data = tf.nest.map_structure(reshape_fn, data)
|
||||
clipped_grads = tf.vectorized_map(
|
||||
self._compute_per_example_grads,
|
||||
|
@ -305,7 +305,9 @@ def make_dp_model_class(cls):
|
|||
output_metrics[_PRIVATIZED_LOSS_NAME] += summed_regularization_loss
|
||||
|
||||
# Log the true loss, including regularization losses.
|
||||
self.compiled_loss(y, y_pred, regularization_losses=self.losses)
|
||||
self.compiled_loss(
|
||||
y, y_pred, sample_weight=weights, regularization_losses=self.losses
|
||||
)
|
||||
|
||||
# Forward the private gradients to the optimizer and return the results.
|
||||
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
|
||||
|
|
|
@ -21,8 +21,11 @@ from tensorflow_privacy.privacy.keras_models import dp_keras_model
|
|||
|
||||
def get_data():
|
||||
# Data is for hidden weights of [3, 1] and bias of 2.
|
||||
# With mean squared loss, we expect loss = 15^2 = 225, gradients of
|
||||
# weights = [90, 120], and gradient of bias = 30.
|
||||
# Loss is (w.x + b - y)^2, model is initialized at (w, b) = (0, 0).
|
||||
# y = 15
|
||||
# Loss: y^2 = 15^2 = 225
|
||||
# Gradient w.r.t. w = -2yx = [90, 120]
|
||||
# Gradient w.r.t. b = -2y = 30
|
||||
data = np.array([[3, 4]])
|
||||
labels = np.matmul(data, [[3], [1]]) + 2
|
||||
return data, labels
|
||||
|
@ -41,8 +44,10 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
|||
layers=[
|
||||
tf.keras.layers.InputLayer(input_shape=(2,)),
|
||||
tf.keras.layers.Dense(
|
||||
1, kernel_initializer='zeros', bias_initializer='zeros')
|
||||
])
|
||||
1, kernel_initializer='zeros', bias_initializer='zeros'
|
||||
),
|
||||
],
|
||||
)
|
||||
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
|
||||
loss = tf.keras.losses.MeanSquaredError()
|
||||
|
||||
|
@ -58,101 +63,149 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
|||
@parameterized.product(
|
||||
l2_norm_clip=(10.0, 40.0, 200.0),
|
||||
fast_clipping=(True, False),
|
||||
sequential=(True, False),
|
||||
weighted=(False, True),
|
||||
)
|
||||
def testClippingNorm(self, l2_norm_clip, fast_clipping):
|
||||
def testClippingNorm(self, l2_norm_clip, fast_clipping, sequential, weighted):
|
||||
"""Tests that clipping norm works."""
|
||||
train_data, train_labels = get_data()
|
||||
|
||||
# Simple linear model returns w * x + b.
|
||||
layer = tf.keras.layers.Dense(
|
||||
1, kernel_initializer='zeros', bias_initializer='zeros'
|
||||
)
|
||||
if sequential:
|
||||
model = dp_keras_model.DPSequential(
|
||||
l2_norm_clip=l2_norm_clip,
|
||||
noise_multiplier=0.0,
|
||||
layer_registry=layer_registry.make_default_layer_registry()
|
||||
if fast_clipping
|
||||
else None,
|
||||
layers=[
|
||||
tf.keras.layers.InputLayer(input_shape=(2,)),
|
||||
tf.keras.layers.Dense(
|
||||
1, kernel_initializer='zeros', bias_initializer='zeros'
|
||||
),
|
||||
],
|
||||
layers=[tf.keras.layers.InputLayer(input_shape=(2,)), layer],
|
||||
)
|
||||
else:
|
||||
inputs = tf.keras.Input(shape=(2,), dtype=tf.float32)
|
||||
outputs = layer(inputs)
|
||||
model = dp_keras_model.DPModel(
|
||||
l2_norm_clip=l2_norm_clip,
|
||||
noise_multiplier=0.0,
|
||||
layer_registry=layer_registry.make_default_layer_registry()
|
||||
if fast_clipping
|
||||
else None,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
)
|
||||
learning_rate = 0.01
|
||||
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
|
||||
loss = tf.keras.losses.MeanSquaredError()
|
||||
model.compile(optimizer=optimizer, loss=loss)
|
||||
expected_loss = loss(train_labels, model(train_data))
|
||||
results = model.fit(train_data, train_labels, epochs=1, batch_size=1)
|
||||
|
||||
model_weights = model.get_weights()
|
||||
weights = None
|
||||
data = tf.data.Dataset.from_tensors((train_data, train_labels))
|
||||
expected_grad_w = np.array([90.0, 120.0])
|
||||
expected_grad_b = np.array([30.0])
|
||||
if weighted:
|
||||
# Apply a weight to the (single) example.
|
||||
weights = [0.18]
|
||||
data = tf.data.Dataset.from_tensors((train_data, train_labels, weights))
|
||||
expected_grad_w *= 0.18
|
||||
expected_grad_b *= 0.18
|
||||
|
||||
unclipped_gradient = np.sqrt(90**2 + 120**2 + 30**2)
|
||||
scale = min(1.0, l2_norm_clip / unclipped_gradient)
|
||||
expected_weights = np.array([[90], [120]]) * scale * learning_rate
|
||||
expected_bias = np.array([30]) * scale * learning_rate
|
||||
unclipped_norm = np.linalg.norm(
|
||||
np.concatenate([expected_grad_w, expected_grad_b])
|
||||
)
|
||||
scale = min(1.0, l2_norm_clip / unclipped_norm)
|
||||
expected_weights = expected_grad_w * scale * learning_rate
|
||||
expected_bias = expected_grad_b * scale * learning_rate
|
||||
expected_loss = loss(train_labels, model(train_data), weights)
|
||||
results = model.fit(data, epochs=1, batch_size=1)
|
||||
weights, bias = model.get_weights()
|
||||
|
||||
# Check parameters are as expected, taking into account the learning rate.
|
||||
self.assertAllClose(model_weights[0], expected_weights)
|
||||
self.assertAllClose(model_weights[1], expected_bias)
|
||||
self.assertAllClose(np.squeeze(weights), expected_weights)
|
||||
self.assertAllClose(bias, expected_bias)
|
||||
|
||||
# Check the value of the loss.
|
||||
actual_loss = results.history['loss'][0]
|
||||
self.assertAllClose(expected_loss, actual_loss)
|
||||
|
||||
def _compute_expected_gradients(self, data, labels, w, l2_norm_clip,
|
||||
num_microbatches):
|
||||
def _compute_expected_gradients(
|
||||
self,
|
||||
data,
|
||||
labels,
|
||||
weights,
|
||||
w0,
|
||||
l2_norm_clip,
|
||||
num_microbatches,
|
||||
):
|
||||
if weights is None:
|
||||
weights = np.array([1], dtype=np.float32)
|
||||
batch_size = data.shape[0]
|
||||
if num_microbatches is None:
|
||||
num_microbatches = batch_size
|
||||
|
||||
preds = np.matmul(data, np.expand_dims(w, axis=1))
|
||||
|
||||
grads = 2 * data * (preds - labels)
|
||||
|
||||
grads = np.reshape(grads,
|
||||
[num_microbatches, batch_size // num_microbatches, -1])
|
||||
|
||||
preds = np.matmul(data, w0[:, np.newaxis])
|
||||
grads = 2 * data * (preds - labels) * weights[:, np.newaxis]
|
||||
grads = np.reshape(
|
||||
grads, [num_microbatches, batch_size // num_microbatches, -1]
|
||||
)
|
||||
mb_grads = np.mean(grads, axis=1)
|
||||
mb_grad_norms = np.linalg.norm(mb_grads, axis=1)
|
||||
|
||||
scale = np.minimum(l2_norm_clip / mb_grad_norms, 1.0)
|
||||
|
||||
mb_grads = mb_grads * scale[:, np.newaxis]
|
||||
|
||||
final_grads = np.mean(mb_grads, axis=0)
|
||||
return final_grads
|
||||
|
||||
@parameterized.product(
|
||||
num_microbatches=(None, 1, 2, 4),
|
||||
fast_clipping=(False, True),
|
||||
fast_clipping=(True, False),
|
||||
sequential=(False, True),
|
||||
weighted=(True, False),
|
||||
)
|
||||
def testMicrobatches(self, num_microbatches, fast_clipping):
|
||||
def testMicrobatches(
|
||||
self, num_microbatches, fast_clipping, sequential, weighted
|
||||
):
|
||||
l2_norm_clip = 1.0
|
||||
train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
|
||||
w = np.zeros((2))
|
||||
train_labels = np.array([[1.0], [3.0], [-2.0], [-4.0]])
|
||||
if weighted:
|
||||
train_weights = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)
|
||||
dataset = tf.data.Dataset.from_tensors(
|
||||
(train_data, train_labels, train_weights)
|
||||
)
|
||||
else:
|
||||
train_weights = None
|
||||
dataset = tf.data.Dataset.from_tensors((train_data, train_labels))
|
||||
learning_rate = 1.0
|
||||
|
||||
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
|
||||
loss = tf.keras.losses.MeanSquaredError()
|
||||
|
||||
# Simple linear model returns w * x.
|
||||
layer = tf.keras.layers.Dense(1, use_bias=False, kernel_initializer='zeros')
|
||||
registry = (
|
||||
layer_registry.make_default_layer_registry() if fast_clipping else None
|
||||
)
|
||||
if sequential:
|
||||
model = dp_keras_model.DPSequential(
|
||||
l2_norm_clip=l2_norm_clip,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=num_microbatches,
|
||||
layer_registry=layer_registry.make_default_layer_registry()
|
||||
if fast_clipping
|
||||
else None,
|
||||
layers=[
|
||||
tf.keras.layers.InputLayer(input_shape=(2,)),
|
||||
tf.keras.layers.Dense(
|
||||
1, use_bias=False, kernel_initializer='zeros'
|
||||
),
|
||||
],
|
||||
layer_registry=registry,
|
||||
layers=[tf.keras.layers.InputLayer(input_shape=(2,)), layer],
|
||||
)
|
||||
else:
|
||||
inputs = tf.keras.Input(shape=(2,), dtype=tf.float32)
|
||||
outputs = layer(inputs)
|
||||
model = dp_keras_model.DPModel(
|
||||
l2_norm_clip=l2_norm_clip,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=num_microbatches,
|
||||
layer_registry=registry,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
)
|
||||
model.compile(optimizer=optimizer, loss=loss)
|
||||
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
|
||||
model.fit(dataset, epochs=1, batch_size=4, shuffle=False)
|
||||
|
||||
model_weights = np.squeeze(model.get_weights())
|
||||
|
||||
|
@ -163,7 +216,14 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
|||
)
|
||||
|
||||
expected_grads = self._compute_expected_gradients(
|
||||
train_data, train_labels, w, l2_norm_clip, effective_num_microbatches
|
||||
train_data,
|
||||
train_labels,
|
||||
train_weights,
|
||||
np.zeros(
|
||||
2,
|
||||
),
|
||||
l2_norm_clip,
|
||||
effective_num_microbatches,
|
||||
)
|
||||
expected_weights = np.squeeze(-learning_rate * expected_grads)
|
||||
self.assertAllClose(model_weights, expected_weights)
|
||||
|
|
Loading…
Reference in a new issue