diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD new file mode 100644 index 0000000..110d054 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD @@ -0,0 +1,25 @@ +load("@rules_python//python:defs.bzl", "py_library") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "gradient_clipping_utils", + srcs = ["gradient_clipping_utils.py"], + srcs_version = "PY3", +) + +py_library( + name = "layer_registry_factories", + srcs = ["layer_registry_factories.py"], + srcs_version = "PY3", +) + +py_library( + name = "clip_grads", + srcs = ["clip_grads.py"], + srcs_version = "PY3", + deps = [ + ":gradient_clipping_utils", + ":layer_registry_factories", + ], +) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py new file mode 100644 index 0000000..93840d6 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -0,0 +1,185 @@ +# Copyright 2022, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Computes per-example loss clip weights. + +For a given Keras model and batch of inputs, computes the per-example +clip weights so that the gradient of the loss function, weighted by these +weights, is equivalent to the gradient of the original loss function but +with the per-example gradients clipped by some clip weight. Uses a variant +of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the +`compute_gradient_norms()` function). +""" + +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils + + +def combine_pre_and_post_sqr_norms(pre_sqr_norm, post_grad, layer_hash): + """Combines pre and post-activation tensors for a given variable. + + The logic for combining norms depends on the variable's underlying layer. + + Args: + pre_sqr_norm: A `tf.Tensor` whose first dimension is the batch dimension. + Contains squared norms that are related to the pre-activation Tensor. + post_grad: A `tf.Tensor` whose first dimension is the batch dimension. + Contains gradients that are related to the post-activation Tensor. + layer_hash: A `float` that is the hash of the variable's underlying layer + class. + + Returns: + A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th + per-example loss function with respect to the given variable. + """ + post_sqr_grads = tf.square(post_grad) + if layer_hash == hash(tf.keras.layers.Embedding): + scaled_grads = tf.expand_dims(pre_sqr_norm, axis=-1) * post_sqr_grads + reduction_axes = tf.range(1, tf.rank(scaled_grads)) + return tf.reduce_sum(scaled_grads, axis=reduction_axes) + else: + reduction_axes = tf.range(1, tf.rank(post_sqr_grads)) + post_sqr_norm = tf.reduce_sum(post_sqr_grads, axis=reduction_axes) + return pre_sqr_norm * post_sqr_norm + + +def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry): + """Computes the per-example loss gradient norms for given data. + + Applies the approach given in https://arxiv.org/pdf/2009.03106.pdf, except + the batch matrix multiplication operation in Algorithm 2 is replaced with + the computation of two norm computations. + + Args: + input_model: The `tf.keras.Model` from which to obtain the layers from. The + loss of the model *must* be a scalar loss. + x_batch: A `tf.Tensor` representing a batch of inputs to the model. The + first axis must be the batch dimension. + y_batch: A `tf.Tensor` representing a batch of output labels. The first axis + must be the batch dimension. The number of examples should match the + number of examples in `x_batch`. + layer_registry: A `dict` of layers that support "fast" gradient norm + computations. The key is the class of the layer and the value is a + function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where + `output` is the pre-activator tensor, `sqr_grad_norms` is related to the + squared norms of a layer's pre-activation tensor, and `vars` are relevant + trainable weights (see `layer_registry_factories.py` for examples). + + Returns: + A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th + per-example loss function. + """ + tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) + # First loop computes the norms of the layer inputs, caches these inputs, + # and computes the summed loss. + with tape: + model_outputs, pre_norm_list, var_list, layer_hash_list = ( + gradient_clipping_utils.forward_norm_pass( + input_model, x_batch, tape, layer_registry + ) + ) + # Ignore the original loss function's reduction to get per-example loss. + loss_config = input_model.loss.get_config() + loss_config['reduction'] = tf.keras.losses.Reduction.NONE + per_example_loss_fn = input_model.loss.from_config(loss_config) + losses = per_example_loss_fn(y_batch, model_outputs) + if tf.rank(tf.squeeze(losses)) > 1: + raise NotImplementedError('Vector losses are not supported.') + summed_loss = tf.reduce_sum(losses) + # Second loop computes the norm of the gradient of the loss with respect to + # the pre-activation tensors, and multiplies these norms with the results of + # the first loop. + full_norm_list = [] + grads = tape.gradient(summed_loss, var_list) + for i in range(len(var_list)): + full_norm = combine_pre_and_post_sqr_norms( + pre_norm_list[i], grads[i], layer_hash_list[i] + ) + full_norm_list.append(full_norm) + del tape + # Post-processing for compatibility with non-eager mode (very annoying). + full_norm_tsr = tf.stack(full_norm_list, axis=1) + return tf.sqrt(tf.reduce_sum(full_norm_tsr, axis=1)) + + +def compute_clip_weights(l2_norm_clip, gradient_norms): + """Computes the per-example loss/clip weights for clipping. + + When the sum of the per-example losses is replaced a weighted sum, where + the weights are generated by this method, then the gradients of each + term in the weighted sum are clipped by the given clip value. + + Args: + l2_norm_clip: A `float` indicating the norm to which per-example gradients + will be clipped. That is, all gradients of the per-example loss functions + will have norm at most `l2_norm_clip`. + gradient_norms: A 1D `tf.Tensor` whose i-th entry is the norm of the + gradient of the loss function for the i-th input. + + Returns: + A 1D `tf.Tensor` representing whose i-th entry `C[i]` is either `1.0` if the + norm of the gradient of i-th per-example loss `G[i]` is less than + `l2_norm_clip` or a number less than `1.0` so that + `|G[i]| * C[i] == l2_norm_clip` otherwise. + """ + if l2_norm_clip is None: + return None + return l2_norm_clip / tf.math.maximum(l2_norm_clip, gradient_norms) + + +def compute_pred_and_clipped_gradients( + input_model, x_batch, y_batch, l2_norm_clip, layer_registry +): + """Computes the per-example predictions and per-example clipped loss gradient. + + Given a batch of observations `(x_batch, y_batch)`, the main steps of this + function are: (i) compute the l2-norm of the gradients of the trainable + variables of `input_model` for each example in the batch; (ii) use the norms + computed in (i) to obtain "clip_weights" that are used to generate a weighted + loss function whose gradient for each example has l2-norm at most + `l2_norm_clip`; (iii) output the clipped gradients in (ii) and the + `tf.Tensor` generated by `input_model` when it is given `x_batch` as its + input. + + Args: + input_model: The `tf.keras.Model` from which to obtain the layers from. + x_batch: A `tf.Tensor` representing a batch of inputs to the model. The + first axis must be the batch dimension. + y_batch: A `tf.Tensor` representing a batch of output labels. The first axis + must be the batch dimension. The number of examples should match the + number of examples in `x_batch`. + l2_norm_clip: A `float` indicating the norm to which per-example gradients + will be clipped. That is, all gradients of the per-example loss functions + will have norm at most `l2_norm_clip`. + layer_registry: A `dict` of layers that support "fast" gradient norm + computations. The key is the class of the layer and the value is a + function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where + `output` is the pre-activator tensor, `sqr_grad_norms` is related to the + squared norms of a layer's pre-activation tensor, and `vars` are relevant + trainable weights (see `layer_registry_factories.py` for examples). + + Returns: + A `tuple` `(y_pred, grad)`. The first element is the prediction generated by + the model on the input `x_batch`. The second element is the clipped + gradient of the loss function. + """ + gradient_norms = compute_gradient_norms( + input_model, x_batch, y_batch, layer_registry + ) + loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms) + with tf.GradientTape() as tape: + y_pred = input_model(x_batch, training=True) + loss_value = input_model.compute_loss( + x_batch, y_batch, y_pred, loss_weights + ) + return y_pred, tape.gradient(loss_value, input_model.trainable_variables) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py new file mode 100644 index 0000000..96e95ed --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -0,0 +1,214 @@ +# Copyright 2022, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions that help in the computation of per-example gradient norms.""" + +from absl import logging +import tensorflow as tf + + +def has_internal_compute_graph(input_object): + """Checks if input is a TF model and has a TF internal compute graph.""" + return ( + isinstance(input_object, tf.keras.Model) + and hasattr(input_object, '_flatten_to_reference_inputs') + and hasattr(input_object, '_tensor_usage_count') + and hasattr(input_object, '_conform_to_reference_input') + and hasattr(input_object, '_nodes_by_depth') + ) + + +def forward_norm_pass(input_model, x_batch, tape, layer_registry): + """Does a forward pass of a model and returns some useful intermediates. + + NOTE: the graph traversal algorithm is an adaptation of the logic in the + _run_internal_graph() method in the functional.Functional class. Hence, + forward_norm_pass should only be invoked if the generated model + instance is an instance of the functional.Functional class. + + Args: + input_model: A Keras functional model to compute the quantities for. + x_batch: A collection of Tensors to be fed into the input layer of the + model. + tape: A tf.GradientTape() instance used to record certain operations. + Assumes that this function is being called inside of this tape. + layer_registry: A dictionary of layers that support "fast" gradient norm + computations. The key is the class of the layer and the value is a + function that returns a triple (norm_list, var_list, transform). For more + details, see `layer_registry_factories.py`. + + Returns: + Four objects (outputs, norm_list, var_list, layer_hash_list). The first are + the outputs that are generated as a result of a forward pass. The second is + either `None` or a collection of squared norms for each example that helps + in computing the gradient norm of an example (if such a "nice" computation + exists). The third is an ordered list of tf.Tensor() objects that are + intended to be evaluated with respect to the summed loss of the model. The + fourth is a list whose i-th element is the hash of the layer class of the + i-th element of var_list. + """ + # TODO: Avoid or remove the references to protected methods of `input_model`. # pylint: disable=g-bad-todo + # Prepare the inputs and BFS variables. + flattened_inputs = input_model._flatten_to_reference_inputs(x_batch) # pylint: disable=protected-access + tensor_dict = {} + tensor_usage_count = input_model._tensor_usage_count # pylint: disable=protected-access + for x, y in zip(input_model.inputs, flattened_inputs): + y = input_model._conform_to_reference_input(y, ref_input=x) # pylint: disable=protected-access + x_id = str(id(x)) + tensor_dict[x_id] = [y] * tensor_usage_count[x_id] + + # Main computations. + nodes_by_depth = input_model._nodes_by_depth # pylint: disable=protected-access + depth_keys = list(nodes_by_depth.keys()) + depth_keys.sort(reverse=True) + var_list = [] + norm_list = [] + layer_hash_list = [] + node_outputs = None + # Perform BFS feedforward computations. + for depth in depth_keys: + for node in nodes_by_depth[depth]: + if node.is_input: + continue # inputs already exist + if any(t_id not in tensor_dict for t_id in node.flat_input_ids): + continue # node is not computable; try skipping + args, kwargs = node.map_arguments(tensor_dict) + if has_internal_compute_graph(node.layer): + # If this node has an internal computational graph, we can recurse. + node_outputs, node_norm_list, node_var_list, node_layer_hash_list = ( + forward_norm_pass(node.layer, args, tape, layer_registry) + ) + var_list += node_var_list + norm_list += node_norm_list + layer_hash_list += node_layer_hash_list + else: + # Either pass through or record some metadata. + if not node.layer.trainable_variables: + node_outputs = node.layer(*args, **kwargs) + else: + lyr_hash = hash(node.layer.__class__) + if lyr_hash not in layer_registry: + raise NotImplementedError( + 'Layer %s is not in the registry of known layers that can' + 'be used for efficient gradient clipping.' + % node.layer.__class__.__name__ + ) + lyr = layer_registry[lyr_hash] + node_norms, node_vars, transform = lyr(node.layer, args) + tape.watch(node_vars) + node_outputs = transform(node_vars) if transform else node_vars + var_list.append(node_vars) + norm_list.append(node_norms) + layer_hash_list.append(lyr_hash) + # update the current dictionary of inputs for the next node. + for x_id, y in zip(node.flat_output_ids, tf.nest.flatten(node_outputs)): + tensor_dict[x_id] = [y] * tensor_usage_count[x_id] + + return node_outputs, norm_list, var_list, layer_hash_list + + +def get_trainable_hidden_layers(input_model): + """Obtains the trainable hidden layers of a Keras model. + + Args: + input_model: The Keras model to obtain the layers from. + + Returns: + A list of Keras layers where the tensorflow.keras.layers.Layer + ancestor class MUST precede any existing tensorflow.keras.models.Model + ancestor class. + """ + hidden_layers = [] + for l in input_model.layers: + for c in l.__class__.__mro__: + if c == tf.keras.models.Model: + hidden_layers += get_trainable_hidden_layers(l) + break + elif c == tf.keras.layers.InputLayer: + break + elif c == tf.keras.layers.Layer: + if l.trainable_variables: + hidden_layers.append(l) + break + return hidden_layers + + +def all_trainable_layers_are_registered(input_model, layer_registry): + """Check if an input model's trainable layers are all registered. + + Args: + input_model: The Keras model from which to obtain the layers from. + layer_registry: A dictionary of layers that support "fast" gradient norm + computations. The key is the class of the layer and the value is a + function that returns a triple (output, sqr_grad_norms, vars), where + output is the pre-activator tensor, sqr_grad_norms is the square of the + norm of the layer's input, and vars is an ordered list of the trainable + weights. + + Returns: + True if all the trainable layers in `input_model` are in `layer_registry`. + False otherwise. + """ + hidden_layers = get_trainable_hidden_layers(input_model) + for l in hidden_layers: + if hash(l.__class__) not in layer_registry: + return False + return True + + +def add_aggregate_noise( + input_model, x_batch, clipped_grads, l2_norm_clip, noise_multiplier +): + """Adds noise to a collection of clipped gradients. + + The magnitude of the noise depends on the aggregation strategy of the + input model's loss function. + + Args: + input_model: The Keras model to obtain the layers from. + x_batch: A collection of Tensors to be fed into the input layer of the + model. + clipped_grads: A list of tensors representing the clipped gradients. + l2_norm_clip: Clipping norm (max L2 norm of each gradient). + noise_multiplier: Ratio of the standard deviation to the clipping norm. + + Returns: + A list of tensors containing the clipped gradients, but with the right + amount of Gaussian noise added to them (depending on the reduction + strategy of the loss function). + """ + scale = l2_norm_clip + if input_model.loss.reduction in [ + tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, + tf.keras.losses.Reduction.AUTO, + ]: + if input_model.loss.reduction == tf.keras.losses.Reduction.AUTO: + logging.info('Assuming that the loss reduction is `SUM_OVER_BATCH_SIZE`.') + if isinstance(x_batch, tf.Tensor): + scale /= tf.cast(tf.shape(x_batch)[0], tf.float32) + elif isinstance(x_batch, dict): + batch_sizes = [ + tf.cast(tf.shape(v)[0], tf.float32) for v in x_batch.values() + ] + scale /= tf.math.reduce_min(batch_sizes) + else: + raise NotImplementedError( + 'Unknown container/class %s for input' % x_batch.__class__.__name__ + ) + + def add_noise(g): + return g + tf.random.normal( + tf.shape(g), mean=0.0, stddev=noise_multiplier * scale + ) + + return tf.nest.map_structure(add_noise, clipped_grads) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry_factories.py b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry_factories.py new file mode 100644 index 0000000..6b434de --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry_factories.py @@ -0,0 +1,140 @@ +# Copyright 2022, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generates a default layer registry. + +Defines "fast" gradient norm layer registry functions for use in the "fast" +gradient clipping algorithm. Specifically, each registry function takes +in two inputs (i) a layer instance and (ii) `tf.Tensor` inputs to produce three +outputs: (a) a `tf.Tensor` `G` of gradient norms, (b) a differentiable +`tf.Tensor` `Z`, and (c) either `None` or a function that maps the object in +(b) to the layer instance's output when using the inputs in (ii). If (c) is +`None`, then (b) contains the layer outputs. + +When a layer registry function is defined, it is generally assumed that the +following relation holds for each pair `(g, z)` in `zip(G, Z)`: + + `|dL/dw|^2 == |dL/dz|^2 * g^2` + +where `L` is any per-example loss function and `w` are the trainable variables +corresponding to `(g, z)`. + +For example, this relation holds if the layer instance is tf.keras.layers.Dense, +Z contains the pre-activation tensors, i.e., `z = X * w` for input `X`, and `g` +is the norm of the input corresponding to the given per-example loss (see the +formulae in https://arxiv.org/abs/1510.01799 for more details). + +The registry functions are registered in a `dict` (registry) whose key is the +hash of the layer class and whose value is the registry function. +""" + +import tensorflow as tf + + +# ============================================================================== +# Supported Keras layers +# ============================================================================== +def dense_layer_computation(layer_instance, inputs): + """Registry function for `tf.keras.layers.Dense`. + + The logic for this computation is based on the following paper: + https://arxiv.org/abs/1510.01799 + + For the sake of efficiency, we fuse the variables and square grad norms + for the kernel weights and bias vector together. + + Args: + layer_instance: A `tf.keras.layers.Dense` instance. + inputs: A `tf.Tensor` which can be passed into the layer instance, i.e., + `layer_instance(inputs)` returns a valid output. + + Returns: + A `tuple` `(sqr_grad_norms, base_vars, transform)`, where `norms` is a 1D + `tf.Tensor` of the squared l2-norms of the input tensors, `base_vars` is the + intermediate Tensor used in the chain-rule / "fast" clipping trick, and + `transform` is a function that maps `base_vars` to the layer outputs. + """ + orig_activation = layer_instance.activation + layer_instance.activation = None + base_vars = layer_instance(*inputs) + sqr_inputs = tf.square(*inputs) + inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs)) + sqr_grad_norms = tf.reduce_sum(tf.square(*inputs), axis=inputs_reduction_axes) + if layer_instance.use_bias: + # Adding a bias term is equivalent to a layer with no bias term and which + # adds an additional variable to the layer input that only takes a constant + # value of 1.0. This is thus equivalent to adding 1.0 to the sum of the + # squared values of the inputs. + sqr_grad_norms += tf.cast(1.0, dtype=sqr_grad_norms.dtype) + layer_instance.activation = orig_activation + return sqr_grad_norms, base_vars, layer_instance.activation + + +def embedding_layer_computation(layer_instance, inputs): + """Registry function for `tf.keras.layers.Embedding`. + + The logic of this computation is based on the `tf.keras.layers.Dense` + computation and the fact that an embedding layer is just a dense layer + with no activation function and an output vector of the form X*W for input + X, where the i-th row of W is the i-th embedding vector and the j-th row of + X is a one-hot vector representing the input of example j. + + Args: + layer_instance: A `tf.keras.layers.Embedding` instance. + inputs: A `tf.Tensor` which can be passed into the layer instance, i.e., + `layer_instance(inputs)` returns a valid output. + + Returns: + A `tuple` `(sqr_grad_norms, base_vars, None)`, where `sqr_grad_norms` is + a `tf.Tensor` that is related to the squared l2-norms of the input tensors + and `base_vars` is the intermediate Tensor used in the chain-rule / "fast" + clipping trick. + """ + if layer_instance.sparse: + raise NotImplementedError("Sparse output vectors are not supported.") + # The logic below is applied to properly handle repeated embedding indices. + # Specifically, sqr_grad_norms will contain the total counts of each embedding + # index (see how it is processed in the combine_pre_and_post_sqr_norms() + # function in clip_grads.py). An example is as follows: + # + # inputs = + # [[0 0 0 1 2 2], + # [0 2 2 2 1 1]] + # + # counts = + # [[3 1 2] + # [1 2 3]] + # + # sqr_grad_norms = + # [[3 3 3 1 2 2], + # [1 3 3 3 2 2]] + # + base_vars = layer_instance(*inputs) + indices = tf.cast(*inputs, tf.int32) + if isinstance(indices, tf.SparseTensor): + indices = tf.sparse.to_dense(indices) + counts = tf.math.bincount(indices, axis=-1) + sqr_grad_norms = tf.cast( + tf.gather(counts, indices, batch_dims=1), base_vars.dtype + ) + return sqr_grad_norms, base_vars, None + + +# ============================================================================== +# Main factory methods +# ============================================================================== +def make_default_layer_registry(): + registry = {} + registry[hash(tf.keras.layers.Dense)] = dense_layer_computation + registry[hash(tf.keras.layers.Embedding)] = embedding_layer_computation + return registry