diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index 0c37d7f..70653e2 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -21,13 +21,14 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the `compute_gradient_norms()` function). """ -from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union +from typing import Callable, Dict, Iterable, List, Optional, Text, Tuple, Union import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] +LossFn = Callable[..., tf.Tensor] def get_registry_generator_fn( @@ -71,7 +72,7 @@ def compute_gradient_norms( x_batch: InputTensor, y_batch: tf.Tensor, layer_registry: lr.LayerRegistry, - per_example_loss_fn: Optional[Callable[[tf.Tensor, Any], tf.Tensor]] = None, + per_example_loss_fn: Optional[LossFn] = None, num_microbatches: Optional[lr.BatchSize] = None, trainable_vars: Optional[List[tf.Variable]] = None, ): @@ -92,9 +93,9 @@ def compute_gradient_norms( compute gradient norms quickly. See `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for more details. - per_example_loss_fn: If not None, used as the function to compute the - vectorized per example loss. Otherwise, we derive it from `input_model`'s - loss function. + per_example_loss_fn: takes as input predictions, labels and weights, and + outputs a vector of per-example losses. If None, derived from + `input_model.loss` by disabling its reduction. num_microbatches: An optional number or scalar `tf.Tensor` for the number of microbatches. If not None, indicates that the loss is grouped into num_microbatches (in this case, the batch dimension needs to be a multiple @@ -193,7 +194,8 @@ def compute_clipped_gradients_and_outputs( l2_norm_clip: float, layer_registry: lr.LayerRegistry, num_microbatches: Optional[lr.BatchSize] = None, -) -> Tuple[List[tf.Tensor], tf.Tensor, float]: + clipping_loss: Optional[LossFn] = None, +) -> Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor]: """Computes the per-example clipped loss gradient and other useful outputs. Given a batch of observations `(x_batch, y_batch)`, the main steps of this @@ -224,14 +226,21 @@ def compute_clipped_gradients_and_outputs( microbatches. If not None, indicates that the loss is grouped into num_microbatches (in this case, the batch dimension needs to be a multiple of num_microbatches). + clipping_loss: If provided, used for the clipping computation. Defaults to + `input_model.compiled_loss`. Specifying a `clipping_loss` can be useful to + avoid calling `input_model.compiled_loss`, as this will append the value + of the clipped loss to the reported metrics, and this can be misleading as + the value of the clipped loss does not reflect the true loss. Returns: - A `tuple` `(grad, y_pred, weighted_loss_value)`. The first element is the + A `tuple` `(grad, y_pred, clipping_loss_value)`. The first element is the clipped gradient of the loss function, the second is the result of applying `input_model` to `x_batch`, and the third is loss value of `input_model`, weighted by the loss weights generated by a specific `compute_clip_weights()` call. """ + if clipping_loss is None: + clipping_loss = input_model.compiled_loss gradient_norms = compute_gradient_norms( input_model, x_batch, @@ -260,19 +269,10 @@ def compute_clipped_gradients_and_outputs( if num_microbatches is None else lr.add_microbatch_axis(y_pred, num_microbatches) ) - # NOTE: We do not log the loss values here. The caller should invoke - # `input_model.compute_loss()` to log loss values. Specifically, - # calling `input_model.compute_loss()` performs the following steps: - # - # (i) sums `input_model.loss` with the regularization losses given in - # `input_model.losses` to obtain the total loss - # (ii) evaluates the total loss with sample weights (if given) - weighted_loss_value = input_model.loss( - loss_y_batch, loss_y_pred, loss_weights - ) + clipping_loss_value = clipping_loss(loss_y_batch, loss_y_pred, loss_weights) clipped_grads = tape.gradient( - weighted_loss_value, + clipping_loss_value, input_model.trainable_variables, unconnected_gradients=tf.UnconnectedGradients.ZERO, ) - return clipped_grads, y_pred, weighted_loss_value + return clipped_grads, y_pred, clipping_loss_value diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index e7ec85d..176d7be 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -15,10 +15,13 @@ from absl import logging import tensorflow as tf + from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr +_PRIVATIZED_LOSS_NAME = 'privatized_loss' + def make_dp_model_class(cls): """Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it.""" @@ -122,6 +125,7 @@ def make_dp_model_class(cls): self._l2_norm_clip = l2_norm_clip self._noise_multiplier = noise_multiplier self._layer_registry = layer_registry + self._clipping_loss = None # Given that `num_microbatches` was added as an argument after the fact, # this check helps detect unintended calls to the earlier API. @@ -176,15 +180,34 @@ def make_dp_model_class(cls): ) def _compute_per_example_grads(self, data): + if self._clipping_loss is None: + self._make_clipping_loss() microbatched_x, microbatched_y = data with tf.GradientTape() as tape: microbatched_y_pred = self(microbatched_x, training=True) - # NOTE: Calling `self.loss()` neither logs the total loss nor does it - # include any regularization terms. - microbatched_loss = self.loss(microbatched_y, microbatched_y_pred) + # NOTE: `self._clipping_loss` does not include any regularization terms. + microbatched_loss = self._clipping_loss( + microbatched_y, microbatched_y_pred + ) grads_list = tape.gradient(microbatched_loss, self.trainable_variables) clipped_grads = self._process_per_example_grads(grads_list) - return microbatched_loss, clipped_grads + return clipped_grads + + def _make_clipping_loss(self): + """Creates a LossesContainer to be used for clipping. + + To compute the privatized loss, we wrap the model's compiled_loss inside a + new LossesContainer. This lets us avoid calling model.compiled_loss, which + appends the loss value to the returned metrics (we want to avoid this as + the privatized loss does not reflect the true loss and can be misleading). + """ + losses_container_cls = self.compiled_loss.__class__ + self._clipping_loss = losses_container_cls( + self.compiled_loss._user_losses, # pylint:disable=protected-access + loss_weights=self.compiled_loss._user_loss_weights, # pylint:disable=protected-access + output_names=self.output_names, + total_loss_mean=tf.keras.metrics.Mean(name=_PRIVATIZED_LOSS_NAME), + ) def train_step(self, data): """DP-SGD version of base class method. @@ -205,11 +228,16 @@ def make_dp_model_class(cls): Returns: See the base class. """ + if self._clipping_loss is None: + self._make_clipping_loss() output_metrics = {} - x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data) + x, y, weights = tf.keras.utils.unpack_x_y_sample_weight(data) + if weights is not None: + raise NotImplementedError( + 'DPModel does not currently support weighted losses.' + ) batch_size = tf.shape(y)[0] eff_num_microbatches = self._num_microbatches or batch_size - privatized_loss_name = 'privatized_loss' # Branch based on gradient clipping algorithm. if self._enable_fast_peg_computation: @@ -221,7 +249,7 @@ def make_dp_model_class(cls): # trick, and uses these norms to clip the per-example gradients. # NOTE: Reshaping of the input according to the effective number of # microbatches is done here. - clipped_grads, y_pred, weighted_loss = ( + clipped_grads, y_pred, clipping_loss = ( clip_grads.compute_clipped_gradients_and_outputs( self, x, @@ -229,8 +257,10 @@ def make_dp_model_class(cls): self._l2_norm_clip, self._layer_registry, self._num_microbatches, + self._clipping_loss, ) ) + output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss if self._noise_multiplier > 0: grads = gradient_clipping_utils.add_aggregate_noise( self, @@ -241,7 +271,6 @@ def make_dp_model_class(cls): ) else: grads = clipped_grads - output_metrics[privatized_loss_name] = weighted_loss else: logging.info('Computing gradients using original clipping algorithm.') # Computes per-example clipped gradients directly. This is called @@ -249,7 +278,7 @@ def make_dp_model_class(cls): # algorithm. reshape_fn = lambda z: lr.add_microbatch_axis(z, eff_num_microbatches) microbatched_data = tf.nest.map_structure(reshape_fn, data) - microbatched_losses, clipped_grads = tf.vectorized_map( + clipped_grads = tf.vectorized_map( self._compute_per_example_grads, microbatched_data, ) @@ -257,11 +286,6 @@ def make_dp_model_class(cls): grads = tf.nest.map_structure( self._reduce_per_example_grads, clipped_grads ) - if self.loss.reduction == tf.keras.losses.Reduction.SUM: - microbatched_loss = tf.reduce_sum(microbatched_losses) - else: - microbatched_loss = tf.reduce_mean(microbatched_losses) - output_metrics[privatized_loss_name] = microbatched_loss # Add the values and gradients contributed by regularization losses. if self.losses: @@ -277,9 +301,10 @@ def make_dp_model_class(cls): unconnected_gradients=tf.UnconnectedGradients.ZERO, ) grads = [a + b for (a, b) in zip(grads, regularization_grads)] - output_metrics[privatized_loss_name] += summed_regularization_loss + if self._enable_fast_peg_computation: + output_metrics[_PRIVATIZED_LOSS_NAME] += summed_regularization_loss - # Log the true loss. + # Log the true loss, including regularization losses. self.compiled_loss(y, y_pred, regularization_losses=self.losses) # Forward the private gradients to the optimizer and return the results. diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py index 1e4ac8a..d4dc724 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py @@ -28,16 +28,6 @@ def get_data(): return data, labels -def get_layer_registries(): - # Outputs a list of testable layer registries. - # The empty registry {} tests the behavior of the standard approach, - # while the other one tests the fast gradient clipping algorithm. - return [ - layer_registry.LayerRegistry(), - layer_registry.make_default_layer_registry(), - ] - - class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): def testBaseline(self): @@ -65,44 +55,49 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): self.assertAllClose(model_weights[0], [[0.90], [1.20]]) self.assertAllClose(model_weights[1], [0.30]) - @parameterized.named_parameters( - ('l2_norm_clip 10.0', 10.0), - ('l2_norm_clip 40.0', 40.0), - ('l2_norm_clip 200.0', 200.0), + @parameterized.product( + l2_norm_clip=(10.0, 40.0, 200.0), + fast_clipping=(True, False), ) - def testClippingNorm(self, l2_norm_clip): + def testClippingNorm(self, l2_norm_clip, fast_clipping): """Tests that clipping norm works.""" train_data, train_labels = get_data() - for test_reg in get_layer_registries(): - # Simple linear model returns w * x + b. - model = dp_keras_model.DPSequential( - l2_norm_clip=l2_norm_clip, - noise_multiplier=0.0, - layer_registry=test_reg, - layers=[ - tf.keras.layers.InputLayer(input_shape=(2,)), - tf.keras.layers.Dense( - 1, kernel_initializer='zeros', bias_initializer='zeros' - ), - ], - ) - learning_rate = 0.01 - optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) - loss = tf.keras.losses.MeanSquaredError() - model.compile(optimizer=optimizer, loss=loss) - model.fit(train_data, train_labels, epochs=1, batch_size=1) + # Simple linear model returns w * x + b. + model = dp_keras_model.DPSequential( + l2_norm_clip=l2_norm_clip, + noise_multiplier=0.0, + layer_registry=layer_registry.make_default_layer_registry() + if fast_clipping + else None, + layers=[ + tf.keras.layers.InputLayer(input_shape=(2,)), + tf.keras.layers.Dense( + 1, kernel_initializer='zeros', bias_initializer='zeros' + ), + ], + ) + learning_rate = 0.01 + optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) + loss = tf.keras.losses.MeanSquaredError() + model.compile(optimizer=optimizer, loss=loss) + expected_loss = loss(train_labels, model(train_data)) + results = model.fit(train_data, train_labels, epochs=1, batch_size=1) - model_weights = model.get_weights() + model_weights = model.get_weights() - unclipped_gradient = np.sqrt(90**2 + 120**2 + 30**2) - scale = min(1.0, l2_norm_clip / unclipped_gradient) - expected_weights = np.array([[90], [120]]) * scale * learning_rate - expected_bias = np.array([30]) * scale * learning_rate + unclipped_gradient = np.sqrt(90**2 + 120**2 + 30**2) + scale = min(1.0, l2_norm_clip / unclipped_gradient) + expected_weights = np.array([[90], [120]]) * scale * learning_rate + expected_bias = np.array([30]) * scale * learning_rate - # Check parameters are as expected, taking into account the learning rate. - self.assertAllClose(model_weights[0], expected_weights) - self.assertAllClose(model_weights[1], expected_bias) + # Check parameters are as expected, taking into account the learning rate. + self.assertAllClose(model_weights[0], expected_weights) + self.assertAllClose(model_weights[1], expected_bias) + + # Check the value of the loss. + actual_loss = results.history['loss'][0] + self.assertAllClose(expected_loss, actual_loss) def _compute_expected_gradients(self, data, labels, w, l2_norm_clip, num_microbatches): @@ -127,64 +122,61 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): final_grads = np.mean(mb_grads, axis=0) return final_grads - @parameterized.named_parameters( - ('mb_test 0', 1.0, None), - ('mb_test 1', 1.0, 1), - ('mb_test 2', 1.0, 2), - ('mb_test 4', 1.0, 4), + @parameterized.product( + num_microbatches=(None, 1, 2, 4), + fast_clipping=(False, True), ) - def testMicrobatches(self, l2_norm_clip, num_microbatches): + def testMicrobatches(self, num_microbatches, fast_clipping): + l2_norm_clip = 1.0 train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]]) w = np.zeros((2)) train_labels = np.array([[1.0], [3.0], [-2.0], [-4.0]]) learning_rate = 1.0 - for test_reg in get_layer_registries(): - optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) - loss = tf.keras.losses.MeanSquaredError() + optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) + loss = tf.keras.losses.MeanSquaredError() - # Simple linear model returns w * x. - model = dp_keras_model.DPSequential( - l2_norm_clip=l2_norm_clip, - noise_multiplier=0.0, - num_microbatches=num_microbatches, - layer_registry=test_reg, - layers=[ - tf.keras.layers.InputLayer(input_shape=(2,)), - tf.keras.layers.Dense( - 1, use_bias=False, kernel_initializer='zeros' - ), - ], - ) - model.compile(optimizer=optimizer, loss=loss) - model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False) + # Simple linear model returns w * x. + model = dp_keras_model.DPSequential( + l2_norm_clip=l2_norm_clip, + noise_multiplier=0.0, + num_microbatches=num_microbatches, + layer_registry=layer_registry.make_default_layer_registry() + if fast_clipping + else None, + layers=[ + tf.keras.layers.InputLayer(input_shape=(2,)), + tf.keras.layers.Dense( + 1, use_bias=False, kernel_initializer='zeros' + ), + ], + ) + model.compile(optimizer=optimizer, loss=loss) + model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False) - model_weights = np.squeeze(model.get_weights()) + model_weights = np.squeeze(model.get_weights()) - effective_num_microbatches = ( - train_data.shape[0] - if model._num_microbatches is None - else num_microbatches - ) + effective_num_microbatches = ( + train_data.shape[0] + if model._num_microbatches is None + else num_microbatches + ) - expected_grads = self._compute_expected_gradients( - train_data, train_labels, w, l2_norm_clip, effective_num_microbatches - ) - expected_weights = np.squeeze(-learning_rate * expected_grads) - self.assertAllClose(model_weights, expected_weights) + expected_grads = self._compute_expected_gradients( + train_data, train_labels, w, l2_norm_clip, effective_num_microbatches + ) + expected_weights = np.squeeze(-learning_rate * expected_grads) + self.assertAllClose(model_weights, expected_weights) - @parameterized.named_parameters( - ('noise_multiplier 3 2 None', 3.0, 2.0, None), - ('noise_multiplier 5 4 None', 5.0, 4.0, None), - ('noise_multiplier 3 2 1', 3.0, 2.0, 1), - ('noise_multiplier 5 4 1', 5.0, 4.0, 1), - ('noise_multiplier 3 2 2', 3.0, 2.0, 2), - ('noise_multiplier 5 4 2', 5.0, 4.0, 2), - ('noise_multiplier 3 2 4', 3.0, 2.0, 4), - ('noise_multiplier 5 4 4', 5.0, 4.0, 4), + @parameterized.product( + l2_norm_clip=(3.0, 5.0), + noise_multiplier=(2.0, 4.0), + num_microbatches=(None, 1, 2, 4), + fast_clipping=(False, True), ) - def testNoiseMultiplier(self, l2_norm_clip, noise_multiplier, - num_microbatches): + def testNoiseMultiplier( + self, l2_norm_clip, noise_multiplier, num_microbatches, fast_clipping + ): # The idea behind this test is to start with a model whose parameters # are set to zero. We then run one step of a model that produces # an un-noised gradient of zero, and then compute the standard deviation @@ -197,69 +189,69 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): learning_rate = 1.0 - for test_reg in get_layer_registries(): - optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) - loss = tf.keras.losses.MeanSquaredError() + optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) + loss = tf.keras.losses.MeanSquaredError() - # Simple linear model returns w * x + b. - model = dp_keras_model.DPSequential( - l2_norm_clip=l2_norm_clip, - noise_multiplier=noise_multiplier, - num_microbatches=num_microbatches, - layer_registry=test_reg, - layers=[ - tf.keras.layers.InputLayer(input_shape=(1000,)), - tf.keras.layers.Dense( - 1, kernel_initializer='zeros', bias_initializer='zeros' - ), - ], - ) - model.compile(optimizer=optimizer, loss=loss) - model.fit(train_data, train_labels, epochs=1, batch_size=4) + # Simple linear model returns w * x + b. + model = dp_keras_model.DPSequential( + l2_norm_clip=l2_norm_clip, + noise_multiplier=noise_multiplier, + num_microbatches=num_microbatches, + layer_registry=layer_registry.make_default_layer_registry() + if fast_clipping + else None, + layers=[ + tf.keras.layers.InputLayer(input_shape=(1000,)), + tf.keras.layers.Dense( + 1, kernel_initializer='zeros', bias_initializer='zeros' + ), + ], + ) + model.compile(optimizer=optimizer, loss=loss) + model.fit(train_data, train_labels, epochs=1, batch_size=4) - effective_num_microbatches = num_microbatches or train_data.shape[0] + effective_num_microbatches = num_microbatches or train_data.shape[0] - model_weights = model.get_weights() - measured_std = np.std(model_weights[0]) - expected_std = ( - l2_norm_clip * noise_multiplier / effective_num_microbatches - ) + model_weights = model.get_weights() + measured_std = np.std(model_weights[0]) + expected_std = l2_norm_clip * noise_multiplier / effective_num_microbatches - # Test standard deviation is close to l2_norm_clip * noise_multiplier. - self.assertNear(measured_std, expected_std, 0.1 * expected_std) + # Test standard deviation is close to l2_norm_clip * noise_multiplier. + self.assertNear(measured_std, expected_std, 0.1 * expected_std) # Simple check to make sure dimensions are correct when output has # dimension > 1. - @parameterized.named_parameters( - ('mb_test None 2', None, 2), - ('mb_test 1 2', 1, 2), - ('mb_test 2 2', 2, 2), - ('mb_test 4 4', 4, 4), + @parameterized.product( + num_microbatches=(None, 1, 2), + output_dimension=(2, 4), + fast_clipping=(False, True), ) - def testMultiDimensionalOutput(self, num_microbatches, output_dimension): + def testMultiDimensionalOutput( + self, num_microbatches, output_dimension, fast_clipping + ): train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]]) train_labels = np.array([[0], [1], [1], [0]]) learning_rate = 1.0 + optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) + loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) - for test_reg in get_layer_registries(): - optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) - loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) - - model = dp_keras_model.DPSequential( - l2_norm_clip=1.0e9, - noise_multiplier=0.0, - num_microbatches=num_microbatches, - layer_registry=test_reg, - layers=[ - tf.keras.layers.InputLayer(input_shape=(2,)), - tf.keras.layers.Dense( - output_dimension, use_bias=False, kernel_initializer='zeros' - ), - tf.keras.layers.Dense(1), - ], - ) - model.compile(optimizer=optimizer, loss=loss_fn) - model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False) + model = dp_keras_model.DPSequential( + l2_norm_clip=1.0e9, + noise_multiplier=0.0, + num_microbatches=num_microbatches, + layer_registry=layer_registry.make_default_layer_registry() + if fast_clipping + else None, + layers=[ + tf.keras.layers.InputLayer(input_shape=(2,)), + tf.keras.layers.Dense( + output_dimension, use_bias=False, kernel_initializer='zeros' + ), + tf.keras.layers.Dense(1), + ], + ) + model.compile(optimizer=optimizer, loss=loss_fn) + model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False) # Checks that calls to earlier API using `use_xla` as a positional argument # raise an exception. @@ -285,10 +277,10 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): # Simple test to check that regularizer gradients are contributing to the # final gradient. @parameterized.named_parameters( - ('no_registry', None), - ('default_registry', layer_registry.make_default_layer_registry()), + ('fast_clipping', True), + ('no_fast_clipping', False), ) - def testRegularizationGradient(self, registry): + def testRegularizationGradient(self, fast_clipping): input_dim = 10 batch_size = 2 regularizer_multiplier = 0.025 @@ -306,7 +298,9 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): outputs=outputs, l2_norm_clip=1e9, noise_multiplier=0.0, - layer_registry=registry, + layer_registry=layer_registry.make_default_layer_registry() + if fast_clipping + else None, ) model.compile( loss=tf.keras.losses.MeanSquaredError(), @@ -331,10 +325,10 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): # Simple test to check that custom input regularization does NOT contribute # to the gradient. @parameterized.named_parameters( - ('no_registry', None), - ('default_registry', layer_registry.make_default_layer_registry()), + ('fast_clipping', True), + ('no_fast_clipping', False), ) - def testCustomRegularizationZeroGradient(self, registry): + def testCustomRegularizationZeroGradient(self, fast_clipping): input_dim = 10 batch_size = 2 inputs = tf.keras.layers.Input((input_dim,)) @@ -350,7 +344,9 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): outputs=outputs, l2_norm_clip=1e9, noise_multiplier=0.0, - layer_registry=registry, + layer_registry=layer_registry.make_default_layer_registry() + if fast_clipping + else None, ) model.add_loss(tf.reduce_sum(inputs)) model.compile(