forked from 626_privacy/tensorflow_privacy
Improve documentation and logging of fast gradient clipping modules and callers.
PiperOrigin-RevId: 513283486
This commit is contained in:
parent
d7cd3f8af1
commit
7436930c64
5 changed files with 129 additions and 42 deletions
|
@ -6,6 +6,7 @@ py_library(
|
||||||
name = "gradient_clipping_utils",
|
name = "gradient_clipping_utils",
|
||||||
srcs = ["gradient_clipping_utils.py"],
|
srcs = ["gradient_clipping_utils.py"],
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
|
deps = [":layer_registry"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
|
|
@ -21,11 +21,18 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
|
||||||
`compute_gradient_norms()` function).
|
`compute_gradient_norms()` function).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Union, Iterable, Text
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||||
|
|
||||||
|
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
|
||||||
|
|
||||||
|
|
||||||
def get_registry_generator_fn(tape, layer_registry):
|
def get_registry_generator_fn(
|
||||||
|
tape: tf.GradientTape, layer_registry: lr.LayerRegistry
|
||||||
|
):
|
||||||
"""Creates the generator function for `compute_gradient_norms()`."""
|
"""Creates the generator function for `compute_gradient_norms()`."""
|
||||||
if layer_registry is None:
|
if layer_registry is None:
|
||||||
# Needed for backwards compatibility.
|
# Needed for backwards compatibility.
|
||||||
|
@ -53,7 +60,12 @@ def get_registry_generator_fn(tape, layer_registry):
|
||||||
return registry_generator_fn
|
return registry_generator_fn
|
||||||
|
|
||||||
|
|
||||||
def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
|
def compute_gradient_norms(
|
||||||
|
input_model: tf.keras.Model,
|
||||||
|
x_batch: InputTensor,
|
||||||
|
y_batch: tf.Tensor,
|
||||||
|
layer_registry: lr.LayerRegistry,
|
||||||
|
):
|
||||||
"""Computes the per-example loss gradient norms for given data.
|
"""Computes the per-example loss gradient norms for given data.
|
||||||
|
|
||||||
Applies a variant of the approach given in
|
Applies a variant of the approach given in
|
||||||
|
@ -62,7 +74,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
|
||||||
Args:
|
Args:
|
||||||
input_model: The `tf.keras.Model` from which to obtain the layers from. The
|
input_model: The `tf.keras.Model` from which to obtain the layers from. The
|
||||||
loss of the model *must* be a scalar loss.
|
loss of the model *must* be a scalar loss.
|
||||||
x_batch: A `tf.Tensor` representing a batch of inputs to the model. The
|
x_batch: An `InputTensor` representing a batch of inputs to the model. The
|
||||||
first axis must be the batch dimension.
|
first axis must be the batch dimension.
|
||||||
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
||||||
must be the batch dimension. The number of examples should match the
|
must be the batch dimension. The number of examples should match the
|
||||||
|
@ -106,7 +118,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
|
||||||
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
||||||
|
|
||||||
|
|
||||||
def compute_clip_weights(l2_norm_clip, gradient_norms):
|
def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
|
||||||
"""Computes the per-example loss/clip weights for clipping.
|
"""Computes the per-example loss/clip weights for clipping.
|
||||||
|
|
||||||
When the sum of the per-example losses is replaced a weighted sum, where
|
When the sum of the per-example losses is replaced a weighted sum, where
|
||||||
|
@ -132,7 +144,11 @@ def compute_clip_weights(l2_norm_clip, gradient_norms):
|
||||||
|
|
||||||
|
|
||||||
def compute_pred_and_clipped_gradients(
|
def compute_pred_and_clipped_gradients(
|
||||||
input_model, x_batch, y_batch, l2_norm_clip, layer_registry
|
input_model: tf.keras.Model,
|
||||||
|
x_batch: InputTensor,
|
||||||
|
y_batch: tf.Tensor,
|
||||||
|
l2_norm_clip: float,
|
||||||
|
layer_registry: lr.LayerRegistry,
|
||||||
):
|
):
|
||||||
"""Computes the per-example predictions and per-example clipped loss gradient.
|
"""Computes the per-example predictions and per-example clipped loss gradient.
|
||||||
|
|
||||||
|
@ -147,7 +163,7 @@ def compute_pred_and_clipped_gradients(
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_model: The `tf.keras.Model` from which to obtain the layers from.
|
input_model: The `tf.keras.Model` from which to obtain the layers from.
|
||||||
x_batch: A `tf.Tensor` representing a batch of inputs to the model. The
|
x_batch: An `InputTensor` representing a batch of inputs to the model. The
|
||||||
first axis must be the batch dimension.
|
first axis must be the batch dimension.
|
||||||
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
||||||
must be the batch dimension. The number of examples should match the
|
must be the batch dimension. The number of examples should match the
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
|
from typing import Callable, Any, List, Union
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -20,23 +22,35 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Type aliases
|
||||||
|
# ==============================================================================
|
||||||
|
LayerGenerator = Callable[[int, int], tf.keras.layers.Layer]
|
||||||
|
|
||||||
|
ModelGenerator = Callable[
|
||||||
|
[LayerGenerator, Union[int, List[int]], int], tf.keras.Model
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Helper functions and classes.
|
# Helper functions and classes.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
class DoubleDense(tf.keras.layers.Layer):
|
class DoubleDense(tf.keras.layers.Layer):
|
||||||
"""Generates two dense layers nested together."""
|
"""Generates two dense layers nested together."""
|
||||||
|
|
||||||
def __init__(self, units):
|
def __init__(self, units: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense1 = tf.keras.layers.Dense(units)
|
self.dense1 = tf.keras.layers.Dense(units)
|
||||||
self.dense2 = tf.keras.layers.Dense(1)
|
self.dense2 = tf.keras.layers.Dense(1)
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs: Any):
|
||||||
x = self.dense1(inputs)
|
x = self.dense1(inputs)
|
||||||
return self.dense2(x)
|
return self.dense2(x)
|
||||||
|
|
||||||
|
|
||||||
def double_dense_layer_computation(layer_instance, inputs, tape):
|
def double_dense_layer_computation(
|
||||||
|
layer_instance: tf.keras.layers.Layer, inputs: Any, tape: tf.GradientTape
|
||||||
|
):
|
||||||
"""Layer registry function for the custom `DoubleDense` layer class."""
|
"""Layer registry function for the custom `DoubleDense` layer class."""
|
||||||
vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation(
|
vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation(
|
||||||
layer_instance.dense1, inputs, tape
|
layer_instance.dense1, inputs, tape
|
||||||
|
@ -53,7 +67,9 @@ def double_dense_layer_computation(layer_instance, inputs, tape):
|
||||||
return [vars1, vars2], outputs, sqr_norm_fn
|
return [vars1, vars2], outputs, sqr_norm_fn
|
||||||
|
|
||||||
|
|
||||||
def compute_true_gradient_norms(input_model, x_batch, y_batch):
|
def compute_true_gradient_norms(
|
||||||
|
input_model: tf.keras.Model, x_batch: tf.Tensor, y_batch: tf.Tensor
|
||||||
|
):
|
||||||
"""Computes the real gradient norms for an input `(model, x, y)`."""
|
"""Computes the real gradient norms for an input `(model, x, y)`."""
|
||||||
loss_config = input_model.loss.get_config()
|
loss_config = input_model.loss.get_config()
|
||||||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||||
|
@ -73,14 +89,14 @@ def compute_true_gradient_norms(input_model, x_batch, y_batch):
|
||||||
|
|
||||||
|
|
||||||
def get_computed_and_true_norms(
|
def get_computed_and_true_norms(
|
||||||
model_generator,
|
model_generator: ModelGenerator,
|
||||||
layer_generator,
|
layer_generator: LayerGenerator,
|
||||||
input_dims,
|
input_dims: Union[int, List[int]],
|
||||||
output_dim,
|
output_dim: int,
|
||||||
is_eager,
|
is_eager: bool,
|
||||||
x_input,
|
x_input: tf.Tensor,
|
||||||
rng_seed=777,
|
rng_seed: int = 777,
|
||||||
registry=None,
|
registry: layer_registry.LayerRegistry = None,
|
||||||
):
|
):
|
||||||
"""Obtains the true and computed gradient norms for a model and batch input.
|
"""Obtains the true and computed gradient norms for a model and batch input.
|
||||||
|
|
||||||
|
@ -238,7 +254,7 @@ def make_weighted_bow_model(layer_generator, input_dims, output_dim):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Factory functions.
|
# Factory functions.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
def get_nd_test_tensors(n):
|
def get_nd_test_tensors(n: int):
|
||||||
"""Returns a list of candidate tests for a given dimension n."""
|
"""Returns a list of candidate tests for a given dimension n."""
|
||||||
return [
|
return [
|
||||||
tf.zeros((n,), dtype=tf.float64),
|
tf.zeros((n,), dtype=tf.float64),
|
||||||
|
@ -246,7 +262,7 @@ def get_nd_test_tensors(n):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_nd_test_batches(n):
|
def get_nd_test_batches(n: int):
|
||||||
"""Returns a list of candidate input batches of dimension n."""
|
"""Returns a list of candidate input batches of dimension n."""
|
||||||
result = []
|
result = []
|
||||||
tensors = get_nd_test_tensors(n)
|
tensors = get_nd_test_tensors(n)
|
||||||
|
|
|
@ -13,11 +13,19 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Utility functions that help in the computation of per-example gradient norms."""
|
"""Utility functions that help in the computation of per-example gradient norms."""
|
||||||
|
|
||||||
|
from typing import Any, Union, Iterable, Text, Callable, Tuple, Optional
|
||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||||
|
|
||||||
def has_internal_compute_graph(input_object):
|
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
|
||||||
|
|
||||||
|
GeneratorFunction = Optional[Callable[[Any, Tuple, dict], Tuple[Any, Any]]]
|
||||||
|
|
||||||
|
|
||||||
|
def has_internal_compute_graph(input_object: Any):
|
||||||
"""Checks if input is a TF model and has a TF internal compute graph."""
|
"""Checks if input is a TF model and has a TF internal compute graph."""
|
||||||
return (
|
return (
|
||||||
isinstance(input_object, tf.keras.Model)
|
isinstance(input_object, tf.keras.Model)
|
||||||
|
@ -28,7 +36,9 @@ def has_internal_compute_graph(input_object):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_internal_layers(input_layer):
|
def _get_internal_layers(
|
||||||
|
input_layer: tf.keras.layers.Layer,
|
||||||
|
) -> list[tf.keras.layers.Layer]:
|
||||||
"""Returns a list of layers that are nested within a given layer."""
|
"""Returns a list of layers that are nested within a given layer."""
|
||||||
internal_layers = []
|
internal_layers = []
|
||||||
if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'):
|
if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'):
|
||||||
|
@ -39,7 +49,11 @@ def _get_internal_layers(input_layer):
|
||||||
return internal_layers
|
return internal_layers
|
||||||
|
|
||||||
|
|
||||||
def model_forward_pass(input_model, inputs, generator_fn=None):
|
def model_forward_pass(
|
||||||
|
input_model: tf.keras.Model,
|
||||||
|
inputs: InputTensor,
|
||||||
|
generator_fn: GeneratorFunction = None,
|
||||||
|
) -> Tuple[tf.Tensor, list[Any]]:
|
||||||
"""Does a forward pass of a model and returns useful intermediates.
|
"""Does a forward pass of a model and returns useful intermediates.
|
||||||
|
|
||||||
NOTE: the graph traversal algorithm is an adaptation of the logic in the
|
NOTE: the graph traversal algorithm is an adaptation of the logic in the
|
||||||
|
@ -118,7 +132,9 @@ def model_forward_pass(input_model, inputs, generator_fn=None):
|
||||||
return node_layer_outputs, generator_outputs_list
|
return node_layer_outputs, generator_outputs_list
|
||||||
|
|
||||||
|
|
||||||
def all_trainable_layers_are_registered(input_model, layer_registry):
|
def all_trainable_layers_are_registered(
|
||||||
|
input_model: tf.keras.Model, layer_registry: lr.LayerRegistry
|
||||||
|
) -> bool:
|
||||||
"""Check if an input model's trainable layers are all registered.
|
"""Check if an input model's trainable layers are all registered.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -140,18 +156,21 @@ def all_trainable_layers_are_registered(input_model, layer_registry):
|
||||||
|
|
||||||
|
|
||||||
def add_aggregate_noise(
|
def add_aggregate_noise(
|
||||||
input_model, x_batch, clipped_grads, l2_norm_clip, noise_multiplier
|
input_model: tf.keras.Model,
|
||||||
):
|
x_batch: InputTensor,
|
||||||
|
clipped_grads: list[tf.Tensor],
|
||||||
|
l2_norm_clip: float,
|
||||||
|
noise_multiplier: float,
|
||||||
|
) -> list[tf.Tensor]:
|
||||||
"""Adds noise to a collection of clipped gradients.
|
"""Adds noise to a collection of clipped gradients.
|
||||||
|
|
||||||
The magnitude of the noise depends on the aggregation strategy of the
|
The magnitude of the noise depends on the aggregation strategy of the
|
||||||
input model's loss function.
|
input model's loss function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_model: The Keras model to obtain the layers from.
|
input_model: The `tf.keras.Model` to obtain the layers from.
|
||||||
x_batch: A collection of Tensors to be fed into the input layer of the
|
x_batch: An `InputTensor` to be fed into the input layer of the model.
|
||||||
model.
|
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
|
||||||
clipped_grads: A list of tensors representing the clipped gradients.
|
|
||||||
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
|
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
|
||||||
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
||||||
|
|
||||||
|
@ -187,7 +206,9 @@ def add_aggregate_noise(
|
||||||
return tf.nest.map_structure(add_noise, clipped_grads)
|
return tf.nest.map_structure(add_noise, clipped_grads)
|
||||||
|
|
||||||
|
|
||||||
def generate_model_outputs_using_core_keras_layers(input_model):
|
def generate_model_outputs_using_core_keras_layers(
|
||||||
|
input_model: tf.keras.Model,
|
||||||
|
) -> tf.Tensor:
|
||||||
"""Returns the model outputs generated by only core Keras layers."""
|
"""Returns the model outputs generated by only core Keras layers."""
|
||||||
cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects())
|
cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects())
|
||||||
cust_hash_set = set([hash(v) for v in cust_obj_dict.values()])
|
cust_hash_set = set([hash(v) for v in cust_obj_dict.values()])
|
||||||
|
|
|
@ -40,9 +40,24 @@ where `l2_row_norm(y)` computes the L2 norm for each row of an input `y`.
|
||||||
Details of this decomposition can be found in https://arxiv.org/abs/1510.01799
|
Details of this decomposition can be found in https://arxiv.org/abs/1510.01799
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Callable, Type, Any, Union, Iterable, Text
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Type aliases
|
||||||
|
# ==============================================================================
|
||||||
|
SquareNormFunction = Callable[[Any], tf.Tensor]
|
||||||
|
|
||||||
|
RegistryFunctionOutput = tuple[Any, tf.Tensor, SquareNormFunction]
|
||||||
|
|
||||||
|
RegistryFunction = Callable[
|
||||||
|
[Any, tuple[Any], tf.GradientTape], RegistryFunctionOutput
|
||||||
|
]
|
||||||
|
|
||||||
|
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Main class
|
# Main class
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
@ -54,15 +69,19 @@ class LayerRegistry:
|
||||||
self._layer_class_dict = {}
|
self._layer_class_dict = {}
|
||||||
self._registry = {}
|
self._registry = {}
|
||||||
|
|
||||||
def is_elem(self, layer_instance):
|
def is_elem(self, layer_instance: tf.keras.layers.Layer) -> bool:
|
||||||
"""Checks if a layer instance's class is in the registry."""
|
"""Checks if a layer instance's class is in the registry."""
|
||||||
return hash(layer_instance.__class__) in self._registry
|
return hash(layer_instance.__class__) in self._registry
|
||||||
|
|
||||||
def lookup(self, layer_instance):
|
def lookup(self, layer_instance: tf.keras.layers.Layer) -> RegistryFunction:
|
||||||
"""Returns the layer registry function for a given layer instance."""
|
"""Returns the layer registry function for a given layer instance."""
|
||||||
return self._registry[hash(layer_instance.__class__)]
|
return self._registry[hash(layer_instance.__class__)]
|
||||||
|
|
||||||
def insert(self, layer_class, layer_registry_function):
|
def insert(
|
||||||
|
self,
|
||||||
|
layer_class: Type[tf.keras.layers.Layer],
|
||||||
|
layer_registry_function: RegistryFunction,
|
||||||
|
):
|
||||||
"""Inserts a layer registry function into the internal dictionaries."""
|
"""Inserts a layer registry function into the internal dictionaries."""
|
||||||
layer_key = hash(layer_class)
|
layer_key = hash(layer_class)
|
||||||
self._layer_class_dict[layer_key] = layer_class
|
self._layer_class_dict[layer_key] = layer_class
|
||||||
|
@ -72,7 +91,11 @@ class LayerRegistry:
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Supported Keras layers
|
# Supported Keras layers
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
def dense_layer_computation(layer_instance, inputs, tape):
|
def dense_layer_computation(
|
||||||
|
layer_instance: tf.keras.layers.Dense,
|
||||||
|
inputs: tuple[InputTensor],
|
||||||
|
tape: tf.GradientTape,
|
||||||
|
) -> RegistryFunctionOutput:
|
||||||
"""Registry function for `tf.keras.layers.Dense`.
|
"""Registry function for `tf.keras.layers.Dense`.
|
||||||
|
|
||||||
The logic for this computation is based on the following paper:
|
The logic for this computation is based on the following paper:
|
||||||
|
@ -83,8 +106,9 @@ def dense_layer_computation(layer_instance, inputs, tape):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_instance: A `tf.keras.layers.Dense` instance.
|
layer_instance: A `tf.keras.layers.Dense` instance.
|
||||||
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
inputs: A tuple containing a single `InputTensor` which can be passed into
|
||||||
`layer_instance(inputs)` returns a valid output.
|
the layer instance, i.e., `layer_instance(*inputs)` returns a valid
|
||||||
|
output.
|
||||||
tape: A `tf.GradientTape` instance that will be used to watch the output
|
tape: A `tf.GradientTape` instance that will be used to watch the output
|
||||||
`base_vars`.
|
`base_vars`.
|
||||||
|
|
||||||
|
@ -100,6 +124,8 @@ def dense_layer_computation(layer_instance, inputs, tape):
|
||||||
trainable variables in `layer_instance`. These squared norms should be a 1D
|
trainable variables in `layer_instance`. These squared norms should be a 1D
|
||||||
`tf.Tensor` of length `batch_size`.
|
`tf.Tensor` of length `batch_size`.
|
||||||
"""
|
"""
|
||||||
|
if len(inputs) != 1:
|
||||||
|
raise ValueError("Only layer inputs of length 1 are permitted.")
|
||||||
orig_activation = layer_instance.activation
|
orig_activation = layer_instance.activation
|
||||||
layer_instance.activation = None
|
layer_instance.activation = None
|
||||||
base_vars = layer_instance(*inputs)
|
base_vars = layer_instance(*inputs)
|
||||||
|
@ -125,7 +151,11 @@ def dense_layer_computation(layer_instance, inputs, tape):
|
||||||
return base_vars, outputs, sqr_norm_fn
|
return base_vars, outputs, sqr_norm_fn
|
||||||
|
|
||||||
|
|
||||||
def embedding_layer_computation(layer_instance, inputs, tape):
|
def embedding_layer_computation(
|
||||||
|
layer_instance: tf.keras.layers.Embedding,
|
||||||
|
inputs: tuple[InputTensor],
|
||||||
|
tape: tf.GradientTape,
|
||||||
|
) -> RegistryFunctionOutput:
|
||||||
"""Registry function for `tf.keras.layers.Embedding`.
|
"""Registry function for `tf.keras.layers.Embedding`.
|
||||||
|
|
||||||
The logic of this computation is based on the `tf.keras.layers.Dense`
|
The logic of this computation is based on the `tf.keras.layers.Dense`
|
||||||
|
@ -136,8 +166,9 @@ def embedding_layer_computation(layer_instance, inputs, tape):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_instance: A `tf.keras.layers.Embedding` instance.
|
layer_instance: A `tf.keras.layers.Embedding` instance.
|
||||||
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
inputs: A tuple containing a single `InputTensor` which can be passed into
|
||||||
`layer_instance(inputs)` returns a valid output.
|
the layer instance, i.e., `layer_instance(*inputs)` returns a valid
|
||||||
|
output.
|
||||||
tape: A `tf.GradientTape` instance that will be used to watch the output
|
tape: A `tf.GradientTape` instance that will be used to watch the output
|
||||||
`base_vars`.
|
`base_vars`.
|
||||||
|
|
||||||
|
@ -153,10 +184,12 @@ def embedding_layer_computation(layer_instance, inputs, tape):
|
||||||
trainable variables in `layer_instance`. These squared norms should be a 1D
|
trainable variables in `layer_instance`. These squared norms should be a 1D
|
||||||
`tf.Tensor` of length `batch_size`.
|
`tf.Tensor` of length `batch_size`.
|
||||||
"""
|
"""
|
||||||
|
if len(inputs) != 1:
|
||||||
|
raise ValueError("Only layer inputs of length 1 are permitted.")
|
||||||
if hasattr(layer_instance, "sparse"): # for backwards compatibility
|
if hasattr(layer_instance, "sparse"): # for backwards compatibility
|
||||||
if layer_instance.sparse:
|
if layer_instance.sparse:
|
||||||
raise NotImplementedError("Sparse output tensors are not supported.")
|
raise NotImplementedError("Sparse output tensors are not supported.")
|
||||||
if isinstance(inputs, tf.SparseTensor):
|
if isinstance(inputs[0], tf.SparseTensor):
|
||||||
raise NotImplementedError("Sparse input tensors are not supported.")
|
raise NotImplementedError("Sparse input tensors are not supported.")
|
||||||
|
|
||||||
# Disable experimental features.
|
# Disable experimental features.
|
||||||
|
@ -225,7 +258,7 @@ def embedding_layer_computation(layer_instance, inputs, tape):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Main factory methods
|
# Main factory methods
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
def make_default_layer_registry():
|
def make_default_layer_registry() -> LayerRegistry:
|
||||||
registry = LayerRegistry()
|
registry = LayerRegistry()
|
||||||
registry.insert(tf.keras.layers.Dense, dense_layer_computation)
|
registry.insert(tf.keras.layers.Dense, dense_layer_computation)
|
||||||
registry.insert(tf.keras.layers.Embedding, embedding_layer_computation)
|
registry.insert(tf.keras.layers.Embedding, embedding_layer_computation)
|
||||||
|
|
Loading…
Reference in a new issue