Add compatability for Einsum layers with dynamic shapes.

PiperOrigin-RevId: 628111219
This commit is contained in:
William Kong 2024-04-25 10:08:19 -07:00 committed by A. Unique TensorFlower
parent 3fa0a2d362
commit 3deaae30a1
2 changed files with 45 additions and 7 deletions

View file

@ -28,11 +28,11 @@ EquationType = enum.Enum(
) )
def _is_batch_of_vectors(t: tf.Tensor) -> bool: def is_batch_of_vectors(t: tf.Tensor) -> bool:
"""Checks if an input is a batch of (effectively) 1D vectors.""" """Checks if an input is a batch of (effectively) 1D vectors."""
num_nontrivial_indices = 0 num_nontrivial_indices = 0
for s in t.shape[1:]: for s in t.shape[1:]:
if s > 1: if s is None or s > 1:
num_nontrivial_indices += 1 num_nontrivial_indices += 1
if num_nontrivial_indices > 1: if num_nontrivial_indices > 1:
return False return False
@ -442,8 +442,8 @@ def compute_fast_einsum_squared_gradient_norm(
# NOTE: When the input/gradient tensors are 1D, it is MUCH faster to do # NOTE: When the input/gradient tensors are 1D, it is MUCH faster to do
# a `tf.square()` + `tf.reduce_sum()` than a single `tf.matmul()`. # a `tf.square()` + `tf.reduce_sum()` than a single `tf.matmul()`.
if ( if (
_is_batch_of_vectors(input_tensor) is_batch_of_vectors(input_tensor)
and _is_batch_of_vectors(grad_tensor) and is_batch_of_vectors(grad_tensor)
and num_microbatches is None and num_microbatches is None
): ):
x_matrix = tf.reshape(x, [tf.shape(x)[0], -1]) x_matrix = tf.reshape(x, [tf.shape(x)[0], -1])

View file

@ -29,20 +29,58 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
([1, 2], True), ([1, 2], True),
([2, 1], True), ([2, 1], True),
([2, 2], True), ([2, 2], True),
# 3D tensors # 3D tensors (batch of vectors)
([2, 1, 1], True), ([2, 1, 1], True),
([1, 2, 1], True), ([1, 2, 1], True),
([1, 1, 2], True), ([1, 1, 2], True),
([2, 2, 1], True), ([2, 2, 1], True),
([2, 1, 2], True), ([2, 1, 2], True),
# 3D tensors (batch of matrices)
([1, 2, 2], False), ([1, 2, 2], False),
([2, 2, 2], False), ([2, 2, 2], False),
] ]
) )
def test_is_batch_of_vectors(self, experiment_params): def test_is_batch_of_vectors_on_static_shapes(self, experiment_params):
shape, true_result = experiment_params shape, true_result = experiment_params
t = tf.zeros(shape) t = tf.zeros(shape)
computed_result = einsum_utils._is_batch_of_vectors(t) computed_result = einsum_utils.is_batch_of_vectors(t)
self.assertEqual(computed_result, true_result)
@parameterized.product(
experiment_params=[
# 1D tensors
([None], True),
# 2D tensors
([1, None], True),
([None, 1], True),
([None, None], True),
([2, None], True),
([None, 2], True),
# 3D tensors (batch of vectors)
([None, 1, 1], True),
([1, None, 1], True),
([1, 1, None], True),
([None, None, 1], True),
([None, 2, 1], True),
([2, None, 1], True),
([None, 1, None], True),
([None, 1, 2], True),
([2, 1, None], True),
([1, None, None], False),
# 3D tensors (batch of matrices)
([1, 2, None], False),
([1, None, 2], False),
([None, None, None], False),
([None, 2, None], False),
([None, None, 2], False),
([2, None, 2], False),
([2, 2, None], False),
]
)
def test_is_batch_of_vectors_on_dynamic_shapes(self, experiment_params):
shape, true_result = experiment_params
t = tf.keras.Input(shape=shape[1:], batch_size=shape[0])
computed_result = einsum_utils.is_batch_of_vectors(t)
self.assertEqual(computed_result, true_result) self.assertEqual(computed_result, true_result)
@parameterized.product( @parameterized.product(