Add a kwargs
argument to the registry API + small changes to docstrings.
This is a forward-looking change that is needed to support more complicated layers, such as `tf.keras.layers.MultiHeadAttention`, which can take `kwargs` as part of their `.call()` method and can generate arbitrary outputs. PiperOrigin-RevId: 514775503
This commit is contained in:
parent
21ee1a607a
commit
4e1fc252e4
4 changed files with 51 additions and 44 deletions
|
@ -52,7 +52,7 @@ def get_registry_generator_fn(
|
|||
)
|
||||
registry_fn = layer_registry.lookup(layer_instance)
|
||||
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
|
||||
layer_instance, args, tape, num_microbatches
|
||||
layer_instance, args, kwargs, tape, num_microbatches
|
||||
)
|
||||
return layer_outputs, (layer_vars, layer_sqr_norm_fn)
|
||||
else:
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
@ -50,16 +50,17 @@ class DoubleDense(tf.keras.layers.Layer):
|
|||
|
||||
def double_dense_layer_computation(
|
||||
layer_instance: tf.keras.layers.Layer,
|
||||
inputs: Any,
|
||||
input_args: Tuple[Any, ...],
|
||||
input_kwargs: Dict[Text, Any],
|
||||
tape: tf.GradientTape,
|
||||
num_microbatches: Optional[int],
|
||||
):
|
||||
) -> layer_registry.RegistryFunctionOutput:
|
||||
"""Layer registry function for the custom `DoubleDense` layer class."""
|
||||
vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation(
|
||||
layer_instance.dense1, inputs, tape, num_microbatches
|
||||
layer_instance.dense1, input_args, input_kwargs, tape, num_microbatches
|
||||
)
|
||||
vars2, outputs, sqr_norm_fn2 = layer_registry.dense_layer_computation(
|
||||
layer_instance.dense2, (outputs,), tape, num_microbatches
|
||||
layer_instance.dense2, (outputs,), {}, tape, num_microbatches
|
||||
)
|
||||
|
||||
def sqr_norm_fn(base_vars):
|
||||
|
@ -75,7 +76,7 @@ def compute_true_gradient_norms(
|
|||
x_batch: tf.Tensor,
|
||||
y_batch: tf.Tensor,
|
||||
num_microbatches: Optional[int],
|
||||
):
|
||||
) -> layer_registry.OutputTensor:
|
||||
"""Computes the real gradient norms for an input `(model, x, y)`."""
|
||||
loss_config = input_model.loss.get_config()
|
||||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||
|
@ -113,7 +114,7 @@ def get_computed_and_true_norms(
|
|||
x_input: tf.Tensor,
|
||||
rng_seed: int = 777,
|
||||
registry: layer_registry.LayerRegistry = None,
|
||||
):
|
||||
) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||
"""Obtains the true and computed gradient norms for a model and batch input.
|
||||
|
||||
Helpful testing wrapper function used to avoid code duplication.
|
||||
|
|
|
@ -120,7 +120,11 @@ def model_forward_pass(
|
|||
layer, args, kwargs
|
||||
)
|
||||
generator_outputs_list.append(layer_generator_outputs)
|
||||
args = (node_layer_outputs,)
|
||||
args = (
|
||||
node_layer_outputs
|
||||
if isinstance(node_layer_outputs, tuple)
|
||||
else (node_layer_outputs,)
|
||||
)
|
||||
kwargs = {}
|
||||
|
||||
# Update the current dictionary of inputs for the next node.
|
||||
|
|
|
@ -56,17 +56,21 @@ import tensorflow as tf
|
|||
# ==============================================================================
|
||||
# Type aliases
|
||||
# ==============================================================================
|
||||
SquareNormFunction = Callable[[Any], tf.Tensor]
|
||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
|
||||
|
||||
RegistryFunctionOutput = Tuple[Any, tf.Tensor, SquareNormFunction]
|
||||
OutputTensor = Union[tf.Tensor, Iterable[tf.Tensor]]
|
||||
|
||||
BatchSize = Union[int, tf.Tensor]
|
||||
|
||||
SquareNormFunction = Callable[[OutputTensor], tf.Tensor]
|
||||
|
||||
RegistryFunctionOutput = Tuple[Any, OutputTensor, SquareNormFunction]
|
||||
|
||||
RegistryFunction = Callable[
|
||||
[Any, Tuple[Any], tf.GradientTape], RegistryFunctionOutput
|
||||
[Any, Tuple[Any, ...], Dict[Text, Any], tf.GradientTape],
|
||||
RegistryFunctionOutput,
|
||||
]
|
||||
|
||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
|
||||
BatchSize = Union[int, tf.Tensor]
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main class
|
||||
|
@ -134,7 +138,8 @@ def add_microbatch_axis(
|
|||
# ==============================================================================
|
||||
def dense_layer_computation(
|
||||
layer_instance: tf.keras.layers.Dense,
|
||||
inputs: Tuple[InputTensor],
|
||||
input_args: Tuple[Any, ...],
|
||||
input_kwargs: Dict[Text, Any],
|
||||
tape: tf.GradientTape,
|
||||
num_microbatches: Optional[tf.Tensor] = None,
|
||||
) -> RegistryFunctionOutput:
|
||||
|
@ -148,9 +153,12 @@ def dense_layer_computation(
|
|||
|
||||
Args:
|
||||
layer_instance: A `tf.keras.layers.Dense` instance.
|
||||
inputs: A tuple containing a single `InputTensor` which can be passed into
|
||||
the layer instance, i.e., `layer_instance(*inputs)` returns a valid
|
||||
output.
|
||||
input_args: A `tuple` containing the first part of `layer_instance` input.
|
||||
Specifically, `layer_instance(*inputs_args, **input_kwargs)` should return
|
||||
a valid output.
|
||||
input_kwargs: A `tuple` containing the second part of `layer_instance`
|
||||
input. Specifically, `layer_instance(*inputs_args, **input_kwargs)` should
|
||||
return a valid output.
|
||||
tape: A `tf.GradientTape` instance that will be used to watch the output
|
||||
`base_vars`.
|
||||
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
|
||||
|
@ -169,11 +177,14 @@ def dense_layer_computation(
|
|||
trainable variables in `layer_instance`. These squared norms should be a 1D
|
||||
`tf.Tensor` of length `batch_size`.
|
||||
"""
|
||||
if len(inputs) != 1:
|
||||
if input_kwargs:
|
||||
raise ValueError("Dense layer calls should not receive kwargs.")
|
||||
del input_kwargs # Unused in dense layer calls.
|
||||
if len(input_args) != 1:
|
||||
raise ValueError("Only layer inputs of length 1 are permitted.")
|
||||
orig_activation = layer_instance.activation
|
||||
layer_instance.activation = None
|
||||
base_vars = layer_instance(*inputs)
|
||||
base_vars = layer_instance(*input_args)
|
||||
tape.watch(base_vars)
|
||||
layer_instance.activation = orig_activation
|
||||
outputs = orig_activation(base_vars) if orig_activation else base_vars
|
||||
|
@ -188,7 +199,7 @@ def dense_layer_computation(
|
|||
# Special handling for better efficiency
|
||||
return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x)))
|
||||
|
||||
inputs_gram = _compute_gramian(*inputs)
|
||||
inputs_gram = _compute_gramian(*input_args)
|
||||
base_vars_grads_gram = _compute_gramian(base_vars_grads)
|
||||
if layer_instance.use_bias:
|
||||
# Adding a bias term is equivalent to a layer with no bias term and which
|
||||
|
@ -206,7 +217,8 @@ def dense_layer_computation(
|
|||
|
||||
def embedding_layer_computation(
|
||||
layer_instance: tf.keras.layers.Embedding,
|
||||
inputs: Tuple[InputTensor],
|
||||
input_args: Tuple[Any, ...],
|
||||
input_kwargs: Dict[Text, Any],
|
||||
tape: tf.GradientTape,
|
||||
num_microbatches: Optional[tf.Tensor] = None,
|
||||
) -> RegistryFunctionOutput:
|
||||
|
@ -220,33 +232,23 @@ def embedding_layer_computation(
|
|||
|
||||
Args:
|
||||
layer_instance: A `tf.keras.layers.Embedding` instance.
|
||||
inputs: A tuple containing a single `InputTensor` which can be passed into
|
||||
the layer instance, i.e., `layer_instance(*inputs)` returns a valid
|
||||
output.
|
||||
tape: A `tf.GradientTape` instance that will be used to watch the output
|
||||
`base_vars`.
|
||||
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
|
||||
indicating whether and how the losses are grouped into microbatches. If
|
||||
not None, num_microbatches must divide the batch size.
|
||||
input_args: See `dense_layer_computation()`.
|
||||
input_kwargs: See `dense_layer_computation()`.
|
||||
tape: See `dense_layer_computation()`.
|
||||
num_microbatches: See `dense_layer_computation()`.
|
||||
|
||||
Returns:
|
||||
A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the
|
||||
intermediate Tensor used in the chain-rule / "fast" clipping trick,
|
||||
`outputs` is the result of `layer_instance(*inputs)`, and `sqr_norm_fn` is
|
||||
a function that takes one input, a `tf.Tensor` that represents the output
|
||||
of the call `tape.gradient(summed_loss, base_vars)` where `tape` is a
|
||||
`tf.GradientTape` instance that records the dense layer computation and
|
||||
`summed_loss` is the sum of the per-example losses of the underlying model.
|
||||
This function then returns the per-example squared L2 gradient norms of the
|
||||
trainable variables in `layer_instance`. These squared norms should be a 1D
|
||||
`tf.Tensor` of length `batch_size`.
|
||||
See `dense_layer_computation()`.
|
||||
"""
|
||||
if len(inputs) != 1:
|
||||
if input_kwargs:
|
||||
raise ValueError("Embedding layer calls should not receive kwargs.")
|
||||
del input_kwargs # Unused in embedding layer calls.
|
||||
if len(input_args) != 1:
|
||||
raise ValueError("Only layer inputs of length 1 are permitted.")
|
||||
if hasattr(layer_instance, "sparse"): # for backwards compatibility
|
||||
if layer_instance.sparse:
|
||||
raise NotImplementedError("Sparse output tensors are not supported.")
|
||||
if isinstance(inputs[0], tf.SparseTensor):
|
||||
if isinstance(input_args[0], tf.SparseTensor):
|
||||
raise NotImplementedError("Sparse input tensors are not supported.")
|
||||
|
||||
# Disable experimental features.
|
||||
|
@ -256,7 +258,7 @@ def embedding_layer_computation(
|
|||
"The experimental embedding feature"
|
||||
"'_use_one_hot_matmul' is not supported."
|
||||
)
|
||||
input_ids = tf.cast(*inputs, tf.int32)
|
||||
input_ids = tf.cast(*input_args, tf.int32)
|
||||
base_vars = layer_instance.trainable_variables[0]
|
||||
tape.watch(base_vars)
|
||||
outputs = tf.nn.embedding_lookup(base_vars, input_ids)
|
||||
|
|
Loading…
Reference in a new issue