Add a kwargs argument to the registry API + small changes to docstrings.

This is a forward-looking change that is needed to support more complicated
layers, such as `tf.keras.layers.MultiHeadAttention`, which can take `kwargs`
as part of their `.call()` method and can generate arbitrary outputs.

PiperOrigin-RevId: 514775503
This commit is contained in:
A. Unique TensorFlower 2023-03-07 10:34:20 -08:00
parent 21ee1a607a
commit 4e1fc252e4
4 changed files with 51 additions and 44 deletions

View file

@ -52,7 +52,7 @@ def get_registry_generator_fn(
)
registry_fn = layer_registry.lookup(layer_instance)
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
layer_instance, args, tape, num_microbatches
layer_instance, args, kwargs, tape, num_microbatches
)
return layer_outputs, (layer_vars, layer_sqr_norm_fn)
else:

View file

@ -13,7 +13,7 @@
# limitations under the License.
import itertools
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
from absl.testing import parameterized
import tensorflow as tf
@ -50,16 +50,17 @@ class DoubleDense(tf.keras.layers.Layer):
def double_dense_layer_computation(
layer_instance: tf.keras.layers.Layer,
inputs: Any,
input_args: Tuple[Any, ...],
input_kwargs: Dict[Text, Any],
tape: tf.GradientTape,
num_microbatches: Optional[int],
):
) -> layer_registry.RegistryFunctionOutput:
"""Layer registry function for the custom `DoubleDense` layer class."""
vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation(
layer_instance.dense1, inputs, tape, num_microbatches
layer_instance.dense1, input_args, input_kwargs, tape, num_microbatches
)
vars2, outputs, sqr_norm_fn2 = layer_registry.dense_layer_computation(
layer_instance.dense2, (outputs,), tape, num_microbatches
layer_instance.dense2, (outputs,), {}, tape, num_microbatches
)
def sqr_norm_fn(base_vars):
@ -75,7 +76,7 @@ def compute_true_gradient_norms(
x_batch: tf.Tensor,
y_batch: tf.Tensor,
num_microbatches: Optional[int],
):
) -> layer_registry.OutputTensor:
"""Computes the real gradient norms for an input `(model, x, y)`."""
loss_config = input_model.loss.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
@ -113,7 +114,7 @@ def get_computed_and_true_norms(
x_input: tf.Tensor,
rng_seed: int = 777,
registry: layer_registry.LayerRegistry = None,
):
) -> Tuple[tf.Tensor, tf.Tensor]:
"""Obtains the true and computed gradient norms for a model and batch input.
Helpful testing wrapper function used to avoid code duplication.

View file

@ -120,7 +120,11 @@ def model_forward_pass(
layer, args, kwargs
)
generator_outputs_list.append(layer_generator_outputs)
args = (node_layer_outputs,)
args = (
node_layer_outputs
if isinstance(node_layer_outputs, tuple)
else (node_layer_outputs,)
)
kwargs = {}
# Update the current dictionary of inputs for the next node.

View file

@ -56,17 +56,21 @@ import tensorflow as tf
# ==============================================================================
# Type aliases
# ==============================================================================
SquareNormFunction = Callable[[Any], tf.Tensor]
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
RegistryFunctionOutput = Tuple[Any, tf.Tensor, SquareNormFunction]
OutputTensor = Union[tf.Tensor, Iterable[tf.Tensor]]
BatchSize = Union[int, tf.Tensor]
SquareNormFunction = Callable[[OutputTensor], tf.Tensor]
RegistryFunctionOutput = Tuple[Any, OutputTensor, SquareNormFunction]
RegistryFunction = Callable[
[Any, Tuple[Any], tf.GradientTape], RegistryFunctionOutput
[Any, Tuple[Any, ...], Dict[Text, Any], tf.GradientTape],
RegistryFunctionOutput,
]
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
BatchSize = Union[int, tf.Tensor]
# ==============================================================================
# Main class
@ -134,7 +138,8 @@ def add_microbatch_axis(
# ==============================================================================
def dense_layer_computation(
layer_instance: tf.keras.layers.Dense,
inputs: Tuple[InputTensor],
input_args: Tuple[Any, ...],
input_kwargs: Dict[Text, Any],
tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None,
) -> RegistryFunctionOutput:
@ -148,9 +153,12 @@ def dense_layer_computation(
Args:
layer_instance: A `tf.keras.layers.Dense` instance.
inputs: A tuple containing a single `InputTensor` which can be passed into
the layer instance, i.e., `layer_instance(*inputs)` returns a valid
output.
input_args: A `tuple` containing the first part of `layer_instance` input.
Specifically, `layer_instance(*inputs_args, **input_kwargs)` should return
a valid output.
input_kwargs: A `tuple` containing the second part of `layer_instance`
input. Specifically, `layer_instance(*inputs_args, **input_kwargs)` should
return a valid output.
tape: A `tf.GradientTape` instance that will be used to watch the output
`base_vars`.
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
@ -169,11 +177,14 @@ def dense_layer_computation(
trainable variables in `layer_instance`. These squared norms should be a 1D
`tf.Tensor` of length `batch_size`.
"""
if len(inputs) != 1:
if input_kwargs:
raise ValueError("Dense layer calls should not receive kwargs.")
del input_kwargs # Unused in dense layer calls.
if len(input_args) != 1:
raise ValueError("Only layer inputs of length 1 are permitted.")
orig_activation = layer_instance.activation
layer_instance.activation = None
base_vars = layer_instance(*inputs)
base_vars = layer_instance(*input_args)
tape.watch(base_vars)
layer_instance.activation = orig_activation
outputs = orig_activation(base_vars) if orig_activation else base_vars
@ -188,7 +199,7 @@ def dense_layer_computation(
# Special handling for better efficiency
return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x)))
inputs_gram = _compute_gramian(*inputs)
inputs_gram = _compute_gramian(*input_args)
base_vars_grads_gram = _compute_gramian(base_vars_grads)
if layer_instance.use_bias:
# Adding a bias term is equivalent to a layer with no bias term and which
@ -206,7 +217,8 @@ def dense_layer_computation(
def embedding_layer_computation(
layer_instance: tf.keras.layers.Embedding,
inputs: Tuple[InputTensor],
input_args: Tuple[Any, ...],
input_kwargs: Dict[Text, Any],
tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None,
) -> RegistryFunctionOutput:
@ -220,33 +232,23 @@ def embedding_layer_computation(
Args:
layer_instance: A `tf.keras.layers.Embedding` instance.
inputs: A tuple containing a single `InputTensor` which can be passed into
the layer instance, i.e., `layer_instance(*inputs)` returns a valid
output.
tape: A `tf.GradientTape` instance that will be used to watch the output
`base_vars`.
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
indicating whether and how the losses are grouped into microbatches. If
not None, num_microbatches must divide the batch size.
input_args: See `dense_layer_computation()`.
input_kwargs: See `dense_layer_computation()`.
tape: See `dense_layer_computation()`.
num_microbatches: See `dense_layer_computation()`.
Returns:
A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the
intermediate Tensor used in the chain-rule / "fast" clipping trick,
`outputs` is the result of `layer_instance(*inputs)`, and `sqr_norm_fn` is
a function that takes one input, a `tf.Tensor` that represents the output
of the call `tape.gradient(summed_loss, base_vars)` where `tape` is a
`tf.GradientTape` instance that records the dense layer computation and
`summed_loss` is the sum of the per-example losses of the underlying model.
This function then returns the per-example squared L2 gradient norms of the
trainable variables in `layer_instance`. These squared norms should be a 1D
`tf.Tensor` of length `batch_size`.
See `dense_layer_computation()`.
"""
if len(inputs) != 1:
if input_kwargs:
raise ValueError("Embedding layer calls should not receive kwargs.")
del input_kwargs # Unused in embedding layer calls.
if len(input_args) != 1:
raise ValueError("Only layer inputs of length 1 are permitted.")
if hasattr(layer_instance, "sparse"): # for backwards compatibility
if layer_instance.sparse:
raise NotImplementedError("Sparse output tensors are not supported.")
if isinstance(inputs[0], tf.SparseTensor):
if isinstance(input_args[0], tf.SparseTensor):
raise NotImplementedError("Sparse input tensors are not supported.")
# Disable experimental features.
@ -256,7 +258,7 @@ def embedding_layer_computation(
"The experimental embedding feature"
"'_use_one_hot_matmul' is not supported."
)
input_ids = tf.cast(*inputs, tf.int32)
input_ids = tf.cast(*input_args, tf.int32)
base_vars = layer_instance.trainable_variables[0]
tape.watch(base_vars)
outputs = tf.nn.embedding_lookup(base_vars, input_ids)