Add compatability for Einsum layers with dynamic shapes.

PiperOrigin-RevId: 628111219
This commit is contained in:
William Kong 2024-04-25 10:08:19 -07:00 committed by A. Unique TensorFlower
parent 3fa0a2d362
commit 3deaae30a1
2 changed files with 45 additions and 7 deletions

View file

@ -28,11 +28,11 @@ EquationType = enum.Enum(
)
def _is_batch_of_vectors(t: tf.Tensor) -> bool:
def is_batch_of_vectors(t: tf.Tensor) -> bool:
"""Checks if an input is a batch of (effectively) 1D vectors."""
num_nontrivial_indices = 0
for s in t.shape[1:]:
if s > 1:
if s is None or s > 1:
num_nontrivial_indices += 1
if num_nontrivial_indices > 1:
return False
@ -442,8 +442,8 @@ def compute_fast_einsum_squared_gradient_norm(
# NOTE: When the input/gradient tensors are 1D, it is MUCH faster to do
# a `tf.square()` + `tf.reduce_sum()` than a single `tf.matmul()`.
if (
_is_batch_of_vectors(input_tensor)
and _is_batch_of_vectors(grad_tensor)
is_batch_of_vectors(input_tensor)
and is_batch_of_vectors(grad_tensor)
and num_microbatches is None
):
x_matrix = tf.reshape(x, [tf.shape(x)[0], -1])

View file

@ -29,20 +29,58 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
([1, 2], True),
([2, 1], True),
([2, 2], True),
# 3D tensors
# 3D tensors (batch of vectors)
([2, 1, 1], True),
([1, 2, 1], True),
([1, 1, 2], True),
([2, 2, 1], True),
([2, 1, 2], True),
# 3D tensors (batch of matrices)
([1, 2, 2], False),
([2, 2, 2], False),
]
)
def test_is_batch_of_vectors(self, experiment_params):
def test_is_batch_of_vectors_on_static_shapes(self, experiment_params):
shape, true_result = experiment_params
t = tf.zeros(shape)
computed_result = einsum_utils._is_batch_of_vectors(t)
computed_result = einsum_utils.is_batch_of_vectors(t)
self.assertEqual(computed_result, true_result)
@parameterized.product(
experiment_params=[
# 1D tensors
([None], True),
# 2D tensors
([1, None], True),
([None, 1], True),
([None, None], True),
([2, None], True),
([None, 2], True),
# 3D tensors (batch of vectors)
([None, 1, 1], True),
([1, None, 1], True),
([1, 1, None], True),
([None, None, 1], True),
([None, 2, 1], True),
([2, None, 1], True),
([None, 1, None], True),
([None, 1, 2], True),
([2, 1, None], True),
([1, None, None], False),
# 3D tensors (batch of matrices)
([1, 2, None], False),
([1, None, 2], False),
([None, None, None], False),
([None, 2, None], False),
([None, None, 2], False),
([2, None, 2], False),
([2, 2, None], False),
]
)
def test_is_batch_of_vectors_on_dynamic_shapes(self, experiment_params):
shape, true_result = experiment_params
t = tf.keras.Input(shape=shape[1:], batch_size=shape[0])
computed_result = einsum_utils.is_batch_of_vectors(t)
self.assertEqual(computed_result, true_result)
@parameterized.product(