Generalize the registry function for the embedding layer for other models.
PiperOrigin-RevId: 509528743
This commit is contained in:
parent
410814ec39
commit
430f103354
1 changed files with 67 additions and 36 deletions
|
@ -136,8 +136,8 @@ def embedding_layer_computation(layer_instance, inputs):
|
||||||
`layer_instance(inputs)` returns a valid output.
|
`layer_instance(inputs)` returns a valid output.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `tuple` `(base_vars, None, sqr_norm_fn)`, `base_vars` is the
|
A `tuple` `(base_vars, transform, sqr_norm_fn)`, `base_vars` is the
|
||||||
intermediate Tensor used in the chain-rule / "fast" clipping trick, and
|
intermediate Tensor used in the chain-rule / "fast" clipping trick,
|
||||||
`sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that
|
`sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that
|
||||||
represents the output of the call `tape.gradient(summed_loss, base_vars)`
|
represents the output of the call `tape.gradient(summed_loss, base_vars)`
|
||||||
where `tape` is a `tf.GradientTape` instance that records the dense
|
where `tape` is a `tf.GradientTape` instance that records the dense
|
||||||
|
@ -148,41 +148,72 @@ def embedding_layer_computation(layer_instance, inputs):
|
||||||
"""
|
"""
|
||||||
if hasattr(layer_instance, "sparse"): # for backwards compatibility
|
if hasattr(layer_instance, "sparse"): # for backwards compatibility
|
||||||
if layer_instance.sparse:
|
if layer_instance.sparse:
|
||||||
raise NotImplementedError("Sparse output vectors are not supported.")
|
raise NotImplementedError("Sparse output tensors are not supported.")
|
||||||
if len(inputs[0].shape) != 2:
|
if isinstance(inputs, tf.SparseTensor):
|
||||||
raise NotImplementedError("Only 2D embedding inputs are supported.")
|
raise NotImplementedError("Sparse input tensors are not supported.")
|
||||||
# The logic below is applied to properly handle repeated embedding indices.
|
|
||||||
# Specifically, sqr_grad_norms will contain the total counts of each embedding
|
|
||||||
# index (see how it is processed in the combine_pre_and_post_sqr_norms()
|
|
||||||
# function in clip_grads.py). An example is as follows:
|
|
||||||
#
|
|
||||||
# inputs =
|
|
||||||
# [[0 0 0 1 2 2],
|
|
||||||
# [0 2 2 2 1 1]]
|
|
||||||
#
|
|
||||||
# counts =
|
|
||||||
# [[3 1 2]
|
|
||||||
# [1 2 3]]
|
|
||||||
#
|
|
||||||
# input_counts =
|
|
||||||
# [[3 3 3 1 2 2],
|
|
||||||
# [1 3 3 3 2 2]]
|
|
||||||
#
|
|
||||||
base_vars = layer_instance(*inputs)
|
|
||||||
def sqr_norm_fn(base_vars_grads):
|
|
||||||
indices = tf.cast(*inputs, tf.int32)
|
|
||||||
if isinstance(indices, tf.SparseTensor):
|
|
||||||
indices = tf.sparse.to_dense(indices)
|
|
||||||
counts = tf.math.bincount(indices, axis=-1)
|
|
||||||
input_counts = tf.expand_dims(
|
|
||||||
tf.cast(tf.gather(counts, indices, batch_dims=1), base_vars.dtype),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
scaled_grads = input_counts * tf.square(base_vars_grads)
|
|
||||||
reduction_axes = tf.range(1, tf.rank(scaled_grads))
|
|
||||||
return tf.reduce_sum(scaled_grads, axis=reduction_axes)
|
|
||||||
|
|
||||||
return base_vars, None, sqr_norm_fn
|
# Disable experimental features.
|
||||||
|
if hasattr(layer_instance, "_use_one_hot_matmul"):
|
||||||
|
if layer_instance._use_one_hot_matmul: # pylint: disable=protected-access
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The experimental embedding feature"
|
||||||
|
"'_use_one_hot_matmul' is not supported."
|
||||||
|
)
|
||||||
|
input_ids = tf.cast(*inputs, tf.int32)
|
||||||
|
base_vars = layer_instance.trainable_variables[0]
|
||||||
|
|
||||||
|
def lookup_inputs(embeddings):
|
||||||
|
return tf.nn.embedding_lookup(embeddings, input_ids)
|
||||||
|
|
||||||
|
def sqr_norm_fn(base_vars_grads):
|
||||||
|
# Get a 1D tensor of the row indices.
|
||||||
|
nrows = tf.shape(input_ids)[0]
|
||||||
|
if isinstance(input_ids, tf.RaggedTensor):
|
||||||
|
row_indices = tf.expand_dims(
|
||||||
|
input_ids.merge_dims(1, -1).value_rowids(), axis=-1
|
||||||
|
)
|
||||||
|
elif isinstance(input_ids, tf.Tensor):
|
||||||
|
ncols = tf.reduce_prod(tf.shape(input_ids)[1:])
|
||||||
|
repeats = tf.repeat(ncols, nrows)
|
||||||
|
row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
|
||||||
|
)
|
||||||
|
# Sum-reduce the `IndexSlices` that is the result of a `tape.gradient()`
|
||||||
|
# call. The sum is reduced by the repeated embedding indices and batch
|
||||||
|
# index. It is adapted from the logic in:
|
||||||
|
# tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices
|
||||||
|
if not isinstance(base_vars_grads, tf.IndexedSlices):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Cannot parse embedding gradients of type: %s"
|
||||||
|
% base_vars_grads.__class__.__name__
|
||||||
|
)
|
||||||
|
slice_indices = tf.expand_dims(base_vars_grads.indices, axis=-1)
|
||||||
|
paired_indices = tf.concat(
|
||||||
|
[tf.cast(row_indices, tf.int64), tf.cast(slice_indices, tf.int64)],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
(unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2(
|
||||||
|
x=paired_indices, axis=[0]
|
||||||
|
)
|
||||||
|
unique_batch_ids = unique_paired_indices[:, 0]
|
||||||
|
summed_gradients = tf.math.unsorted_segment_sum(
|
||||||
|
base_vars_grads.values,
|
||||||
|
new_index_positions,
|
||||||
|
tf.shape(unique_paired_indices)[0],
|
||||||
|
)
|
||||||
|
# Compute the squared gradient norms at the per-example level.
|
||||||
|
sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1)
|
||||||
|
summed_data_range = tf.range(tf.shape(sqr_gradient_sum)[0])
|
||||||
|
return tf.sparse.segment_sum(
|
||||||
|
sqr_gradient_sum,
|
||||||
|
summed_data_range,
|
||||||
|
tf.sort(unique_batch_ids),
|
||||||
|
num_segments=nrows,
|
||||||
|
) # fill in empty inputs
|
||||||
|
|
||||||
|
return base_vars, lookup_inputs, sqr_norm_fn
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
Loading…
Reference in a new issue