diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD index f0b03bb..ffa666f 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD @@ -9,8 +9,8 @@ py_library( ) py_library( - name = "layer_registry_factories", - srcs = ["layer_registry_factories.py"], + name = "layer_registry", + srcs = ["layer_registry.py"], srcs_version = "PY3", ) @@ -20,7 +20,7 @@ py_library( srcs_version = "PY3", deps = [ ":gradient_clipping_utils", - ":layer_registry_factories", + ":layer_registry", ], ) @@ -31,6 +31,6 @@ py_test( srcs_version = "PY3", deps = [ ":clip_grads", - ":layer_registry_factories", + ":layer_registry", ], ) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index 439febc..b2c9dd3 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -25,40 +25,42 @@ import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils -def combine_pre_and_post_sqr_norms(pre_sqr_norm, post_grad, layer_hash): - """Combines pre and post-activation tensors for a given variable. - - The logic for combining norms depends on the variable's underlying layer. - - Args: - pre_sqr_norm: A `tf.Tensor` whose first dimension is the batch dimension. - Contains squared norms that are related to the pre-activation Tensor. - post_grad: A `tf.Tensor` whose first dimension is the batch dimension. - Contains gradients that are related to the post-activation Tensor. - layer_hash: A `float` that is the hash of the variable's underlying layer - class. - - Returns: - A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th - per-example loss function with respect to the given variable. - """ - post_sqr_grads = tf.square(post_grad) - if layer_hash == hash(tf.keras.layers.Embedding): - scaled_grads = tf.expand_dims(pre_sqr_norm, axis=-1) * post_sqr_grads - reduction_axes = tf.range(1, tf.rank(scaled_grads)) - return tf.reduce_sum(scaled_grads, axis=reduction_axes) +def get_registry_generator_fn(tape, layer_registry): + """Creates the generator function for `compute_gradient_norms()`.""" + if layer_registry is None: + # Needed for backwards compatibility. + registry_generator_fn = None else: - reduction_axes = tf.range(1, tf.rank(post_sqr_grads)) - post_sqr_norm = tf.reduce_sum(post_sqr_grads, axis=reduction_axes) - return pre_sqr_norm * post_sqr_norm + + def registry_generator_fn(layer_instance, args, kwargs): + if layer_instance.trainable_variables: + # Only trainable variables factor into the gradient. + if not layer_registry.is_elem(layer_instance): + raise NotImplementedError( + 'Layer %s is not in the registry of known layers that can ' + 'be used for efficient gradient clipping.' + % layer_instance.__class__.__name__ + ) + registry_fn = layer_registry.lookup(layer_instance) + (layer_vars, transform, layer_sqr_norm_fn) = registry_fn( + layer_instance, args + ) + if tape is not None: + tape.watch(layer_vars) + layer_outputs = transform(layer_vars) if transform else layer_vars + return layer_outputs, (layer_vars, layer_sqr_norm_fn) + else: + # Non-trainable layer. + return layer_instance(*args, **kwargs), None + + return registry_generator_fn def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry): """Computes the per-example loss gradient norms for given data. - Applies the approach given in https://arxiv.org/pdf/2009.03106.pdf, except - the batch matrix multiplication operation in Algorithm 2 is replaced with - the computation of two norm computations. + Applies a variant of the approach given in + https://arxiv.org/pdf/2009.03106.pdf Args: input_model: The `tf.keras.Model` from which to obtain the layers from. The @@ -68,24 +70,22 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry): y_batch: A `tf.Tensor` representing a batch of output labels. The first axis must be the batch dimension. The number of examples should match the number of examples in `x_batch`. - layer_registry: A `dict` of layers that support "fast" gradient norm - computations. The key is the class of the layer and the value is a - function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where - `output` is the pre-activator tensor, `sqr_grad_norms` is related to the - squared norms of a layer's pre-activation tensor, and `vars` are relevant - trainable weights (see `layer_registry_factories.py` for examples). + layer_registry: A `LayerRegistry` instance containing functions that help + compute gradient norms quickly. See + `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for + more details. Returns: A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th per-example loss function. """ tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) - # First loop computes the norms of the layer inputs, caches these inputs, - # and computes the summed loss. + registry_generator_fn = get_registry_generator_fn(tape, layer_registry) + # First loop computes the model outputs, summed loss, and generator outputs. with tape: - model_outputs, pre_norm_list, var_list, layer_hash_list = ( - gradient_clipping_utils.forward_norm_pass( - input_model, x_batch, tape, layer_registry + model_outputs, generator_outputs_list = ( + gradient_clipping_utils.model_forward_pass( + input_model, x_batch, generator_fn=registry_generator_fn ) ) # Ignore the original loss function's reduction to get per-example loss. @@ -94,20 +94,19 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry): per_example_loss_fn = input_model.loss.from_config(loss_config) losses = per_example_loss_fn(y_batch, model_outputs) summed_loss = tf.reduce_sum(losses) - # Second loop computes the norm of the gradient of the loss with respect to - # the pre-activation tensors, and multiplies these norms with the results of - # the first loop. - full_norm_list = [] - grads = tape.gradient(summed_loss, var_list) - for i in range(len(var_list)): - full_norm = combine_pre_and_post_sqr_norms( - pre_norm_list[i], grads[i], layer_hash_list[i] - ) - full_norm_list.append(full_norm) + # Unwrap the generator outputs so that the next loop avoids duplicating + # backprop ops. + filtered_outputs = [t for t in generator_outputs_list if t is not None] + vars_list = [a for (a, b) in filtered_outputs] + sqr_norm_fns_list = [b for (a, b) in filtered_outputs] + # Second loop evaluates the squared L2 norm functions and appends the results. + grads_list = tape.gradient(summed_loss, vars_list) + sqr_norm_list = [] + for grads, f in zip(grads_list, sqr_norm_fns_list): + sqr_norm_list.append(f(grads)) del tape - # Post-processing for compatibility with non-eager mode (very annoying). - full_norm_tsr = tf.stack(full_norm_list, axis=1) - return tf.sqrt(tf.reduce_sum(full_norm_tsr, axis=1)) + sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1) + return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) def compute_clip_weights(l2_norm_clip, gradient_norms): diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index 6da6a54..d325680 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -17,7 +17,7 @@ from absl.testing import parameterized import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads -from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry_factories +from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry # ============================================================================== @@ -81,7 +81,7 @@ def get_computed_and_true_norms( ) y_pred = model(x_input) y_batch = tf.ones_like(y_pred) - registry = layer_registry_factories.make_default_layer_registry() + registry = layer_registry.make_default_layer_registry() computed_norms = clip_grads.compute_gradient_norms( model, x_input, y_batch, layer_registry=registry ) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index 96e95ed..6dd0d49 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -28,8 +28,19 @@ def has_internal_compute_graph(input_object): ) -def forward_norm_pass(input_model, x_batch, tape, layer_registry): - """Does a forward pass of a model and returns some useful intermediates. +def _get_internal_layers(input_layer): + """Returns a list of layers that are nested within a given layer.""" + internal_layers = [] + if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'): + for layer in input_layer.layers: + internal_layers.extend(_get_internal_layers(layer)) + else: + internal_layers.append(input_layer) + return internal_layers + + +def model_forward_pass(input_model, inputs, generator_fn=None): + """Does a forward pass of a model and returns useful intermediates. NOTE: the graph traversal algorithm is an adaptation of the logic in the _run_internal_graph() method in the functional.Functional class. Hence, @@ -37,44 +48,42 @@ def forward_norm_pass(input_model, x_batch, tape, layer_registry): instance is an instance of the functional.Functional class. Args: - input_model: A Keras functional model to compute the quantities for. - x_batch: A collection of Tensors to be fed into the input layer of the - model. - tape: A tf.GradientTape() instance used to record certain operations. - Assumes that this function is being called inside of this tape. - layer_registry: A dictionary of layers that support "fast" gradient norm - computations. The key is the class of the layer and the value is a - function that returns a triple (norm_list, var_list, transform). For more - details, see `layer_registry_factories.py`. + input_model: A `tf.keras.Model` to compute the quantities for. + inputs: Arbitrary input to be fed into the input layer of the model. It is + expected that `input_model(inputs)` returns a valid output. + generator_fn: A function with signature `(tf.keras.layers.Layer, Any, Any) + -> (tf.Tensor, Any)`, where we require `generator_fn(layer_instance, args, + kwargs)[0] == layer_instance(*args, **kwargs)`. If `None`, then + `layer_fn(layer_instance, args, kwargs)[1] == None`. Returns: - Four objects (outputs, norm_list, var_list, layer_hash_list). The first are - the outputs that are generated as a result of a forward pass. The second is - either `None` or a collection of squared norms for each example that helps - in computing the gradient norm of an example (if such a "nice" computation - exists). The third is an ordered list of tf.Tensor() objects that are - intended to be evaluated with respect to the summed loss of the model. The - fourth is a list whose i-th element is the hash of the layer class of the - i-th element of var_list. + A `tuple` `(outputs, generator_outputs_list)`. `outputs` is the + `tf.Tensor` that is generated as a result of a forward pass. + `generator_outputs_list` is a `list` whose i-th entry is the output of + `generator_fn(lyr, args, kwargs)[1]` where `lyr` is the i-th + layer when the compute graph of `input_model` is traversed in BFS order. """ # TODO: Avoid or remove the references to protected methods of `input_model`. # pylint: disable=g-bad-todo + + # Default generator. + generator_outputs_list = [] + if generator_fn is None: + + def generator_fn(layer_instance, args, kwargs): + return layer_instance(*args, **kwargs), None + # Prepare the inputs and BFS variables. - flattened_inputs = input_model._flatten_to_reference_inputs(x_batch) # pylint: disable=protected-access + flattened_inputs = input_model._flatten_to_reference_inputs(inputs) # pylint: disable=protected-access tensor_dict = {} tensor_usage_count = input_model._tensor_usage_count # pylint: disable=protected-access for x, y in zip(input_model.inputs, flattened_inputs): y = input_model._conform_to_reference_input(y, ref_input=x) # pylint: disable=protected-access x_id = str(id(x)) tensor_dict[x_id] = [y] * tensor_usage_count[x_id] - - # Main computations. nodes_by_depth = input_model._nodes_by_depth # pylint: disable=protected-access depth_keys = list(nodes_by_depth.keys()) depth_keys.sort(reverse=True) - var_list = [] - norm_list = [] - layer_hash_list = [] - node_outputs = None + # Perform BFS feedforward computations. for depth in depth_keys: for node in nodes_by_depth[depth]: @@ -85,62 +94,28 @@ def forward_norm_pass(input_model, x_batch, tape, layer_registry): args, kwargs = node.map_arguments(tensor_dict) if has_internal_compute_graph(node.layer): # If this node has an internal computational graph, we can recurse. - node_outputs, node_norm_list, node_var_list, node_layer_hash_list = ( - forward_norm_pass(node.layer, args, tape, layer_registry) + node_layer_outputs, node_generator_outputs = model_forward_pass( + node.layer, args, generator_fn ) - var_list += node_var_list - norm_list += node_norm_list - layer_hash_list += node_layer_hash_list + generator_outputs_list.extend(node_generator_outputs) else: - # Either pass through or record some metadata. - if not node.layer.trainable_variables: - node_outputs = node.layer(*args, **kwargs) - else: - lyr_hash = hash(node.layer.__class__) - if lyr_hash not in layer_registry: - raise NotImplementedError( - 'Layer %s is not in the registry of known layers that can' - 'be used for efficient gradient clipping.' - % node.layer.__class__.__name__ - ) - lyr = layer_registry[lyr_hash] - node_norms, node_vars, transform = lyr(node.layer, args) - tape.watch(node_vars) - node_outputs = transform(node_vars) if transform else node_vars - var_list.append(node_vars) - norm_list.append(node_norms) - layer_hash_list.append(lyr_hash) - # update the current dictionary of inputs for the next node. - for x_id, y in zip(node.flat_output_ids, tf.nest.flatten(node_outputs)): + # Otherwise, we parse the node directly. + node_layers = _get_internal_layers(node.layer) + for layer in node_layers: + node_layer_outputs, layer_generator_outputs = generator_fn( + layer, args, kwargs + ) + generator_outputs_list.append(layer_generator_outputs) + args = (node_layer_outputs,) + kwargs = {} + + # Update the current dictionary of inputs for the next node. + for x_id, y in zip( + node.flat_output_ids, tf.nest.flatten(node_layer_outputs) + ): tensor_dict[x_id] = [y] * tensor_usage_count[x_id] - return node_outputs, norm_list, var_list, layer_hash_list - - -def get_trainable_hidden_layers(input_model): - """Obtains the trainable hidden layers of a Keras model. - - Args: - input_model: The Keras model to obtain the layers from. - - Returns: - A list of Keras layers where the tensorflow.keras.layers.Layer - ancestor class MUST precede any existing tensorflow.keras.models.Model - ancestor class. - """ - hidden_layers = [] - for l in input_model.layers: - for c in l.__class__.__mro__: - if c == tf.keras.models.Model: - hidden_layers += get_trainable_hidden_layers(l) - break - elif c == tf.keras.layers.InputLayer: - break - elif c == tf.keras.layers.Layer: - if l.trainable_variables: - hidden_layers.append(l) - break - return hidden_layers + return node_layer_outputs, generator_outputs_list def all_trainable_layers_are_registered(input_model, layer_registry): @@ -148,21 +123,19 @@ def all_trainable_layers_are_registered(input_model, layer_registry): Args: input_model: The Keras model from which to obtain the layers from. - layer_registry: A dictionary of layers that support "fast" gradient norm - computations. The key is the class of the layer and the value is a - function that returns a triple (output, sqr_grad_norms, vars), where - output is the pre-activator tensor, sqr_grad_norms is the square of the - norm of the layer's input, and vars is an ordered list of the trainable - weights. + layer_registry: A `LayerRegistry` instance containing functions that help + compute gradient norms quickly. See + `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for + more details. Returns: True if all the trainable layers in `input_model` are in `layer_registry`. False otherwise. """ - hidden_layers = get_trainable_hidden_layers(input_model) - for l in hidden_layers: - if hash(l.__class__) not in layer_registry: - return False + for layer in input_model.layers: + for sublayer in _get_internal_layers(layer): + if not layer_registry.is_elem(sublayer) and sublayer.trainable_variables: + return False return True @@ -212,3 +185,19 @@ def add_aggregate_noise( ) return tf.nest.map_structure(add_noise, clipped_grads) + + +def generate_model_outputs_using_core_keras_layers(input_model): + """Returns the model outputs generated by only core Keras layers.""" + cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects()) + cust_hash_set = set([hash(v) for v in cust_obj_dict.values()]) + + def generator_fn(layer_instance, args, kwargs): + if hash(layer_instance.__class__) in cust_hash_set: + # Using `.call()` does not register the layer in the compute graph of + # a forward pass. + return layer_instance.call(*args, **kwargs), None + else: + return layer_instance(*args, **kwargs), None + + return model_forward_pass(input_model, input_model.inputs, generator_fn)[0] diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py new file mode 100644 index 0000000..7e94dfa --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py @@ -0,0 +1,195 @@ +# Copyright 2022, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines the layer registry class and useful factory functions. + +Defines "fast" gradient norm layer registry functions for use in the "fast" +gradient clipping algorithm. Specifically, each registry function takes +in two inputs (i) a layer instance and (ii) `tf.Tensor` inputs to produce three +outputs: (a) a differentiable `tf.Tensor` `Z`, (b) either `None` or a function +that maps the object in (a) to the layer instance's output when using the +inputs in (ii), and (c) a function `F` that generates the per-example +squared gradient norms when it is fed an object representing the gradient of +the summed loss with respect to `Z` in (a). If (b) is `None`, then (a) is +expected to contain the layer outputs. + +When a layer registry function is defined, it is generally assumed that the +following relation holds: + + `|dL/dW|^2 == F(grad_Z)` + +where `gradient_Z` is the gradient of the summed loss with respect to `Z`. + +For example, if the layer instance is tf.keras.layers.Dense, Z contains the +pre-activation tensors, i.e., `z = X * w` for input `X`, and `g` is a tensor +whose i-th entry is the L2 norm of the i-th input vector, then + + `F(grad_Z) = g^2 * l2_row_norm(grad_Z)^2`, + +where `l2_row_norm(y)` computes the L2 norm for each row of an input `y`. +Details of this decomposition can be found in https://arxiv.org/abs/1510.01799 +""" + +import tensorflow as tf + + +# ============================================================================== +# Main class +# ============================================================================== +class LayerRegistry: + """Custom container for layer registry functions.""" + + def __init__(self): + """Basic initialization of various internal dictionaries.""" + self._layer_class_dict = {} + self._registry = {} + + def is_elem(self, layer_instance): + """Checks if a layer instance's class is in the registry.""" + return hash(layer_instance.__class__) in self._registry + + def lookup(self, layer_instance): + """Returns the layer registry function for a given layer instance.""" + return self._registry[hash(layer_instance.__class__)] + + def insert(self, layer_class, layer_registry_function): + """Inserts a layer registry function into the internal dictionaries.""" + layer_key = hash(layer_class) + self._layer_class_dict[layer_key] = layer_class + self._registry[layer_key] = layer_registry_function + + +# ============================================================================== +# Supported Keras layers +# ============================================================================== +def dense_layer_computation(layer_instance, inputs): + """Registry function for `tf.keras.layers.Dense`. + + The logic for this computation is based on the following paper: + https://arxiv.org/abs/1510.01799 + + For the sake of efficiency, we fuse the variables and square grad norms + for the kernel weights and bias vector together. + + Args: + layer_instance: A `tf.keras.layers.Dense` instance. + inputs: A `tf.Tensor` which can be passed into the layer instance, i.e., + `layer_instance(inputs)` returns a valid output. + + Returns: + A `tuple` `(base_vars, transform, sqr_norm_fn)`. `base_vars` is the + intermediate Tensor used in the chain-rule / "fast" clipping trick, + `transform` is a function that maps `base_vars` to the layer outputs, and + `sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that + represents the output of the call `tape.gradient(summed_loss, base_vars)` + where `tape` is a `tf.GradientTape` instance that records the dense + layer computation and `summed_loss` is the sum of the per-example losses + of the underlying model. This function then returns the per-example squared + L2 gradient norms of the trainable variables in `layer_instance`. These + squared norms should be a 1D `tf.Tensor` of length `batch_size`. + """ + orig_activation = layer_instance.activation + layer_instance.activation = None + base_vars = layer_instance(*inputs) + layer_instance.activation = orig_activation + def sqr_norm_fn(base_vars_grads): + sqr_inputs = tf.square(*inputs) + inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs)) + input_sqr_norms = tf.reduce_sum(sqr_inputs, axis=inputs_reduction_axes) + if layer_instance.use_bias: + # Adding a bias term is equivalent to a layer with no bias term and which + # adds an additional variable to the layer input that only takes a + # constant value of 1.0. This is thus equivalent to adding 1.0 to the sum + # of the squared values of the inputs. + input_sqr_norms += tf.cast(1.0, dtype=input_sqr_norms.dtype) + reduction_axes = tf.range(1, tf.rank(base_vars_grads)) + base_vars_sqr_norms = tf.reduce_sum( + tf.square(base_vars_grads), axis=reduction_axes + ) + return input_sqr_norms * base_vars_sqr_norms + + return base_vars, layer_instance.activation, sqr_norm_fn + + +def embedding_layer_computation(layer_instance, inputs): + """Registry function for `tf.keras.layers.Embedding`. + + The logic of this computation is based on the `tf.keras.layers.Dense` + computation and the fact that an embedding layer is just a dense layer + with no activation function and an output vector of the form X*W for input + X, where the i-th row of W is the i-th embedding vector and the j-th row of + X is a one-hot vector representing the input of example j. + + Args: + layer_instance: A `tf.keras.layers.Embedding` instance. + inputs: A `tf.Tensor` which can be passed into the layer instance, i.e., + `layer_instance(inputs)` returns a valid output. + + Returns: + A `tuple` `(base_vars, None, sqr_norm_fn)`, `base_vars` is the + intermediate Tensor used in the chain-rule / "fast" clipping trick, and + `sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that + represents the output of the call `tape.gradient(summed_loss, base_vars)` + where `tape` is a `tf.GradientTape` instance that records the dense + layer computation and `summed_loss` is the sum of the per-example losses + of the underlying model. This function then returns the per-example squared + L2 gradient norms of the trainable variables in `layer_instance`. These + squared norms should be a 1D `tf.Tensor` of length `batch_size`. + """ + if hasattr(layer_instance, "sparse"): # for backwards compatibility + if layer_instance.sparse: + raise NotImplementedError("Sparse output vectors are not supported.") + if len(inputs[0].shape) != 2: + raise NotImplementedError("Only 2D embedding inputs are supported.") + # The logic below is applied to properly handle repeated embedding indices. + # Specifically, sqr_grad_norms will contain the total counts of each embedding + # index (see how it is processed in the combine_pre_and_post_sqr_norms() + # function in clip_grads.py). An example is as follows: + # + # inputs = + # [[0 0 0 1 2 2], + # [0 2 2 2 1 1]] + # + # counts = + # [[3 1 2] + # [1 2 3]] + # + # input_counts = + # [[3 3 3 1 2 2], + # [1 3 3 3 2 2]] + # + base_vars = layer_instance(*inputs) + def sqr_norm_fn(base_vars_grads): + indices = tf.cast(*inputs, tf.int32) + if isinstance(indices, tf.SparseTensor): + indices = tf.sparse.to_dense(indices) + counts = tf.math.bincount(indices, axis=-1) + input_counts = tf.expand_dims( + tf.cast(tf.gather(counts, indices, batch_dims=1), base_vars.dtype), + axis=-1, + ) + scaled_grads = input_counts * tf.square(base_vars_grads) + reduction_axes = tf.range(1, tf.rank(scaled_grads)) + return tf.reduce_sum(scaled_grads, axis=reduction_axes) + + return base_vars, None, sqr_norm_fn + + +# ============================================================================== +# Main factory methods +# ============================================================================== +def make_default_layer_registry(): + registry = LayerRegistry() + registry.insert(tf.keras.layers.Dense, dense_layer_computation) + registry.insert(tf.keras.layers.Embedding, embedding_layer_computation) + return registry diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry_factories.py b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry_factories.py deleted file mode 100644 index d4f7f4c..0000000 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry_factories.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2022, The TensorFlow Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Generates a default layer registry. - -Defines "fast" gradient norm layer registry functions for use in the "fast" -gradient clipping algorithm. Specifically, each registry function takes -in two inputs (i) a layer instance and (ii) `tf.Tensor` inputs to produce three -outputs: (a) a `tf.Tensor` `G` of gradient norms, (b) a differentiable -`tf.Tensor` `Z`, and (c) either `None` or a function that maps the object in -(b) to the layer instance's output when using the inputs in (ii). If (c) is -`None`, then (b) contains the layer outputs. - -When a layer registry function is defined, it is generally assumed that the -following relation holds for each pair `(g, z)` in `zip(G, Z)`: - - `|dL/dw|^2 == |dL/dz|^2 * g^2` - -where `L` is any per-example loss function and `w` are the trainable variables -corresponding to `(g, z)`. - -For example, this relation holds if the layer instance is tf.keras.layers.Dense, -Z contains the pre-activation tensors, i.e., `z = X * w` for input `X`, and `g` -is the norm of the input corresponding to the given per-example loss (see the -formulae in https://arxiv.org/abs/1510.01799 for more details). - -The registry functions are registered in a `dict` (registry) whose key is the -hash of the layer class and whose value is the registry function. -""" - -import tensorflow as tf - - -# ============================================================================== -# Supported Keras layers -# ============================================================================== -def dense_layer_computation(layer_instance, inputs): - """Registry function for `tf.keras.layers.Dense`. - - The logic for this computation is based on the following paper: - https://arxiv.org/abs/1510.01799 - - For the sake of efficiency, we fuse the variables and square grad norms - for the kernel weights and bias vector together. - - Args: - layer_instance: A `tf.keras.layers.Dense` instance. - inputs: A `tf.Tensor` which can be passed into the layer instance, i.e., - `layer_instance(inputs)` returns a valid output. - - Returns: - A `tuple` `(sqr_grad_norms, base_vars, transform)`, where `norms` is a 1D - `tf.Tensor` of the squared l2-norms of the input tensors, `base_vars` is the - intermediate Tensor used in the chain-rule / "fast" clipping trick, and - `transform` is a function that maps `base_vars` to the layer outputs. - """ - orig_activation = layer_instance.activation - layer_instance.activation = None - base_vars = layer_instance(*inputs) - sqr_inputs = tf.square(*inputs) - inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs)) - sqr_grad_norms = tf.reduce_sum(tf.square(*inputs), axis=inputs_reduction_axes) - if layer_instance.use_bias: - # Adding a bias term is equivalent to a layer with no bias term and which - # adds an additional variable to the layer input that only takes a constant - # value of 1.0. This is thus equivalent to adding 1.0 to the sum of the - # squared values of the inputs. - sqr_grad_norms += tf.cast(1.0, dtype=sqr_grad_norms.dtype) - layer_instance.activation = orig_activation - return sqr_grad_norms, base_vars, layer_instance.activation - - -def embedding_layer_computation(layer_instance, inputs): - """Registry function for `tf.keras.layers.Embedding`. - - The logic of this computation is based on the `tf.keras.layers.Dense` - computation and the fact that an embedding layer is just a dense layer - with no activation function and an output vector of the form X*W for input - X, where the i-th row of W is the i-th embedding vector and the j-th row of - X is a one-hot vector representing the input of example j. - - Args: - layer_instance: A `tf.keras.layers.Embedding` instance. - inputs: A `tf.Tensor` which can be passed into the layer instance, i.e., - `layer_instance(inputs)` returns a valid output. - - Returns: - A `tuple` `(sqr_grad_norms, base_vars, None)`, where `sqr_grad_norms` is - a `tf.Tensor` that is related to the squared l2-norms of the input tensors - and `base_vars` is the intermediate Tensor used in the chain-rule / "fast" - clipping trick. - """ - if hasattr(layer_instance, "sparse"): # for backwards compatibility - if layer_instance.sparse: - raise NotImplementedError("Sparse output vectors are not supported.") - if tf.rank(*inputs) != 2: - raise NotImplementedError("Only 2D embedding inputs are supported.") - # The logic below is applied to properly handle repeated embedding indices. - # Specifically, sqr_grad_norms will contain the total counts of each embedding - # index (see how it is processed in the combine_pre_and_post_sqr_norms() - # function in clip_grads.py). An example is as follows: - # - # inputs = - # [[0 0 0 1 2 2], - # [0 2 2 2 1 1]] - # - # counts = - # [[3 1 2] - # [1 2 3]] - # - # sqr_grad_norms = - # [[3 3 3 1 2 2], - # [1 3 3 3 2 2]] - # - base_vars = layer_instance(*inputs) - indices = tf.cast(*inputs, tf.int32) - if isinstance(indices, tf.SparseTensor): - indices = tf.sparse.to_dense(indices) - counts = tf.math.bincount(indices, axis=-1) - sqr_grad_norms = tf.cast( - tf.gather(counts, indices, batch_dims=1), base_vars.dtype - ) - return sqr_grad_norms, base_vars, None - - -# ============================================================================== -# Main factory methods -# ============================================================================== -def make_default_layer_registry(): - registry = {} - registry[hash(tf.keras.layers.Dense)] = dense_layer_computation - registry[hash(tf.keras.layers.Embedding)] = embedding_layer_computation - return registry diff --git a/tensorflow_privacy/privacy/keras_models/BUILD b/tensorflow_privacy/privacy/keras_models/BUILD index 407537a..5a20c38 100644 --- a/tensorflow_privacy/privacy/keras_models/BUILD +++ b/tensorflow_privacy/privacy/keras_models/BUILD @@ -18,7 +18,7 @@ py_library( deps = [ "//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads", "//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils", - "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry_factories", + "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", ], ) @@ -28,7 +28,7 @@ py_test( python_version = "PY3", srcs_version = "PY3", deps = [ - "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry_factories", + "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", "//tensorflow_privacy/privacy/keras_models:dp_keras_model", ], ) diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index c2e2421..5f445d2 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -96,13 +96,10 @@ def make_dp_model_class(cls): noise_multiplier: Ratio of the standard deviation to the clipping norm. num_microbatches: Number of microbatches. use_xla: If `True`, compiles train_step to XLA. - layer_registry: A `dict` of layers that support "fast" gradient norm - computations. The key is the class of the layer and the value is a - function that returns a `tuple` `(output, sqr_grad_norms, vars)`, - where `output` is the pre-activator tensor, `sqr_grad_norms` is - related to the squared norms of a layer's pre-activation tensor, and - `vars` are relevant trainable weights (see - `layer_registry_factories.py` for examples). + layer_registry: A `LayerRegistry` instance containing functions that + help compute gradient norms quickly. See + `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for + more details. *args: These will be passed on to the base class `__init__` method. **kwargs: These will be passed on to the base class `__init__` method. """ diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py index 172a853..4bb3c4f 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py @@ -15,7 +15,7 @@ from absl.testing import parameterized import numpy as np import tensorflow as tf -from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry_factories +from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry from tensorflow_privacy.privacy.keras_models import dp_keras_model @@ -32,7 +32,10 @@ def get_layer_registries(): # Outputs a list of testable layer registries. # The empty registry {} tests the behavior of the standard approach, # while the other one tests the fast gradient clipping algorithm. - return [{}, layer_registry_factories.make_default_layer_registry()] + return [ + layer_registry.LayerRegistry(), + layer_registry.make_default_layer_registry(), + ] class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):