Generalize the registry function for the embedding layer for other models.

PiperOrigin-RevId: 509528743
This commit is contained in:
A. Unique TensorFlower 2023-02-14 07:58:37 -08:00
parent 410814ec39
commit 430f103354

View file

@ -136,8 +136,8 @@ def embedding_layer_computation(layer_instance, inputs):
`layer_instance(inputs)` returns a valid output.
Returns:
A `tuple` `(base_vars, None, sqr_norm_fn)`, `base_vars` is the
intermediate Tensor used in the chain-rule / "fast" clipping trick, and
A `tuple` `(base_vars, transform, sqr_norm_fn)`, `base_vars` is the
intermediate Tensor used in the chain-rule / "fast" clipping trick,
`sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that
represents the output of the call `tape.gradient(summed_loss, base_vars)`
where `tape` is a `tf.GradientTape` instance that records the dense
@ -148,41 +148,72 @@ def embedding_layer_computation(layer_instance, inputs):
"""
if hasattr(layer_instance, "sparse"): # for backwards compatibility
if layer_instance.sparse:
raise NotImplementedError("Sparse output vectors are not supported.")
if len(inputs[0].shape) != 2:
raise NotImplementedError("Only 2D embedding inputs are supported.")
# The logic below is applied to properly handle repeated embedding indices.
# Specifically, sqr_grad_norms will contain the total counts of each embedding
# index (see how it is processed in the combine_pre_and_post_sqr_norms()
# function in clip_grads.py). An example is as follows:
#
# inputs =
# [[0 0 0 1 2 2],
# [0 2 2 2 1 1]]
#
# counts =
# [[3 1 2]
# [1 2 3]]
#
# input_counts =
# [[3 3 3 1 2 2],
# [1 3 3 3 2 2]]
#
base_vars = layer_instance(*inputs)
def sqr_norm_fn(base_vars_grads):
indices = tf.cast(*inputs, tf.int32)
if isinstance(indices, tf.SparseTensor):
indices = tf.sparse.to_dense(indices)
counts = tf.math.bincount(indices, axis=-1)
input_counts = tf.expand_dims(
tf.cast(tf.gather(counts, indices, batch_dims=1), base_vars.dtype),
axis=-1,
)
scaled_grads = input_counts * tf.square(base_vars_grads)
reduction_axes = tf.range(1, tf.rank(scaled_grads))
return tf.reduce_sum(scaled_grads, axis=reduction_axes)
raise NotImplementedError("Sparse output tensors are not supported.")
if isinstance(inputs, tf.SparseTensor):
raise NotImplementedError("Sparse input tensors are not supported.")
return base_vars, None, sqr_norm_fn
# Disable experimental features.
if hasattr(layer_instance, "_use_one_hot_matmul"):
if layer_instance._use_one_hot_matmul: # pylint: disable=protected-access
raise NotImplementedError(
"The experimental embedding feature"
"'_use_one_hot_matmul' is not supported."
)
input_ids = tf.cast(*inputs, tf.int32)
base_vars = layer_instance.trainable_variables[0]
def lookup_inputs(embeddings):
return tf.nn.embedding_lookup(embeddings, input_ids)
def sqr_norm_fn(base_vars_grads):
# Get a 1D tensor of the row indices.
nrows = tf.shape(input_ids)[0]
if isinstance(input_ids, tf.RaggedTensor):
row_indices = tf.expand_dims(
input_ids.merge_dims(1, -1).value_rowids(), axis=-1
)
elif isinstance(input_ids, tf.Tensor):
ncols = tf.reduce_prod(tf.shape(input_ids)[1:])
repeats = tf.repeat(ncols, nrows)
row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1])
else:
raise NotImplementedError(
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
)
# Sum-reduce the `IndexSlices` that is the result of a `tape.gradient()`
# call. The sum is reduced by the repeated embedding indices and batch
# index. It is adapted from the logic in:
# tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices
if not isinstance(base_vars_grads, tf.IndexedSlices):
raise NotImplementedError(
"Cannot parse embedding gradients of type: %s"
% base_vars_grads.__class__.__name__
)
slice_indices = tf.expand_dims(base_vars_grads.indices, axis=-1)
paired_indices = tf.concat(
[tf.cast(row_indices, tf.int64), tf.cast(slice_indices, tf.int64)],
axis=1,
)
(unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2(
x=paired_indices, axis=[0]
)
unique_batch_ids = unique_paired_indices[:, 0]
summed_gradients = tf.math.unsorted_segment_sum(
base_vars_grads.values,
new_index_positions,
tf.shape(unique_paired_indices)[0],
)
# Compute the squared gradient norms at the per-example level.
sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1)
summed_data_range = tf.range(tf.shape(sqr_gradient_sum)[0])
return tf.sparse.segment_sum(
sqr_gradient_sum,
summed_data_range,
tf.sort(unique_batch_ids),
num_segments=nrows,
) # fill in empty inputs
return base_vars, lookup_inputs, sqr_norm_fn
# ==============================================================================