Add a kwargs
argument to the registry API + small changes to docstrings.
This is a forward-looking change that is needed to support more complicated layers, such as `tf.keras.layers.MultiHeadAttention`, which can take `kwargs` as part of their `.call()` method and can generate arbitrary outputs. PiperOrigin-RevId: 514775503
This commit is contained in:
parent
21ee1a607a
commit
4e1fc252e4
4 changed files with 51 additions and 44 deletions
|
@ -52,7 +52,7 @@ def get_registry_generator_fn(
|
||||||
)
|
)
|
||||||
registry_fn = layer_registry.lookup(layer_instance)
|
registry_fn = layer_registry.lookup(layer_instance)
|
||||||
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
|
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
|
||||||
layer_instance, args, tape, num_microbatches
|
layer_instance, args, kwargs, tape, num_microbatches
|
||||||
)
|
)
|
||||||
return layer_outputs, (layer_vars, layer_sqr_norm_fn)
|
return layer_outputs, (layer_vars, layer_sqr_norm_fn)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Any, Callable, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -50,16 +50,17 @@ class DoubleDense(tf.keras.layers.Layer):
|
||||||
|
|
||||||
def double_dense_layer_computation(
|
def double_dense_layer_computation(
|
||||||
layer_instance: tf.keras.layers.Layer,
|
layer_instance: tf.keras.layers.Layer,
|
||||||
inputs: Any,
|
input_args: Tuple[Any, ...],
|
||||||
|
input_kwargs: Dict[Text, Any],
|
||||||
tape: tf.GradientTape,
|
tape: tf.GradientTape,
|
||||||
num_microbatches: Optional[int],
|
num_microbatches: Optional[int],
|
||||||
):
|
) -> layer_registry.RegistryFunctionOutput:
|
||||||
"""Layer registry function for the custom `DoubleDense` layer class."""
|
"""Layer registry function for the custom `DoubleDense` layer class."""
|
||||||
vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation(
|
vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation(
|
||||||
layer_instance.dense1, inputs, tape, num_microbatches
|
layer_instance.dense1, input_args, input_kwargs, tape, num_microbatches
|
||||||
)
|
)
|
||||||
vars2, outputs, sqr_norm_fn2 = layer_registry.dense_layer_computation(
|
vars2, outputs, sqr_norm_fn2 = layer_registry.dense_layer_computation(
|
||||||
layer_instance.dense2, (outputs,), tape, num_microbatches
|
layer_instance.dense2, (outputs,), {}, tape, num_microbatches
|
||||||
)
|
)
|
||||||
|
|
||||||
def sqr_norm_fn(base_vars):
|
def sqr_norm_fn(base_vars):
|
||||||
|
@ -75,7 +76,7 @@ def compute_true_gradient_norms(
|
||||||
x_batch: tf.Tensor,
|
x_batch: tf.Tensor,
|
||||||
y_batch: tf.Tensor,
|
y_batch: tf.Tensor,
|
||||||
num_microbatches: Optional[int],
|
num_microbatches: Optional[int],
|
||||||
):
|
) -> layer_registry.OutputTensor:
|
||||||
"""Computes the real gradient norms for an input `(model, x, y)`."""
|
"""Computes the real gradient norms for an input `(model, x, y)`."""
|
||||||
loss_config = input_model.loss.get_config()
|
loss_config = input_model.loss.get_config()
|
||||||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||||
|
@ -113,7 +114,7 @@ def get_computed_and_true_norms(
|
||||||
x_input: tf.Tensor,
|
x_input: tf.Tensor,
|
||||||
rng_seed: int = 777,
|
rng_seed: int = 777,
|
||||||
registry: layer_registry.LayerRegistry = None,
|
registry: layer_registry.LayerRegistry = None,
|
||||||
):
|
) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||||
"""Obtains the true and computed gradient norms for a model and batch input.
|
"""Obtains the true and computed gradient norms for a model and batch input.
|
||||||
|
|
||||||
Helpful testing wrapper function used to avoid code duplication.
|
Helpful testing wrapper function used to avoid code duplication.
|
||||||
|
|
|
@ -120,7 +120,11 @@ def model_forward_pass(
|
||||||
layer, args, kwargs
|
layer, args, kwargs
|
||||||
)
|
)
|
||||||
generator_outputs_list.append(layer_generator_outputs)
|
generator_outputs_list.append(layer_generator_outputs)
|
||||||
args = (node_layer_outputs,)
|
args = (
|
||||||
|
node_layer_outputs
|
||||||
|
if isinstance(node_layer_outputs, tuple)
|
||||||
|
else (node_layer_outputs,)
|
||||||
|
)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
# Update the current dictionary of inputs for the next node.
|
# Update the current dictionary of inputs for the next node.
|
||||||
|
|
|
@ -56,17 +56,21 @@ import tensorflow as tf
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Type aliases
|
# Type aliases
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
SquareNormFunction = Callable[[Any], tf.Tensor]
|
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
|
||||||
|
|
||||||
RegistryFunctionOutput = Tuple[Any, tf.Tensor, SquareNormFunction]
|
OutputTensor = Union[tf.Tensor, Iterable[tf.Tensor]]
|
||||||
|
|
||||||
|
BatchSize = Union[int, tf.Tensor]
|
||||||
|
|
||||||
|
SquareNormFunction = Callable[[OutputTensor], tf.Tensor]
|
||||||
|
|
||||||
|
RegistryFunctionOutput = Tuple[Any, OutputTensor, SquareNormFunction]
|
||||||
|
|
||||||
RegistryFunction = Callable[
|
RegistryFunction = Callable[
|
||||||
[Any, Tuple[Any], tf.GradientTape], RegistryFunctionOutput
|
[Any, Tuple[Any, ...], Dict[Text, Any], tf.GradientTape],
|
||||||
|
RegistryFunctionOutput,
|
||||||
]
|
]
|
||||||
|
|
||||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
|
|
||||||
BatchSize = Union[int, tf.Tensor]
|
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Main class
|
# Main class
|
||||||
|
@ -134,7 +138,8 @@ def add_microbatch_axis(
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
def dense_layer_computation(
|
def dense_layer_computation(
|
||||||
layer_instance: tf.keras.layers.Dense,
|
layer_instance: tf.keras.layers.Dense,
|
||||||
inputs: Tuple[InputTensor],
|
input_args: Tuple[Any, ...],
|
||||||
|
input_kwargs: Dict[Text, Any],
|
||||||
tape: tf.GradientTape,
|
tape: tf.GradientTape,
|
||||||
num_microbatches: Optional[tf.Tensor] = None,
|
num_microbatches: Optional[tf.Tensor] = None,
|
||||||
) -> RegistryFunctionOutput:
|
) -> RegistryFunctionOutput:
|
||||||
|
@ -148,9 +153,12 @@ def dense_layer_computation(
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_instance: A `tf.keras.layers.Dense` instance.
|
layer_instance: A `tf.keras.layers.Dense` instance.
|
||||||
inputs: A tuple containing a single `InputTensor` which can be passed into
|
input_args: A `tuple` containing the first part of `layer_instance` input.
|
||||||
the layer instance, i.e., `layer_instance(*inputs)` returns a valid
|
Specifically, `layer_instance(*inputs_args, **input_kwargs)` should return
|
||||||
output.
|
a valid output.
|
||||||
|
input_kwargs: A `tuple` containing the second part of `layer_instance`
|
||||||
|
input. Specifically, `layer_instance(*inputs_args, **input_kwargs)` should
|
||||||
|
return a valid output.
|
||||||
tape: A `tf.GradientTape` instance that will be used to watch the output
|
tape: A `tf.GradientTape` instance that will be used to watch the output
|
||||||
`base_vars`.
|
`base_vars`.
|
||||||
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
|
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
|
||||||
|
@ -169,11 +177,14 @@ def dense_layer_computation(
|
||||||
trainable variables in `layer_instance`. These squared norms should be a 1D
|
trainable variables in `layer_instance`. These squared norms should be a 1D
|
||||||
`tf.Tensor` of length `batch_size`.
|
`tf.Tensor` of length `batch_size`.
|
||||||
"""
|
"""
|
||||||
if len(inputs) != 1:
|
if input_kwargs:
|
||||||
|
raise ValueError("Dense layer calls should not receive kwargs.")
|
||||||
|
del input_kwargs # Unused in dense layer calls.
|
||||||
|
if len(input_args) != 1:
|
||||||
raise ValueError("Only layer inputs of length 1 are permitted.")
|
raise ValueError("Only layer inputs of length 1 are permitted.")
|
||||||
orig_activation = layer_instance.activation
|
orig_activation = layer_instance.activation
|
||||||
layer_instance.activation = None
|
layer_instance.activation = None
|
||||||
base_vars = layer_instance(*inputs)
|
base_vars = layer_instance(*input_args)
|
||||||
tape.watch(base_vars)
|
tape.watch(base_vars)
|
||||||
layer_instance.activation = orig_activation
|
layer_instance.activation = orig_activation
|
||||||
outputs = orig_activation(base_vars) if orig_activation else base_vars
|
outputs = orig_activation(base_vars) if orig_activation else base_vars
|
||||||
|
@ -188,7 +199,7 @@ def dense_layer_computation(
|
||||||
# Special handling for better efficiency
|
# Special handling for better efficiency
|
||||||
return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x)))
|
return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x)))
|
||||||
|
|
||||||
inputs_gram = _compute_gramian(*inputs)
|
inputs_gram = _compute_gramian(*input_args)
|
||||||
base_vars_grads_gram = _compute_gramian(base_vars_grads)
|
base_vars_grads_gram = _compute_gramian(base_vars_grads)
|
||||||
if layer_instance.use_bias:
|
if layer_instance.use_bias:
|
||||||
# Adding a bias term is equivalent to a layer with no bias term and which
|
# Adding a bias term is equivalent to a layer with no bias term and which
|
||||||
|
@ -206,7 +217,8 @@ def dense_layer_computation(
|
||||||
|
|
||||||
def embedding_layer_computation(
|
def embedding_layer_computation(
|
||||||
layer_instance: tf.keras.layers.Embedding,
|
layer_instance: tf.keras.layers.Embedding,
|
||||||
inputs: Tuple[InputTensor],
|
input_args: Tuple[Any, ...],
|
||||||
|
input_kwargs: Dict[Text, Any],
|
||||||
tape: tf.GradientTape,
|
tape: tf.GradientTape,
|
||||||
num_microbatches: Optional[tf.Tensor] = None,
|
num_microbatches: Optional[tf.Tensor] = None,
|
||||||
) -> RegistryFunctionOutput:
|
) -> RegistryFunctionOutput:
|
||||||
|
@ -220,33 +232,23 @@ def embedding_layer_computation(
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_instance: A `tf.keras.layers.Embedding` instance.
|
layer_instance: A `tf.keras.layers.Embedding` instance.
|
||||||
inputs: A tuple containing a single `InputTensor` which can be passed into
|
input_args: See `dense_layer_computation()`.
|
||||||
the layer instance, i.e., `layer_instance(*inputs)` returns a valid
|
input_kwargs: See `dense_layer_computation()`.
|
||||||
output.
|
tape: See `dense_layer_computation()`.
|
||||||
tape: A `tf.GradientTape` instance that will be used to watch the output
|
num_microbatches: See `dense_layer_computation()`.
|
||||||
`base_vars`.
|
|
||||||
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
|
|
||||||
indicating whether and how the losses are grouped into microbatches. If
|
|
||||||
not None, num_microbatches must divide the batch size.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the
|
See `dense_layer_computation()`.
|
||||||
intermediate Tensor used in the chain-rule / "fast" clipping trick,
|
|
||||||
`outputs` is the result of `layer_instance(*inputs)`, and `sqr_norm_fn` is
|
|
||||||
a function that takes one input, a `tf.Tensor` that represents the output
|
|
||||||
of the call `tape.gradient(summed_loss, base_vars)` where `tape` is a
|
|
||||||
`tf.GradientTape` instance that records the dense layer computation and
|
|
||||||
`summed_loss` is the sum of the per-example losses of the underlying model.
|
|
||||||
This function then returns the per-example squared L2 gradient norms of the
|
|
||||||
trainable variables in `layer_instance`. These squared norms should be a 1D
|
|
||||||
`tf.Tensor` of length `batch_size`.
|
|
||||||
"""
|
"""
|
||||||
if len(inputs) != 1:
|
if input_kwargs:
|
||||||
|
raise ValueError("Embedding layer calls should not receive kwargs.")
|
||||||
|
del input_kwargs # Unused in embedding layer calls.
|
||||||
|
if len(input_args) != 1:
|
||||||
raise ValueError("Only layer inputs of length 1 are permitted.")
|
raise ValueError("Only layer inputs of length 1 are permitted.")
|
||||||
if hasattr(layer_instance, "sparse"): # for backwards compatibility
|
if hasattr(layer_instance, "sparse"): # for backwards compatibility
|
||||||
if layer_instance.sparse:
|
if layer_instance.sparse:
|
||||||
raise NotImplementedError("Sparse output tensors are not supported.")
|
raise NotImplementedError("Sparse output tensors are not supported.")
|
||||||
if isinstance(inputs[0], tf.SparseTensor):
|
if isinstance(input_args[0], tf.SparseTensor):
|
||||||
raise NotImplementedError("Sparse input tensors are not supported.")
|
raise NotImplementedError("Sparse input tensors are not supported.")
|
||||||
|
|
||||||
# Disable experimental features.
|
# Disable experimental features.
|
||||||
|
@ -256,7 +258,7 @@ def embedding_layer_computation(
|
||||||
"The experimental embedding feature"
|
"The experimental embedding feature"
|
||||||
"'_use_one_hot_matmul' is not supported."
|
"'_use_one_hot_matmul' is not supported."
|
||||||
)
|
)
|
||||||
input_ids = tf.cast(*inputs, tf.int32)
|
input_ids = tf.cast(*input_args, tf.int32)
|
||||||
base_vars = layer_instance.trainable_variables[0]
|
base_vars = layer_instance.trainable_variables[0]
|
||||||
tape.watch(base_vars)
|
tape.watch(base_vars)
|
||||||
outputs = tf.nn.embedding_lookup(base_vars, input_ids)
|
outputs = tf.nn.embedding_lookup(base_vars, input_ids)
|
||||||
|
|
Loading…
Reference in a new issue