diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index 20fe19a..23f1b0c 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -21,7 +21,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the `compute_gradient_norms()` function). """ -from typing import Any, Callable, Dict, Iterable, Optional, Text, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils @@ -129,11 +129,14 @@ def compute_gradient_norms( vars_list = [a for (a, b) in filtered_outputs] sqr_norm_fns_list = [b for (a, b) in filtered_outputs] # Second loop evaluates the squared L2 norm functions and appends the results. - grads_list = tape.gradient(summed_loss, vars_list) + grads_list = tape.gradient( + summed_loss, + vars_list, + unconnected_gradients=tf.UnconnectedGradients.ZERO, + ) sqr_norm_list = [] for grads, f in zip(grads_list, sqr_norm_fns_list): sqr_norm_list.append(f(grads)) - del tape sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1) return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) @@ -163,24 +166,23 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor): return l2_norm_clip / tf.math.maximum(l2_norm_clip, gradient_norms) -def compute_pred_and_clipped_gradients( +def compute_clipped_gradients_and_outputs( input_model: tf.keras.Model, x_batch: InputTensor, y_batch: tf.Tensor, l2_norm_clip: float, layer_registry: lr.LayerRegistry, num_microbatches: Optional[lr.BatchSize] = None, -): - """Computes the per-example predictions and per-example clipped loss gradient. +) -> Tuple[List[tf.Tensor], tf.Tensor, float]: + """Computes the per-example clipped loss gradient and other useful outputs. Given a batch of observations `(x_batch, y_batch)`, the main steps of this function are: (i) compute the l2-norm of the gradients of the trainable variables of `input_model` for each example in the batch; (ii) use the norms computed in (i) to obtain "clip_weights" that are used to generate a weighted loss function whose gradient for each example has l2-norm at most - `l2_norm_clip`; (iii) output the clipped gradients in (ii) and the - `tf.Tensor` generated by `input_model` when it is given `x_batch` as its - input. + `l2_norm_clip`; (iii) output the clipped gradients in (ii) and other useful + outputs to the caller. Args: input_model: The `tf.keras.Model` from which to obtain the layers from. @@ -204,9 +206,11 @@ def compute_pred_and_clipped_gradients( of num_microbatches). Returns: - A `tuple` `(y_pred, grad)`. The first element is the prediction generated by - the model on the input `x_batch`. The second element is the clipped - gradient of the loss function. + A `tuple` `(grad, y_pred, weighted_loss_value)`. The first element is the + clipped gradient of the loss function, the second is the result of + applying `input_model` to `x_batch`, and the third is loss value of + `input_model`, weighted by the loss weights generated by a specific + `compute_clip_weights()` call. """ gradient_norms = compute_gradient_norms( input_model, @@ -217,18 +221,37 @@ def compute_pred_and_clipped_gradients( ) loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms) with tf.GradientTape() as tape: - y_pred = input_model(x_batch, training=True) - if num_microbatches is not None: - y_batch = lr.add_microbatch_axis(y_batch, num_microbatches) - y_pred = lr.add_microbatch_axis(y_pred, num_microbatches) - # Warning: When num_microbatches is not None, we need to be sure that + # WARNING: When num_microbatches is not None, we need to be sure that # `compute_loss` always computes the mean over the microbatches # as it is the assumption made when computing the gradient norm. # It is indeed the case for multiple keras loss functions # (e.g. mean_squared_error and binary_crossentropy). However it # is not defined in the contract so may not hold, especially for # custom losses. - loss_value = input_model.compute_loss( - x_batch, y_batch, y_pred, loss_weights + y_pred = input_model(x_batch, training=True) + loss_y_batch = ( + y_batch + if num_microbatches is None + else lr.add_microbatch_axis(y_batch, num_microbatches) ) - return y_pred, tape.gradient(loss_value, input_model.trainable_variables) + loss_y_pred = ( + y_pred + if num_microbatches is None + else lr.add_microbatch_axis(y_pred, num_microbatches) + ) + # NOTE: We do not log the loss values here. The caller should invoke + # `input_model.compute_loss()` to log loss values. Specifically, + # calling `input_model.compute_loss()` performs the following steps: + # + # (i) sums `input_model.loss` with the regularization losses given in + # `input_model.losses` to obtain the total loss + # (ii) evaluates the total loss with sample weights (if given) + weighted_loss_value = input_model.loss( + loss_y_batch, loss_y_pred, loss_weights + ) + clipped_grads = tape.gradient( + weighted_loss_value, + input_model.trainable_variables, + unconnected_gradients=tf.UnconnectedGradients.ZERO, + ) + return clipped_grads, y_pred, weighted_loss_value diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index 7bafd18..abf4fad 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -17,14 +17,14 @@ from absl import logging import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils +from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr def make_dp_model_class(cls): """Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it.""" class DPModelClass(cls): # pylint: disable=missing-class-docstring - __doc__ = ( - """DP subclass of `{base_model}`. + __doc__ = ("""DP subclass of `{base_model}`. This can be used as a differentially private replacement for {base_model}. This class implements DP-SGD using the standard @@ -32,7 +32,7 @@ def make_dp_model_class(cls): This class also utilizes a faster gradient clipping algorithm if the following two conditions hold: - (i) the trainable layers of the model are keys in the `dict` input + (i) the trainable layers of the model are keys in the input `layer_registry`, (ii) the loss `tf.Tensor` for a given batch of examples is either a scalar or a 2D `tf.Tensor` that has only one column @@ -56,6 +56,11 @@ def make_dp_model_class(cls): It is the caller's responsibility to make sure that the loss function does behave this way. + WARNING: This API does not have privacy guarantees for custom + layer-level losses created by the `layer.add_loss()` API. It does, + however, support layer regularization losses. All of these layer-level + losses are found in `model.losses`. + When instantiating this class, you need to supply several DP-related arguments followed by the standard arguments for `{short_base_model}`. @@ -83,8 +88,7 @@ def make_dp_model_class(cls): model.fit(train_data, train_labels, epochs=1, batch_size=32) ``` - """ - ).format( + """).format( base_model='tf.keras.' + cls.__name__, short_base_model=cls.__name__, dp_model_class='DP' + cls.__name__, @@ -124,8 +128,10 @@ def make_dp_model_class(cls): # In particular, boolean values supplied to `use_xla` in the earlier API # will raise an error. if isinstance(num_microbatches, bool): - raise ValueError('Boolean value supplied for `num_microbatches`. ' - 'Did you intend it for `use_xla`?') + raise ValueError( + 'Boolean value supplied for `num_microbatches`. ' + 'Did you intend it for `use_xla`?' + ) self._num_microbatches = num_microbatches # If all the trainable layers are in the input layer registry, we @@ -144,7 +150,8 @@ def make_dp_model_class(cls): if use_xla: self.train_step = tf.function( - self.train_step, experimental_compile=True) + self.train_step, experimental_compile=True + ) def _process_per_example_grads(self, grads): grads_flat = tf.nest.flatten(grads) @@ -152,7 +159,7 @@ def make_dp_model_class(cls): tf.reduce_sum(input_tensor=tf.square(g)) for g in grads_flat ] global_norm = tf.sqrt(tf.add_n(squared_l2_norms)) - div = tf.maximum(global_norm / self._l2_norm_clip, 1.) + div = tf.maximum(global_norm / self._l2_norm_clip, 1.0) clipped_flat = [g / div for g in grads_flat] return tf.nest.pack_sequence_as(grads, clipped_flat) @@ -160,19 +167,23 @@ def make_dp_model_class(cls): summed_grads = tf.reduce_sum(input_tensor=stacked_grads, axis=0) noise_stddev = self._l2_norm_clip * self._noise_multiplier noise = tf.random.normal( - tf.shape(input=summed_grads), stddev=noise_stddev) + tf.shape(input=summed_grads), stddev=noise_stddev + ) noised_grads = summed_grads + noise - return noised_grads / tf.cast(stacked_grads.shape[0], noised_grads.dtype) + return noised_grads / tf.cast( + tf.shape(stacked_grads)[0], noised_grads.dtype + ) def _compute_per_example_grads(self, data): - x, y = data + microbatched_x, microbatched_y = data with tf.GradientTape() as tape: - y_pred = self(x, training=True) - loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses) - - grads_list = tape.gradient(loss, self.trainable_variables) + microbatched_y_pred = self(microbatched_x, training=True) + # NOTE: Calling `self.loss()` neither logs the total loss nor does it + # include any regularization terms. + microbatched_loss = self.loss(microbatched_y, microbatched_y_pred) + grads_list = tape.gradient(microbatched_loss, self.trainable_variables) clipped_grads = self._process_per_example_grads(grads_list) - return y_pred, loss, clipped_grads + return microbatched_loss, clipped_grads def train_step(self, data): """DP-SGD version of base class method. @@ -184,7 +195,7 @@ def make_dp_model_class(cls): condition is satisfied if the model subclasses the keras.Sequential or keras.engine.functional.Functional class). - If (i) and (ii) above do not hold, then clips and aggregates + If (i) and (ii) above do not hold, then this function clips and aggregates gradients at the microbatch level. Args: @@ -193,6 +204,13 @@ def make_dp_model_class(cls): Returns: See the base class. """ + output_metrics = {} + x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data) + batch_size = tf.shape(y)[0] + eff_microbatch_size = self._num_microbatches or batch_size + privatized_loss_name = 'privatized_loss' + + # Branch based on gradient clipping algorithm. if self._enable_fast_peg_computation: logging.info( 'Computing gradients using the fast per-example gradient ' @@ -200,60 +218,73 @@ def make_dp_model_class(cls): ) # Computes the per-example gradient norms using a "fast" clipping # trick, and uses these norms to clip the per-example gradients. - x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data) - y_pred, clipped_grads = clip_grads.compute_pred_and_clipped_gradients( - self, - x, - y, - self._l2_norm_clip, - self._layer_registry, - self._num_microbatches, + # NOTE: Reshaping of the input according to the effective number of + # microbatches is done here. + clipped_grads, y_pred, weighted_loss = ( + clip_grads.compute_clipped_gradients_and_outputs( + self, + x, + y, + self._l2_norm_clip, + self._layer_registry, + self._num_microbatches, + ) ) - batch_size = self._num_microbatches or tf.shape(y)[0] grads = gradient_clipping_utils.add_aggregate_noise( self, clipped_grads, - batch_size, + eff_microbatch_size, self._l2_norm_clip, self._noise_multiplier, ) + output_metrics[privatized_loss_name] = weighted_loss else: logging.info('Computing gradients using microbatching.') # Computes per-example clipped gradients directly. This is called # if at least one of the layers cannot use the "fast" gradient clipping # algorithm. - # TODO(wkong): check if the following is valid with sample weights. - _, y = data - batch_size = y.shape[0] - - if self._num_microbatches is None: - self._num_microbatches = batch_size - if batch_size % self._num_microbatches != 0: - raise ValueError('Number of_microbatches must divide batch size.') - - def reshape_fn(x): - new_shape = ( - self._num_microbatches, - batch_size // self._num_microbatches, - ) + x.shape[1:] - return tf.reshape(x, new_shape) - - data = tf.nest.map_structure(reshape_fn, data) - - y_pred, _, per_eg_grads = tf.vectorized_map( - self._compute_per_example_grads, data + reshape_fn = lambda z: lr.add_microbatch_axis(z, eff_microbatch_size) + microbatched_data = tf.nest.map_structure(reshape_fn, data) + microbatched_losses, clipped_grads = tf.vectorized_map( + self._compute_per_example_grads, + microbatched_data, ) - - y_pred = tf.reshape(y_pred, (batch_size) + y_pred.shape[2:]) - + y_pred = self(x, training=True) grads = tf.nest.map_structure( - self._reduce_per_example_grads, per_eg_grads + self._reduce_per_example_grads, clipped_grads ) + if self.loss.reduction == tf.keras.losses.Reduction.SUM: + microbatched_loss = tf.reduce_sum(microbatched_losses) + else: + microbatched_loss = tf.reduce_mean(microbatched_losses) + output_metrics[privatized_loss_name] = microbatched_loss + + # Add the values and gradients contributed by regularization losses. + if self.losses: + logging.warning( + 'Losses in `model.losses` must be input (batch) independent in ' + 'order to obtain the desired differential privacy guarantees.' + ) + with tf.GradientTape() as tape: + summed_regularization_loss = tf.add_n(self.losses) + regularization_grads = tape.gradient( + summed_regularization_loss, + self.trainable_variables, + unconnected_gradients=tf.UnconnectedGradients.ZERO, + ) + grads = [a + b for (a, b) in zip(grads, regularization_grads)] + output_metrics[privatized_loss_name] += summed_regularization_loss + + # Log the true loss. + self.compiled_loss(y, y_pred, regularization_losses=self.losses) # Forward the private gradients to the optimizer and return the results. self.optimizer.apply_gradients(zip(grads, self.trainable_variables)) self.compiled_metrics.update_state(y, y_pred) - return {m.name: m.result() for m in self.metrics} + for m in self.metrics: + output_metrics[m.name] = m.result() + + return output_metrics return DPModelClass diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py index 59e60a4..1e4ac8a 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py @@ -282,5 +282,93 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): ], ) + # Simple test to check that regularizer gradients are contributing to the + # final gradient. + @parameterized.named_parameters( + ('no_registry', None), + ('default_registry', layer_registry.make_default_layer_registry()), + ) + def testRegularizationGradient(self, registry): + input_dim = 10 + batch_size = 2 + regularizer_multiplier = 0.025 + inputs = tf.keras.layers.Input((input_dim,)) + dense_lyr = tf.keras.layers.Dense( + 1, + kernel_initializer='ones', + use_bias=False, + kernel_regularizer=tf.keras.regularizers.L2(regularizer_multiplier), + ) + # Zero-out outputs to avoid contributions from the main loss function. + outputs = tf.multiply(dense_lyr(inputs), 0.0) + model = dp_keras_model.DPModel( + inputs=inputs, + outputs=outputs, + l2_norm_clip=1e9, + noise_multiplier=0.0, + layer_registry=registry, + ) + model.compile( + loss=tf.keras.losses.MeanSquaredError(), + optimizer=tf.keras.optimizers.SGD(1.0), + run_eagerly=True, + ) + x_batch = tf.reshape( + tf.range(input_dim * batch_size, dtype=tf.float32), + [batch_size, input_dim], + ) + y_batch = tf.zeros([batch_size, 1]) + model.fit(x=x_batch, y=y_batch) + + self.assertAllClose( + model.trainable_variables, + tf.multiply( + tf.ones_like(model.trainable_variables), + 1.0 - 2.0 * regularizer_multiplier, + ), + ) + + # Simple test to check that custom input regularization does NOT contribute + # to the gradient. + @parameterized.named_parameters( + ('no_registry', None), + ('default_registry', layer_registry.make_default_layer_registry()), + ) + def testCustomRegularizationZeroGradient(self, registry): + input_dim = 10 + batch_size = 2 + inputs = tf.keras.layers.Input((input_dim,)) + dense_lyr = tf.keras.layers.Dense( + 1, + kernel_initializer='ones', + use_bias=False, + ) + # Zero-out outputs to avoid contributions from the main loss function. + outputs = tf.multiply(dense_lyr(inputs), 0.0) + model = dp_keras_model.DPModel( + inputs=inputs, + outputs=outputs, + l2_norm_clip=1e9, + noise_multiplier=0.0, + layer_registry=registry, + ) + model.add_loss(tf.reduce_sum(inputs)) + model.compile( + loss=tf.keras.losses.MeanSquaredError(), + optimizer=tf.keras.optimizers.SGD(1.0), + run_eagerly=True, + ) + x_batch = tf.reshape( + tf.range(input_dim * batch_size, dtype=tf.float32), + [batch_size, input_dim], + ) + y_batch = tf.zeros([batch_size, 1]) + model.fit(x=x_batch, y=y_batch) + + self.assertAllClose( + model.trainable_variables, tf.ones_like(model.trainable_variables) + ) + + if __name__ == '__main__': tf.test.main()