Add compatability for Einsum layers with dynamic shapes.
PiperOrigin-RevId: 628111219
This commit is contained in:
parent
3fa0a2d362
commit
3deaae30a1
2 changed files with 45 additions and 7 deletions
|
@ -28,11 +28,11 @@ EquationType = enum.Enum(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_batch_of_vectors(t: tf.Tensor) -> bool:
|
def is_batch_of_vectors(t: tf.Tensor) -> bool:
|
||||||
"""Checks if an input is a batch of (effectively) 1D vectors."""
|
"""Checks if an input is a batch of (effectively) 1D vectors."""
|
||||||
num_nontrivial_indices = 0
|
num_nontrivial_indices = 0
|
||||||
for s in t.shape[1:]:
|
for s in t.shape[1:]:
|
||||||
if s > 1:
|
if s is None or s > 1:
|
||||||
num_nontrivial_indices += 1
|
num_nontrivial_indices += 1
|
||||||
if num_nontrivial_indices > 1:
|
if num_nontrivial_indices > 1:
|
||||||
return False
|
return False
|
||||||
|
@ -442,8 +442,8 @@ def compute_fast_einsum_squared_gradient_norm(
|
||||||
# NOTE: When the input/gradient tensors are 1D, it is MUCH faster to do
|
# NOTE: When the input/gradient tensors are 1D, it is MUCH faster to do
|
||||||
# a `tf.square()` + `tf.reduce_sum()` than a single `tf.matmul()`.
|
# a `tf.square()` + `tf.reduce_sum()` than a single `tf.matmul()`.
|
||||||
if (
|
if (
|
||||||
_is_batch_of_vectors(input_tensor)
|
is_batch_of_vectors(input_tensor)
|
||||||
and _is_batch_of_vectors(grad_tensor)
|
and is_batch_of_vectors(grad_tensor)
|
||||||
and num_microbatches is None
|
and num_microbatches is None
|
||||||
):
|
):
|
||||||
x_matrix = tf.reshape(x, [tf.shape(x)[0], -1])
|
x_matrix = tf.reshape(x, [tf.shape(x)[0], -1])
|
||||||
|
|
|
@ -29,20 +29,58 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
([1, 2], True),
|
([1, 2], True),
|
||||||
([2, 1], True),
|
([2, 1], True),
|
||||||
([2, 2], True),
|
([2, 2], True),
|
||||||
# 3D tensors
|
# 3D tensors (batch of vectors)
|
||||||
([2, 1, 1], True),
|
([2, 1, 1], True),
|
||||||
([1, 2, 1], True),
|
([1, 2, 1], True),
|
||||||
([1, 1, 2], True),
|
([1, 1, 2], True),
|
||||||
([2, 2, 1], True),
|
([2, 2, 1], True),
|
||||||
([2, 1, 2], True),
|
([2, 1, 2], True),
|
||||||
|
# 3D tensors (batch of matrices)
|
||||||
([1, 2, 2], False),
|
([1, 2, 2], False),
|
||||||
([2, 2, 2], False),
|
([2, 2, 2], False),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_is_batch_of_vectors(self, experiment_params):
|
def test_is_batch_of_vectors_on_static_shapes(self, experiment_params):
|
||||||
shape, true_result = experiment_params
|
shape, true_result = experiment_params
|
||||||
t = tf.zeros(shape)
|
t = tf.zeros(shape)
|
||||||
computed_result = einsum_utils._is_batch_of_vectors(t)
|
computed_result = einsum_utils.is_batch_of_vectors(t)
|
||||||
|
self.assertEqual(computed_result, true_result)
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
experiment_params=[
|
||||||
|
# 1D tensors
|
||||||
|
([None], True),
|
||||||
|
# 2D tensors
|
||||||
|
([1, None], True),
|
||||||
|
([None, 1], True),
|
||||||
|
([None, None], True),
|
||||||
|
([2, None], True),
|
||||||
|
([None, 2], True),
|
||||||
|
# 3D tensors (batch of vectors)
|
||||||
|
([None, 1, 1], True),
|
||||||
|
([1, None, 1], True),
|
||||||
|
([1, 1, None], True),
|
||||||
|
([None, None, 1], True),
|
||||||
|
([None, 2, 1], True),
|
||||||
|
([2, None, 1], True),
|
||||||
|
([None, 1, None], True),
|
||||||
|
([None, 1, 2], True),
|
||||||
|
([2, 1, None], True),
|
||||||
|
([1, None, None], False),
|
||||||
|
# 3D tensors (batch of matrices)
|
||||||
|
([1, 2, None], False),
|
||||||
|
([1, None, 2], False),
|
||||||
|
([None, None, None], False),
|
||||||
|
([None, 2, None], False),
|
||||||
|
([None, None, 2], False),
|
||||||
|
([2, None, 2], False),
|
||||||
|
([2, 2, None], False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_is_batch_of_vectors_on_dynamic_shapes(self, experiment_params):
|
||||||
|
shape, true_result = experiment_params
|
||||||
|
t = tf.keras.Input(shape=shape[1:], batch_size=shape[0])
|
||||||
|
computed_result = einsum_utils.is_batch_of_vectors(t)
|
||||||
self.assertEqual(computed_result, true_result)
|
self.assertEqual(computed_result, true_result)
|
||||||
|
|
||||||
@parameterized.product(
|
@parameterized.product(
|
||||||
|
|
Loading…
Reference in a new issue