Add a kwargs argument to the registry API + small changes to docstrings.

This is a forward-looking change that is needed to support more complicated
layers, such as `tf.keras.layers.MultiHeadAttention`, which can take `kwargs`
as part of their `.call()` method and can generate arbitrary outputs.

PiperOrigin-RevId: 514775503
This commit is contained in:
A. Unique TensorFlower 2023-03-07 10:34:20 -08:00
parent 21ee1a607a
commit 4e1fc252e4
4 changed files with 51 additions and 44 deletions

View file

@ -52,7 +52,7 @@ def get_registry_generator_fn(
) )
registry_fn = layer_registry.lookup(layer_instance) registry_fn = layer_registry.lookup(layer_instance)
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn( (layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
layer_instance, args, tape, num_microbatches layer_instance, args, kwargs, tape, num_microbatches
) )
return layer_outputs, (layer_vars, layer_sqr_norm_fn) return layer_outputs, (layer_vars, layer_sqr_norm_fn)
else: else:

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import itertools import itertools
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
@ -50,16 +50,17 @@ class DoubleDense(tf.keras.layers.Layer):
def double_dense_layer_computation( def double_dense_layer_computation(
layer_instance: tf.keras.layers.Layer, layer_instance: tf.keras.layers.Layer,
inputs: Any, input_args: Tuple[Any, ...],
input_kwargs: Dict[Text, Any],
tape: tf.GradientTape, tape: tf.GradientTape,
num_microbatches: Optional[int], num_microbatches: Optional[int],
): ) -> layer_registry.RegistryFunctionOutput:
"""Layer registry function for the custom `DoubleDense` layer class.""" """Layer registry function for the custom `DoubleDense` layer class."""
vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation( vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation(
layer_instance.dense1, inputs, tape, num_microbatches layer_instance.dense1, input_args, input_kwargs, tape, num_microbatches
) )
vars2, outputs, sqr_norm_fn2 = layer_registry.dense_layer_computation( vars2, outputs, sqr_norm_fn2 = layer_registry.dense_layer_computation(
layer_instance.dense2, (outputs,), tape, num_microbatches layer_instance.dense2, (outputs,), {}, tape, num_microbatches
) )
def sqr_norm_fn(base_vars): def sqr_norm_fn(base_vars):
@ -75,7 +76,7 @@ def compute_true_gradient_norms(
x_batch: tf.Tensor, x_batch: tf.Tensor,
y_batch: tf.Tensor, y_batch: tf.Tensor,
num_microbatches: Optional[int], num_microbatches: Optional[int],
): ) -> layer_registry.OutputTensor:
"""Computes the real gradient norms for an input `(model, x, y)`.""" """Computes the real gradient norms for an input `(model, x, y)`."""
loss_config = input_model.loss.get_config() loss_config = input_model.loss.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE loss_config['reduction'] = tf.keras.losses.Reduction.NONE
@ -113,7 +114,7 @@ def get_computed_and_true_norms(
x_input: tf.Tensor, x_input: tf.Tensor,
rng_seed: int = 777, rng_seed: int = 777,
registry: layer_registry.LayerRegistry = None, registry: layer_registry.LayerRegistry = None,
): ) -> Tuple[tf.Tensor, tf.Tensor]:
"""Obtains the true and computed gradient norms for a model and batch input. """Obtains the true and computed gradient norms for a model and batch input.
Helpful testing wrapper function used to avoid code duplication. Helpful testing wrapper function used to avoid code duplication.

View file

@ -120,7 +120,11 @@ def model_forward_pass(
layer, args, kwargs layer, args, kwargs
) )
generator_outputs_list.append(layer_generator_outputs) generator_outputs_list.append(layer_generator_outputs)
args = (node_layer_outputs,) args = (
node_layer_outputs
if isinstance(node_layer_outputs, tuple)
else (node_layer_outputs,)
)
kwargs = {} kwargs = {}
# Update the current dictionary of inputs for the next node. # Update the current dictionary of inputs for the next node.

View file

@ -56,17 +56,21 @@ import tensorflow as tf
# ============================================================================== # ==============================================================================
# Type aliases # Type aliases
# ============================================================================== # ==============================================================================
SquareNormFunction = Callable[[Any], tf.Tensor] InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
RegistryFunctionOutput = Tuple[Any, tf.Tensor, SquareNormFunction] OutputTensor = Union[tf.Tensor, Iterable[tf.Tensor]]
BatchSize = Union[int, tf.Tensor]
SquareNormFunction = Callable[[OutputTensor], tf.Tensor]
RegistryFunctionOutput = Tuple[Any, OutputTensor, SquareNormFunction]
RegistryFunction = Callable[ RegistryFunction = Callable[
[Any, Tuple[Any], tf.GradientTape], RegistryFunctionOutput [Any, Tuple[Any, ...], Dict[Text, Any], tf.GradientTape],
RegistryFunctionOutput,
] ]
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
BatchSize = Union[int, tf.Tensor]
# ============================================================================== # ==============================================================================
# Main class # Main class
@ -134,7 +138,8 @@ def add_microbatch_axis(
# ============================================================================== # ==============================================================================
def dense_layer_computation( def dense_layer_computation(
layer_instance: tf.keras.layers.Dense, layer_instance: tf.keras.layers.Dense,
inputs: Tuple[InputTensor], input_args: Tuple[Any, ...],
input_kwargs: Dict[Text, Any],
tape: tf.GradientTape, tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None, num_microbatches: Optional[tf.Tensor] = None,
) -> RegistryFunctionOutput: ) -> RegistryFunctionOutput:
@ -148,9 +153,12 @@ def dense_layer_computation(
Args: Args:
layer_instance: A `tf.keras.layers.Dense` instance. layer_instance: A `tf.keras.layers.Dense` instance.
inputs: A tuple containing a single `InputTensor` which can be passed into input_args: A `tuple` containing the first part of `layer_instance` input.
the layer instance, i.e., `layer_instance(*inputs)` returns a valid Specifically, `layer_instance(*inputs_args, **input_kwargs)` should return
output. a valid output.
input_kwargs: A `tuple` containing the second part of `layer_instance`
input. Specifically, `layer_instance(*inputs_args, **input_kwargs)` should
return a valid output.
tape: A `tf.GradientTape` instance that will be used to watch the output tape: A `tf.GradientTape` instance that will be used to watch the output
`base_vars`. `base_vars`.
num_microbatches: An optional numeric value or scalar `tf.Tensor` for num_microbatches: An optional numeric value or scalar `tf.Tensor` for
@ -169,11 +177,14 @@ def dense_layer_computation(
trainable variables in `layer_instance`. These squared norms should be a 1D trainable variables in `layer_instance`. These squared norms should be a 1D
`tf.Tensor` of length `batch_size`. `tf.Tensor` of length `batch_size`.
""" """
if len(inputs) != 1: if input_kwargs:
raise ValueError("Dense layer calls should not receive kwargs.")
del input_kwargs # Unused in dense layer calls.
if len(input_args) != 1:
raise ValueError("Only layer inputs of length 1 are permitted.") raise ValueError("Only layer inputs of length 1 are permitted.")
orig_activation = layer_instance.activation orig_activation = layer_instance.activation
layer_instance.activation = None layer_instance.activation = None
base_vars = layer_instance(*inputs) base_vars = layer_instance(*input_args)
tape.watch(base_vars) tape.watch(base_vars)
layer_instance.activation = orig_activation layer_instance.activation = orig_activation
outputs = orig_activation(base_vars) if orig_activation else base_vars outputs = orig_activation(base_vars) if orig_activation else base_vars
@ -188,7 +199,7 @@ def dense_layer_computation(
# Special handling for better efficiency # Special handling for better efficiency
return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x))) return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x)))
inputs_gram = _compute_gramian(*inputs) inputs_gram = _compute_gramian(*input_args)
base_vars_grads_gram = _compute_gramian(base_vars_grads) base_vars_grads_gram = _compute_gramian(base_vars_grads)
if layer_instance.use_bias: if layer_instance.use_bias:
# Adding a bias term is equivalent to a layer with no bias term and which # Adding a bias term is equivalent to a layer with no bias term and which
@ -206,7 +217,8 @@ def dense_layer_computation(
def embedding_layer_computation( def embedding_layer_computation(
layer_instance: tf.keras.layers.Embedding, layer_instance: tf.keras.layers.Embedding,
inputs: Tuple[InputTensor], input_args: Tuple[Any, ...],
input_kwargs: Dict[Text, Any],
tape: tf.GradientTape, tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None, num_microbatches: Optional[tf.Tensor] = None,
) -> RegistryFunctionOutput: ) -> RegistryFunctionOutput:
@ -220,33 +232,23 @@ def embedding_layer_computation(
Args: Args:
layer_instance: A `tf.keras.layers.Embedding` instance. layer_instance: A `tf.keras.layers.Embedding` instance.
inputs: A tuple containing a single `InputTensor` which can be passed into input_args: See `dense_layer_computation()`.
the layer instance, i.e., `layer_instance(*inputs)` returns a valid input_kwargs: See `dense_layer_computation()`.
output. tape: See `dense_layer_computation()`.
tape: A `tf.GradientTape` instance that will be used to watch the output num_microbatches: See `dense_layer_computation()`.
`base_vars`.
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
indicating whether and how the losses are grouped into microbatches. If
not None, num_microbatches must divide the batch size.
Returns: Returns:
A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the See `dense_layer_computation()`.
intermediate Tensor used in the chain-rule / "fast" clipping trick,
`outputs` is the result of `layer_instance(*inputs)`, and `sqr_norm_fn` is
a function that takes one input, a `tf.Tensor` that represents the output
of the call `tape.gradient(summed_loss, base_vars)` where `tape` is a
`tf.GradientTape` instance that records the dense layer computation and
`summed_loss` is the sum of the per-example losses of the underlying model.
This function then returns the per-example squared L2 gradient norms of the
trainable variables in `layer_instance`. These squared norms should be a 1D
`tf.Tensor` of length `batch_size`.
""" """
if len(inputs) != 1: if input_kwargs:
raise ValueError("Embedding layer calls should not receive kwargs.")
del input_kwargs # Unused in embedding layer calls.
if len(input_args) != 1:
raise ValueError("Only layer inputs of length 1 are permitted.") raise ValueError("Only layer inputs of length 1 are permitted.")
if hasattr(layer_instance, "sparse"): # for backwards compatibility if hasattr(layer_instance, "sparse"): # for backwards compatibility
if layer_instance.sparse: if layer_instance.sparse:
raise NotImplementedError("Sparse output tensors are not supported.") raise NotImplementedError("Sparse output tensors are not supported.")
if isinstance(inputs[0], tf.SparseTensor): if isinstance(input_args[0], tf.SparseTensor):
raise NotImplementedError("Sparse input tensors are not supported.") raise NotImplementedError("Sparse input tensors are not supported.")
# Disable experimental features. # Disable experimental features.
@ -256,7 +258,7 @@ def embedding_layer_computation(
"The experimental embedding feature" "The experimental embedding feature"
"'_use_one_hot_matmul' is not supported." "'_use_one_hot_matmul' is not supported."
) )
input_ids = tf.cast(*inputs, tf.int32) input_ids = tf.cast(*input_args, tf.int32)
base_vars = layer_instance.trainable_variables[0] base_vars = layer_instance.trainable_variables[0]
tape.watch(base_vars) tape.watch(base_vars)
outputs = tf.nn.embedding_lookup(base_vars, input_ids) outputs = tf.nn.embedding_lookup(base_vars, input_ids)