diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index 70653e2..c11a93b 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -128,6 +128,16 @@ def compute_gradient_norms( loss_config['reduction'] = tf.keras.losses.Reduction.NONE per_example_loss_fn = input_model.loss.from_config(loss_config) losses = per_example_loss_fn(y_batch, model_outputs) + if losses.shape is None: + raise NotImplementedError( + "The unreduced (or per-example) loss's shape cannot be `None`" + ) + if len(losses.shape) != 1: + raise NotImplementedError( + 'The unreduced (or per-example) loss needs to have a shape of length ' + 'one, but received an unreduced loss of shape length %s' + % len(losses.shape) + ) if num_microbatches is not None: losses = tf.reduce_mean( lr.add_microbatch_axis(losses, num_microbatches), axis=1 @@ -239,6 +249,11 @@ def compute_clipped_gradients_and_outputs( `input_model`, weighted by the loss weights generated by a specific `compute_clip_weights()` call. """ + if input_model.loss.reduction == 'none': + raise NotImplementedError( + 'Fast gradient clipping does not support ' + 'models with unreduced loss functions.' + ) if clipping_loss is None: clipping_loss = input_model.compiled_loss gradient_norms = compute_gradient_norms( diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index a0986fc..8d5ffcc 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -426,7 +426,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase): class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase): - # TODO(wkong): Test sparse input tensors when the GitHub CI environment + # TODO(weiweikong): Test sparse input tensors when the GitHub CI environment # supports them for embeddings. @parameterized.product( x_batch=[ @@ -541,5 +541,83 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase): self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) +class ClipGradsComputeClippedGradsAndOutputsTest( + tf.test.TestCase, parameterized.TestCase +): + + def setUp(self): + super().setUp() + dense_generator = lambda a, b: tf.keras.layers.Dense(b) + self._input_dim = 2 + self._output_dim = 3 + self._model = make_two_layer_sequential_model( + dense_generator, self._input_dim, self._output_dim + ) + + @parameterized.product( + batch_size=[1, 2, 10], + l2_norm_clip=[0.1, 1.0, 10], + is_eager=[True, False], + reduction=['auto', 'sum', 'sum_over_batch_size', 'none'], + ) + def test_clipped_gradients_on_different_losses( + self, batch_size, l2_norm_clip, is_eager, reduction + ): + loss_fn = tf.keras.losses.MeanSquaredError(reduction=reduction) + self._model.compile(loss=loss_fn, run_eagerly=is_eager) + x_batch = tf.reshape( + tf.range(batch_size * self._input_dim, dtype=tf.float32), + [batch_size, -1], + ) + y_batch = tf.reshape( + 1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1] + ) + # Stop early for efficiency. + if reduction == 'none': + self.assertRaises( + NotImplementedError, + # function tested + clip_grads.compute_clipped_gradients_and_outputs, + # function args + self._model, + x_batch, + y_batch, + l2_norm_clip, + layer_registry.make_default_layer_registry(), + ) + return + # NOTE: losses from this point are scalar losses. + with tf.GradientTape() as tape: + y_pred = self._model(x_batch) + loss_value = loss_fn(y_pred, y_batch) + true_grads = tape.gradient(loss_value, self._model.trainable_variables) + clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs( + self._model, + x_batch, + y_batch, + l2_norm_clip, + layer_registry.make_default_layer_registry(), + ) + + # Computes the L2 norm manually. + def compute_l2_norm(t): + sqr_sum_fn = lambda x: tf.reduce_sum(tf.square(x)) + return tf.sqrt(tf.add_n(tf.nest.map_structure(sqr_sum_fn, t))) + + true_norm = compute_l2_norm(true_grads) + computed_norm = compute_l2_norm(clipped_grads) + norm_bound = ( + l2_norm_clip * batch_size if reduction == 'sum' else l2_norm_clip + ) + if true_norm >= norm_bound: + # All of the per-example gradient norms should be less than the L2 norm + # clip value. Hence, by the triangle inequality, the gradient norm of the + # summed loss (averaged loss) should be less than the clip value times + # the batch size (just the clip value). + self.assertLessEqual(computed_norm, norm_bound) + else: + self.assertAlmostEqual(computed_norm, true_norm) + + if __name__ == '__main__': tf.test.main()