Generalize the registry function for the embedding layer for other models.
PiperOrigin-RevId: 509528743
This commit is contained in:
parent
410814ec39
commit
430f103354
1 changed files with 67 additions and 36 deletions
|
@ -136,8 +136,8 @@ def embedding_layer_computation(layer_instance, inputs):
|
|||
`layer_instance(inputs)` returns a valid output.
|
||||
|
||||
Returns:
|
||||
A `tuple` `(base_vars, None, sqr_norm_fn)`, `base_vars` is the
|
||||
intermediate Tensor used in the chain-rule / "fast" clipping trick, and
|
||||
A `tuple` `(base_vars, transform, sqr_norm_fn)`, `base_vars` is the
|
||||
intermediate Tensor used in the chain-rule / "fast" clipping trick,
|
||||
`sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that
|
||||
represents the output of the call `tape.gradient(summed_loss, base_vars)`
|
||||
where `tape` is a `tf.GradientTape` instance that records the dense
|
||||
|
@ -148,41 +148,72 @@ def embedding_layer_computation(layer_instance, inputs):
|
|||
"""
|
||||
if hasattr(layer_instance, "sparse"): # for backwards compatibility
|
||||
if layer_instance.sparse:
|
||||
raise NotImplementedError("Sparse output vectors are not supported.")
|
||||
if len(inputs[0].shape) != 2:
|
||||
raise NotImplementedError("Only 2D embedding inputs are supported.")
|
||||
# The logic below is applied to properly handle repeated embedding indices.
|
||||
# Specifically, sqr_grad_norms will contain the total counts of each embedding
|
||||
# index (see how it is processed in the combine_pre_and_post_sqr_norms()
|
||||
# function in clip_grads.py). An example is as follows:
|
||||
#
|
||||
# inputs =
|
||||
# [[0 0 0 1 2 2],
|
||||
# [0 2 2 2 1 1]]
|
||||
#
|
||||
# counts =
|
||||
# [[3 1 2]
|
||||
# [1 2 3]]
|
||||
#
|
||||
# input_counts =
|
||||
# [[3 3 3 1 2 2],
|
||||
# [1 3 3 3 2 2]]
|
||||
#
|
||||
base_vars = layer_instance(*inputs)
|
||||
def sqr_norm_fn(base_vars_grads):
|
||||
indices = tf.cast(*inputs, tf.int32)
|
||||
if isinstance(indices, tf.SparseTensor):
|
||||
indices = tf.sparse.to_dense(indices)
|
||||
counts = tf.math.bincount(indices, axis=-1)
|
||||
input_counts = tf.expand_dims(
|
||||
tf.cast(tf.gather(counts, indices, batch_dims=1), base_vars.dtype),
|
||||
axis=-1,
|
||||
)
|
||||
scaled_grads = input_counts * tf.square(base_vars_grads)
|
||||
reduction_axes = tf.range(1, tf.rank(scaled_grads))
|
||||
return tf.reduce_sum(scaled_grads, axis=reduction_axes)
|
||||
raise NotImplementedError("Sparse output tensors are not supported.")
|
||||
if isinstance(inputs, tf.SparseTensor):
|
||||
raise NotImplementedError("Sparse input tensors are not supported.")
|
||||
|
||||
return base_vars, None, sqr_norm_fn
|
||||
# Disable experimental features.
|
||||
if hasattr(layer_instance, "_use_one_hot_matmul"):
|
||||
if layer_instance._use_one_hot_matmul: # pylint: disable=protected-access
|
||||
raise NotImplementedError(
|
||||
"The experimental embedding feature"
|
||||
"'_use_one_hot_matmul' is not supported."
|
||||
)
|
||||
input_ids = tf.cast(*inputs, tf.int32)
|
||||
base_vars = layer_instance.trainable_variables[0]
|
||||
|
||||
def lookup_inputs(embeddings):
|
||||
return tf.nn.embedding_lookup(embeddings, input_ids)
|
||||
|
||||
def sqr_norm_fn(base_vars_grads):
|
||||
# Get a 1D tensor of the row indices.
|
||||
nrows = tf.shape(input_ids)[0]
|
||||
if isinstance(input_ids, tf.RaggedTensor):
|
||||
row_indices = tf.expand_dims(
|
||||
input_ids.merge_dims(1, -1).value_rowids(), axis=-1
|
||||
)
|
||||
elif isinstance(input_ids, tf.Tensor):
|
||||
ncols = tf.reduce_prod(tf.shape(input_ids)[1:])
|
||||
repeats = tf.repeat(ncols, nrows)
|
||||
row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
|
||||
)
|
||||
# Sum-reduce the `IndexSlices` that is the result of a `tape.gradient()`
|
||||
# call. The sum is reduced by the repeated embedding indices and batch
|
||||
# index. It is adapted from the logic in:
|
||||
# tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices
|
||||
if not isinstance(base_vars_grads, tf.IndexedSlices):
|
||||
raise NotImplementedError(
|
||||
"Cannot parse embedding gradients of type: %s"
|
||||
% base_vars_grads.__class__.__name__
|
||||
)
|
||||
slice_indices = tf.expand_dims(base_vars_grads.indices, axis=-1)
|
||||
paired_indices = tf.concat(
|
||||
[tf.cast(row_indices, tf.int64), tf.cast(slice_indices, tf.int64)],
|
||||
axis=1,
|
||||
)
|
||||
(unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2(
|
||||
x=paired_indices, axis=[0]
|
||||
)
|
||||
unique_batch_ids = unique_paired_indices[:, 0]
|
||||
summed_gradients = tf.math.unsorted_segment_sum(
|
||||
base_vars_grads.values,
|
||||
new_index_positions,
|
||||
tf.shape(unique_paired_indices)[0],
|
||||
)
|
||||
# Compute the squared gradient norms at the per-example level.
|
||||
sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1)
|
||||
summed_data_range = tf.range(tf.shape(sqr_gradient_sum)[0])
|
||||
return tf.sparse.segment_sum(
|
||||
sqr_gradient_sum,
|
||||
summed_data_range,
|
||||
tf.sort(unique_batch_ids),
|
||||
num_segments=nrows,
|
||||
) # fill in empty inputs
|
||||
|
||||
return base_vars, lookup_inputs, sqr_norm_fn
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
|
Loading…
Reference in a new issue