From 3fa0a2d362da983a6aeac057cba817810d197550 Mon Sep 17 00:00:00 2001 From: William Kong Date: Wed, 24 Apr 2024 20:45:59 -0700 Subject: [PATCH] Add support for multi-headed models that use fast gradient clipping. PiperOrigin-RevId: 627942683 --- .../fast_gradient_clipping/clip_grads.py | 110 +++++++++++++++--- .../common_manip_utils.py | 24 ++-- .../gradient_clipping_utils.py | 26 ++++- .../registry_functions/layer_normalization.py | 1 - .../privacy/keras_models/dp_keras_model.py | 18 ++- .../keras_models/dp_keras_model_test.py | 61 ++++++++++ 6 files changed, 208 insertions(+), 32 deletions(-) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index 9a5a96e..a316849 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -69,11 +69,80 @@ def get_registry_generator_fn( return registry_generator_fn +def _infer_per_example_loss_fn(model: tf.keras.Model): + """Infer the per-example loss from model config.""" + + def _convert(loss_fn): + loss_config = loss_fn.get_config() + loss_config['reduction'] = tf.keras.losses.Reduction.NONE + return loss_fn.from_config(loss_config) + + model_loss = model.loss + if isinstance(model_loss, tf.keras.losses.Loss): + return _convert(model_loss) + elif isinstance(model_loss, dict): + # Note that we cannot call the public method `.get_compile_config()` because + # it calls a numpy function, which is not supported inside a `tf.function` + # wrapped function. + compile_config = model._compile_config.config # pylint: disable=protected-access + if compile_config is None: + raise ValueError('Model must be compiled for loss function conversion') + # Does a weighted mean of the configured losses. Note that we cannot build + # from the config of the compiled loss because (i) it builds a + # `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s + # during its construction, (ii) non-unique `tf.Variables` cannot be used + # inside a `tf.function`, which is usually where this function is used. + if 'loss_weights' not in compile_config: + raise ValueError( + 'Models with multiple loss must have corresponding loss weights for' + ' loss function conversion' + ) + weights = compile_config['loss_weights'] + per_example_losses = {k: _convert(v) for k, v in model_loss.items()} + num_losses = len(weights) + + def _per_example_loss_fn(y_true, y_pred, sample_weight=None): + loss_values = [] + if model_loss.keys() - y_pred.keys(): + raise ValueError( + 'y_pred must contain the same keys and the model losses, but ' + 'got %s and %s' % (y_pred.keys(), model_loss.keys()) + ) + if model_loss.keys() - y_true.keys(): + raise ValueError( + 'y_true must contain the same keys and the model losses, but ' + 'got %s and %s' % (y_true.keys(), model_loss.keys()) + ) + if sample_weight is not None: + if model_loss.keys() - sample_weight.keys(): + raise ValueError( + 'sample_weight must contain the same keys and the model losses,' + ' but got %s and %s' % (y_true.keys(), model_loss.keys()) + ) + for k in y_true.keys(): + sgl_sample_weight = None if sample_weight is None else sample_weight[k] + sgl_value = ( + weights[k] + * per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight) + / num_losses + ) + loss_values.append(tf.reshape(sgl_value, shape=[-1])) + return tf.math.add_n(loss_values) + + return _per_example_loss_fn + else: + raise ValueError( + 'Unsupported type for loss function conversion: {}'.format( + type(model_loss) + ) + ) + + def compute_gradient_norms( input_model: tf.keras.Model, layer_registry: lr.LayerRegistry, x_batch: type_aliases.InputTensors, - y_batch: tf.Tensor, + y_batch: type_aliases.OutputTensors, weight_batch: Optional[tf.Tensor] = None, per_example_loss_fn: Optional[type_aliases.LossFn] = None, num_microbatches: Optional[type_aliases.BatchSize] = None, @@ -94,9 +163,9 @@ def compute_gradient_norms( more details. x_batch: An `InputTensor` representing a batch of inputs to the model. The first axis must be the batch dimension. - y_batch: A `tf.Tensor` representing a batch of output labels. The first axis - must be the batch dimension. The number of examples should match the - number of examples in `x_batch`. + y_batch: An `OutputTensor` representing a batch of output labels. The first + axes of the tensors must be the batch dimension. The number of examples + should match the number of examples in `x_batch`. weight_batch: Optional batch of weights, passed to the loss function. Weights apply to the loss prior to clipping. per_example_loss_fn: takes as input predictions, labels and weights, and @@ -131,11 +200,11 @@ def compute_gradient_norms( input_model, x_batch, generator_fn=registry_generator_fn ) ) + # Ignore the original loss function's reduction to get per-example loss. if per_example_loss_fn is None: - loss_config = input_model.loss.get_config() - loss_config['reduction'] = tf.keras.losses.Reduction.NONE - per_example_loss_fn = input_model.loss.from_config(loss_config) + per_example_loss_fn = _infer_per_example_loss_fn(input_model) + losses = per_example_loss_fn(y_batch, model_outputs, weight_batch) if losses.shape is None: raise NotImplementedError( @@ -233,7 +302,7 @@ def compute_clipped_gradients_and_outputs( l2_norm_clip: float, layer_registry: lr.LayerRegistry, x_batch: type_aliases.InputTensors, - y_batch: tf.Tensor, + y_batch: type_aliases.OutputTensors, weight_batch: Optional[tf.Tensor] = None, num_microbatches: Optional[type_aliases.BatchSize] = None, clipping_loss: Optional[type_aliases.LossFn] = None, @@ -260,10 +329,10 @@ def compute_clipped_gradients_and_outputs( squared norms of a layer's pre-activation tensor, and `vars` are relevant trainable weights (see `layer_registry_factories.py` for examples). x_batch: An `InputTensor` representing a batch of inputs to the model. The - first axis must be the batch dimension. - y_batch: A `tf.Tensor` representing a batch of output labels. The first axis - must be the batch dimension. The number of examples should match the - number of examples in `x_batch`. + first axes of each tensor must be the batch dimension. + y_batch: An `OutputTensor` representing a batch of output labels. The first + axes of each tensor must be the batch dimension. The number of examples + should match the number of examples in `x_batch`. weight_batch: Optional vector of weights, passed to the loss function. Must be of size [batch_size]. In case of microbatching, this will be reshaped to [num_microbatches, batch_size/num_microbatches] before passing it to @@ -285,11 +354,12 @@ def compute_clipped_gradients_and_outputs( clipping_loss_value: the loss value weighted in such a way that its gradient is `clipped_grad`. """ - if input_model.loss.reduction == 'none': - raise NotImplementedError( - 'Fast gradient clipping does not support ' - 'models with unreduced loss functions.' - ) + if hasattr(input_model.loss, 'reduction'): + if input_model.loss.reduction == 'none': + raise NotImplementedError( + 'Fast gradient clipping does not support ' + 'models with unreduced loss functions.' + ) if clipping_loss is None: clipping_loss = input_model.compiled_loss gradient_norms = compute_gradient_norms( @@ -311,11 +381,13 @@ def compute_clipped_gradients_and_outputs( weight_batch, num_microbatches ) if num_microbatches is None: - clip_weights = clip_weights * weight_batch # shape [num_microbatches] + c = clip_weights # shape [num_microbatches] else: # In this case, weight_batch is of shape [batch_size, microbatch_size], # we multiply by the clip_weights (which is of shape [num_microbatches]) - clip_weights = clip_weights[:, tf.newaxis] * weight_batch + c = clip_weights[:, tf.newaxis] + clip_weights = tf.nest.map_structure(lambda w: c * w, weight_batch) + with tf.GradientTape() as tape: # WARNING: When num_microbatches is not None, we need to be sure that # `compute_loss` always computes the mean over the microbatches diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/common_manip_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/common_manip_utils.py index db05f0a..e5f5907 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/common_manip_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/common_manip_utils.py @@ -20,13 +20,13 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases def maybe_add_microbatch_axis( - x: tf.Tensor, + x: type_aliases.PackedTensors, num_microbatches: Optional[type_aliases.BatchSize], -) -> tf.Tensor: - """Adds the microbatch axis. +) -> type_aliases.PackedTensors: + """Adds the microbatch axis to a collection of tensors. Args: - x: the input tensor. + x: Model output or input tensors. num_microbatches: If None, x is returned unchanged. Otherwise, must divide the batch size. @@ -36,9 +36,13 @@ def maybe_add_microbatch_axis( """ if num_microbatches is None: return x - with tf.control_dependencies( - [tf.assert_equal(tf.math.floormod(tf.shape(x)[0], num_microbatches), 0)] - ): - return tf.reshape( - x, tf.concat([[num_microbatches, -1], tf.shape(x)[1:]], axis=0) - ) + + def _expand(t): + with tf.control_dependencies( + [tf.assert_equal(tf.math.floormod(tf.shape(t)[0], num_microbatches), 0)] + ): + return tf.reshape( + t, tf.concat([[num_microbatches, -1], tf.shape(t)[1:]], axis=0) + ) + + return tf.nest.map_structure(_expand, x) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index adca9af..7416318 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -144,6 +144,30 @@ def all_trainable_layers_are_registered( return True +def _infer_loss_reduction_type(model: tf.keras.Model): + """Infers what type of loss reduction is being performed.""" + model_loss = model.loss + if isinstance(model_loss, tf.keras.losses.Loss): + return model_loss.reduction + elif isinstance(model.loss, dict): + reductions = set() + compiled_loss = model.compiled_loss + if compiled_loss is None: + raise ValueError('Model must be compiled for adding noise') + new_config_list = compiled_loss.get_config()['losses'] + for loss_config in new_config_list: + reductions.add(loss_config['config']['reduction']) + if len(reductions) > 1: + raise ValueError( + 'Reductions in models with multiple losses must all be the same' + ) + return reductions.pop() + else: + raise ValueError( + 'Unsupported type for adding noise: {}'.format(type(model_loss)) + ) + + def add_aggregate_noise( clipped_grads: list[tf.Tensor], batch_size: tf.Tensor, @@ -194,7 +218,7 @@ def add_aggregate_noise( tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, tf.keras.losses.Reduction.AUTO, ] - model_reduction = loss_model.loss.reduction + model_reduction = _infer_loss_reduction_type(loss_model) loss_reduction = ( 'mean' if model_reduction in implicit_mean_reductions else 'sum' ) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py index f9248de..849ace6 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py @@ -82,7 +82,6 @@ def layer_normalization_computation( stacked_grads = common_manip_utils.maybe_add_microbatch_axis( grads, num_microbatches ) - stacked_grads = tf.reduce_sum(stacked_grads, axis=1) reduction_axes = tf.range(1, tf.rank(stacked_grads)) return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes) diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index 6217246..9427713 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -213,6 +213,22 @@ def make_dp_model_class(cls): total_loss_mean=tf.keras.metrics.Mean(name=_PRIVATIZED_LOSS_NAME), ) + def _infer_batch_size(self, t): + """Infers the batch size from a tensor or a container of tensors.""" + if t is None: + return None + elif isinstance(t, tf.Tensor): + t0 = t + elif isinstance(t, list): + t0 = t[0] + elif isinstance(t, dict): + t0 = list(t.values())[0] + else: + raise ValueError( + 'Unsupported type for batch size inference: {}'.format(type(t)) + ) + return tf.shape(t0)[0] + def train_step(self, data): """DP-SGD version of base class method. @@ -236,7 +252,7 @@ def make_dp_model_class(cls): self._make_clipping_loss() output_metrics = {} x, y, weights = tf.keras.utils.unpack_x_y_sample_weight(data) - batch_size = tf.shape(y)[0] + batch_size = self._infer_batch_size(y) num_microbatches = self._num_microbatches or batch_size # Branch based on gradient clipping algorithm. diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py index 23acb37..7cf46d0 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py @@ -425,6 +425,67 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): model.trainable_variables, tf.ones_like(model.trainable_variables) ) + # Checks single optimizer update and consistency with non-DP model. + @parameterized.product( + fast_clipping=[True, False], + num_microbatches=[None, 2], + ) + def testWeightedLoss(self, fast_clipping, num_microbatches): + k1 = 'foo' + k2 = 'bar' + input_dim = 10 + default_layer_registry = ( + layer_registry.make_default_layer_registry() if fast_clipping else None + ) + + def _make_base_model(use_dp: bool): + inputs = { + k1: tf.keras.layers.Input((input_dim,)), + k2: tf.keras.layers.Input((input_dim,)), + } + dense1 = tf.keras.layers.Dense(1, kernel_initializer='ones') + dense2 = tf.keras.layers.Dense(1, kernel_initializer='ones') + outputs = {k1: dense1(inputs[k1]), k2: dense2(inputs[k2])} + if use_dp: + base_model = dp_keras_model.DPModel( + inputs=inputs, + outputs=outputs, + l2_norm_clip=1e9, + noise_multiplier=0.0, + num_microbatches=num_microbatches, + layer_registry=default_layer_registry, + ) + else: + base_model = tf.keras.Model(inputs=inputs, outputs=outputs) + losses = { + k1: tf.keras.losses.MeanSquaredError(), + k2: tf.keras.losses.MeanAbsoluteError(), + } + loss_weights = {k1: 0.4, k2: 0.6} + sgd = tf.keras.optimizers.SGD(1.0) + base_model.compile(loss=losses, loss_weights=loss_weights, optimizer=sgd) + return base_model + + dp_model = _make_base_model(use_dp=True) + model = _make_base_model(use_dp=False) + batch_size = 4 + x_batch = { + k1: tf.reshape( + tf.range(input_dim * batch_size, dtype=tf.float32), + [batch_size, input_dim], + ), + k2: 2.0 * tf.reshape( + tf.range(input_dim * batch_size, dtype=tf.float32), + [batch_size, input_dim], + ), + } + y_batch = {k1: tf.zeros([batch_size, 1]), k2: tf.ones([batch_size, 1])} + + # Checks that gradients align. + model.fit(x=x_batch, y=y_batch, epochs=1, batch_size=4, shuffle=False) + dp_model.fit(x=x_batch, y=y_batch, epochs=1, batch_size=4, shuffle=False) + self.assertAllClose(dp_model.trainable_variables, model.trainable_variables) + if __name__ == '__main__': tf.test.main()