forked from 626_privacy/tensorflow_privacy
Add additional tests and checks on the passed loss function.
PiperOrigin-RevId: 532225904
This commit is contained in:
parent
8fdac5f833
commit
0f5acf868e
2 changed files with 94 additions and 1 deletions
|
@ -128,6 +128,16 @@ def compute_gradient_norms(
|
|||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
||||
losses = per_example_loss_fn(y_batch, model_outputs)
|
||||
if losses.shape is None:
|
||||
raise NotImplementedError(
|
||||
"The unreduced (or per-example) loss's shape cannot be `None`"
|
||||
)
|
||||
if len(losses.shape) != 1:
|
||||
raise NotImplementedError(
|
||||
'The unreduced (or per-example) loss needs to have a shape of length '
|
||||
'one, but received an unreduced loss of shape length %s'
|
||||
% len(losses.shape)
|
||||
)
|
||||
if num_microbatches is not None:
|
||||
losses = tf.reduce_mean(
|
||||
lr.add_microbatch_axis(losses, num_microbatches), axis=1
|
||||
|
@ -239,6 +249,11 @@ def compute_clipped_gradients_and_outputs(
|
|||
`input_model`, weighted by the loss weights generated by a specific
|
||||
`compute_clip_weights()` call.
|
||||
"""
|
||||
if input_model.loss.reduction == 'none':
|
||||
raise NotImplementedError(
|
||||
'Fast gradient clipping does not support '
|
||||
'models with unreduced loss functions.'
|
||||
)
|
||||
if clipping_loss is None:
|
||||
clipping_loss = input_model.compiled_loss
|
||||
gradient_norms = compute_gradient_norms(
|
||||
|
|
|
@ -426,7 +426,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
# TODO(wkong): Test sparse input tensors when the GitHub CI environment
|
||||
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
|
||||
# supports them for embeddings.
|
||||
@parameterized.product(
|
||||
x_batch=[
|
||||
|
@ -541,5 +541,83 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
class ClipGradsComputeClippedGradsAndOutputsTest(
|
||||
tf.test.TestCase, parameterized.TestCase
|
||||
):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
dense_generator = lambda a, b: tf.keras.layers.Dense(b)
|
||||
self._input_dim = 2
|
||||
self._output_dim = 3
|
||||
self._model = make_two_layer_sequential_model(
|
||||
dense_generator, self._input_dim, self._output_dim
|
||||
)
|
||||
|
||||
@parameterized.product(
|
||||
batch_size=[1, 2, 10],
|
||||
l2_norm_clip=[0.1, 1.0, 10],
|
||||
is_eager=[True, False],
|
||||
reduction=['auto', 'sum', 'sum_over_batch_size', 'none'],
|
||||
)
|
||||
def test_clipped_gradients_on_different_losses(
|
||||
self, batch_size, l2_norm_clip, is_eager, reduction
|
||||
):
|
||||
loss_fn = tf.keras.losses.MeanSquaredError(reduction=reduction)
|
||||
self._model.compile(loss=loss_fn, run_eagerly=is_eager)
|
||||
x_batch = tf.reshape(
|
||||
tf.range(batch_size * self._input_dim, dtype=tf.float32),
|
||||
[batch_size, -1],
|
||||
)
|
||||
y_batch = tf.reshape(
|
||||
1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1]
|
||||
)
|
||||
# Stop early for efficiency.
|
||||
if reduction == 'none':
|
||||
self.assertRaises(
|
||||
NotImplementedError,
|
||||
# function tested
|
||||
clip_grads.compute_clipped_gradients_and_outputs,
|
||||
# function args
|
||||
self._model,
|
||||
x_batch,
|
||||
y_batch,
|
||||
l2_norm_clip,
|
||||
layer_registry.make_default_layer_registry(),
|
||||
)
|
||||
return
|
||||
# NOTE: losses from this point are scalar losses.
|
||||
with tf.GradientTape() as tape:
|
||||
y_pred = self._model(x_batch)
|
||||
loss_value = loss_fn(y_pred, y_batch)
|
||||
true_grads = tape.gradient(loss_value, self._model.trainable_variables)
|
||||
clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs(
|
||||
self._model,
|
||||
x_batch,
|
||||
y_batch,
|
||||
l2_norm_clip,
|
||||
layer_registry.make_default_layer_registry(),
|
||||
)
|
||||
|
||||
# Computes the L2 norm manually.
|
||||
def compute_l2_norm(t):
|
||||
sqr_sum_fn = lambda x: tf.reduce_sum(tf.square(x))
|
||||
return tf.sqrt(tf.add_n(tf.nest.map_structure(sqr_sum_fn, t)))
|
||||
|
||||
true_norm = compute_l2_norm(true_grads)
|
||||
computed_norm = compute_l2_norm(clipped_grads)
|
||||
norm_bound = (
|
||||
l2_norm_clip * batch_size if reduction == 'sum' else l2_norm_clip
|
||||
)
|
||||
if true_norm >= norm_bound:
|
||||
# All of the per-example gradient norms should be less than the L2 norm
|
||||
# clip value. Hence, by the triangle inequality, the gradient norm of the
|
||||
# summed loss (averaged loss) should be less than the clip value times
|
||||
# the batch size (just the clip value).
|
||||
self.assertLessEqual(computed_norm, norm_bound)
|
||||
else:
|
||||
self.assertAlmostEqual(computed_norm, true_norm)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in a new issue