Add the second set of EinsumDense utility functions for implementing fast gradient norm computation.
PiperOrigin-RevId: 568063831
This commit is contained in:
parent
1be6e026e7
commit
62a2d43d1c
3 changed files with 270 additions and 30 deletions
|
@ -19,6 +19,7 @@ py_test(
|
||||||
name = "einsum_utils_test",
|
name = "einsum_utils_test",
|
||||||
srcs = ["einsum_utils_test.py"],
|
srcs = ["einsum_utils_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
|
shard_count = 4,
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [":einsum_utils"],
|
deps = [":einsum_utils"],
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import itertools
|
import itertools
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -37,6 +38,36 @@ def _is_batch_of_vectors(t: tf.Tensor) -> bool:
|
||||||
return num_nontrivial_indices <= 1
|
return num_nontrivial_indices <= 1
|
||||||
|
|
||||||
|
|
||||||
|
def _is_valid_einsum_equation(
|
||||||
|
maybe_ab: str,
|
||||||
|
maybe_bc: str,
|
||||||
|
maybe_ac: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Checks if three input strings form a valid einsum dense equation.
|
||||||
|
|
||||||
|
Given three substrings `maybe_ab`, `maybe_bc`, and `maybe_ac`, this function
|
||||||
|
checks if
|
||||||
|
```
|
||||||
|
maybe_ab + ',' + maybe_bc + '->' + maybe_ac
|
||||||
|
```
|
||||||
|
is an einsum equation of the form `ab,bc->ac`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
maybe_ab: The proposed `ab` substring.
|
||||||
|
maybe_bc: The proposed `bc` substring.
|
||||||
|
maybe_ac: The proposed `ac` substring.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`True` if the three input strings form an einsum equation of the form
|
||||||
|
`ab,bc->ac` and `False` otherwise.
|
||||||
|
"""
|
||||||
|
a_substr = os.path.commonprefix([maybe_ab, maybe_ac])
|
||||||
|
a_len = len(a_substr)
|
||||||
|
b_substr = maybe_ab[a_len:]
|
||||||
|
c_substr = maybe_ac[a_len:]
|
||||||
|
return maybe_bc == b_substr + c_substr
|
||||||
|
|
||||||
|
|
||||||
def _parse_einsum_equation(
|
def _parse_einsum_equation(
|
||||||
equation: str,
|
equation: str,
|
||||||
) -> tuple[EquationType, tuple[str, str, str]]:
|
) -> tuple[EquationType, tuple[str, str, str]]:
|
||||||
|
@ -61,22 +92,33 @@ def _parse_einsum_equation(
|
||||||
maybe_match = re.fullmatch(regex_str, equation)
|
maybe_match = re.fullmatch(regex_str, equation)
|
||||||
return maybe_match.groups() if maybe_match is not None else None
|
return maybe_match.groups() if maybe_match is not None else None
|
||||||
|
|
||||||
groups1 = _try_match(r"([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)")
|
error_message = (
|
||||||
if groups1 is not None:
|
|
||||||
return EquationType.NO_ELLIPSES, groups1
|
|
||||||
groups2 = _try_match(r"\.\.\.([a-zA-Z]+),([a-zA-Z]+)->\.\.\.([a-zA-Z]+)")
|
|
||||||
if groups2 is not None:
|
|
||||||
return EquationType.LEFT_ELLIPSES, groups2
|
|
||||||
groups3 = _try_match(r"([a-zA-Z]+)\.\.\.,([a-zA-Z]+)->([a-zA-Z]+)\.\.\.")
|
|
||||||
if groups3 is not None:
|
|
||||||
return EquationType.RIGHT_ELLIPSES, groups3
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid Einsum equation string "
|
"Invalid Einsum equation string "
|
||||||
+ equation
|
+ equation
|
||||||
+ " ."
|
+ " ."
|
||||||
"Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, "
|
"Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, "
|
||||||
"{ab...,bc->ac...}"
|
"{ab...,bc->ac...}"
|
||||||
)
|
)
|
||||||
|
case_pairs = [
|
||||||
|
# equation_type, regex_str
|
||||||
|
(EquationType.NO_ELLIPSES, r"([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)"),
|
||||||
|
(
|
||||||
|
EquationType.LEFT_ELLIPSES,
|
||||||
|
r"\.\.\.([a-zA-Z]+),([a-zA-Z]+)->\.\.\.([a-zA-Z]+)",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
EquationType.RIGHT_ELLIPSES,
|
||||||
|
r"([a-zA-Z]+)\.\.\.,([a-zA-Z]+)->([a-zA-Z]+)\.\.\.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for equation_type, regex_str in case_pairs:
|
||||||
|
groups = _try_match(regex_str)
|
||||||
|
if groups is not None:
|
||||||
|
if not _is_valid_einsum_equation(*groups):
|
||||||
|
raise ValueError(error_message)
|
||||||
|
return equation_type, groups
|
||||||
|
# No valid cases found. Raise an error.
|
||||||
|
raise ValueError(error_message)
|
||||||
|
|
||||||
|
|
||||||
def _reshape_einsum_inputs(
|
def _reshape_einsum_inputs(
|
||||||
|
@ -100,13 +142,17 @@ def _reshape_einsum_inputs(
|
||||||
and the number of output columns is the second dimension of the input. The
|
and the number of output columns is the second dimension of the input. The
|
||||||
product of the non-trivial dimensions of the output should be equal to
|
product of the non-trivial dimensions of the output should be equal to
|
||||||
the product of the dimensions of `input_tensor`.
|
the product of the dimensions of `input_tensor`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `equation` is not a valid einsum equation in the context of
|
||||||
|
the `tf.keras.layers.EinsumDense` layer.
|
||||||
"""
|
"""
|
||||||
# Find the components `ab`, `bc`, and `ac` given that `equation` can only be
|
# Find the components `ab`, `bc`, and `ac` given that `equation` can only be
|
||||||
# one of the following mutually exclusive forms:
|
# one of the following mutually exclusive forms:
|
||||||
#
|
#
|
||||||
# (C1) ab,bc->ac,
|
# C1. ab,bc->ac,
|
||||||
# (C2) ...ab,bc->...ac
|
# C2. ...ab,bc->...ac
|
||||||
# (C3) ab...,bc->ac...
|
# C3. ab...,bc->ac...
|
||||||
#
|
#
|
||||||
# NOTE: `a`, `b`, and `c` are (possibly) also substrings.
|
# NOTE: `a`, `b`, and `c` are (possibly) also substrings.
|
||||||
|
|
||||||
|
@ -115,7 +161,7 @@ def _reshape_einsum_inputs(
|
||||||
input_len = len(input_shape)
|
input_len = len(input_shape)
|
||||||
equation_type, (ab_str, bc_str, ac_str) = _parse_einsum_equation(equation)
|
equation_type, (ab_str, bc_str, ac_str) = _parse_einsum_equation(equation)
|
||||||
if equation_type == EquationType.LEFT_ELLIPSES:
|
if equation_type == EquationType.LEFT_ELLIPSES:
|
||||||
# In case (C2), the `a` part of this component can be empty, so we have no
|
# In case C2, the `a` part of this component can be empty, so we have no
|
||||||
# choice but to compare the `c` part of `ac` with the `bc` component.
|
# choice but to compare the `c` part of `ac` with the `bc` component.
|
||||||
c_len = 0
|
c_len = 0
|
||||||
for s1, s2 in itertools.zip_longest(reversed(bc_str), reversed(ac_str)):
|
for s1, s2 in itertools.zip_longest(reversed(bc_str), reversed(ac_str)):
|
||||||
|
@ -135,7 +181,7 @@ def _reshape_einsum_inputs(
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
# Prepare `input_tensor` for reshaping and get the pivot index of the prepped
|
# Prepare `input_tensor` for reshaping and get the pivot index of the prepped
|
||||||
# tensor. Note that case (C3) requires a transpose to ensure that matrix
|
# tensor. Note that case C3 requires a transpose to ensure that matrix
|
||||||
# multiplication is performed by the caller.
|
# multiplication is performed by the caller.
|
||||||
if equation_type == EquationType.RIGHT_ELLIPSES:
|
if equation_type == EquationType.RIGHT_ELLIPSES:
|
||||||
ellipses_idx = len(ab_str)
|
ellipses_idx = len(ab_str)
|
||||||
|
@ -178,17 +224,124 @@ def _reshape_einsum_outputs(
|
||||||
A rank-3 `tf.Tensor` whose first dimension is the batch dimension. The
|
A rank-3 `tf.Tensor` whose first dimension is the batch dimension. The
|
||||||
product of the non-trivial dimensions of the output should be equal to
|
product of the non-trivial dimensions of the output should be equal to
|
||||||
the product of the non-trivial dimensions of `output_tensor`.
|
the product of the non-trivial dimensions of `output_tensor`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `equation` is not a valid einsum equation in the context of
|
||||||
|
the `tf.keras.layers.EinsumDense` layer.
|
||||||
"""
|
"""
|
||||||
match = re.fullmatch(r"([a-zA-Z.]+),([a-zA-Z.]+)->([a-zA-Z.]+)", equation)
|
# Get the raw components of the reversed equation.
|
||||||
if match is not None:
|
equation_type, (ab_str, bc_str, ac_str) = _parse_einsum_equation(equation)
|
||||||
s1, s2, s3 = match.groups()
|
prefix = "..." if equation_type == EquationType.LEFT_ELLIPSES else ""
|
||||||
else:
|
suffix = "..." if equation_type == EquationType.RIGHT_ELLIPSES else ""
|
||||||
raise ValueError(
|
ellided_ab_str = prefix + ab_str + suffix
|
||||||
"Invalid Einsum equation string "
|
ellided_ac_str = prefix + ac_str + suffix
|
||||||
+ equation
|
# Swap the `b` and `c` components.
|
||||||
+ " ."
|
c_str = os.path.commonprefix([bc_str[::-1], ac_str[::-1]])[::-1]
|
||||||
"Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, "
|
b_len = len(bc_str) - len(c_str)
|
||||||
"{ab...,bc->ac...}"
|
b_str = bc_str[:b_len]
|
||||||
)
|
cb_str = c_str + b_str
|
||||||
reversed_equation = s3 + "," + s2[::-1] + "->" + s1
|
reversed_equation = ellided_ac_str + "," + cb_str + "->" + ellided_ab_str
|
||||||
return _reshape_einsum_inputs(output_tensor, reversed_equation)
|
return _reshape_einsum_inputs(output_tensor, reversed_equation)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_einsum_bias_adjoint_reduction_axes(
|
||||||
|
equation: str,
|
||||||
|
bias_axes: str,
|
||||||
|
einsum_rank: int,
|
||||||
|
) -> list[int]:
|
||||||
|
"""Computes axes related to the per-example adjoint of the einsum bias op.
|
||||||
|
|
||||||
|
To describe the output of this computation, first recall that for each
|
||||||
|
example the `EinsumDense` layer performs the following transformation:
|
||||||
|
```
|
||||||
|
F(W, bias | X) = Einsum(W, X) + Q(bias)
|
||||||
|
```
|
||||||
|
where `W` is a tensor of trainable variables, `bias` is a tensor of rank
|
||||||
|
`len(bias_axes)`, `X` is a batch of inputs, and `Q` is a linear broadcast
|
||||||
|
operator that roughly corresponds to `Q(bias) ~= tf.broadcast_to(bias, S)` for
|
||||||
|
`S := tf.shape(Einsum(W, X))`.
|
||||||
|
|
||||||
|
It is straightforward to show that the per-example adjoint of `Q` is given by
|
||||||
|
`Q'(Y) := tf.reduce_sum(Y, axes=R)` where `R` contains the broadcasting
|
||||||
|
indices. This function returns `R` as an unordered list of `int`s.
|
||||||
|
|
||||||
|
Assumptions:
|
||||||
|
|
||||||
|
A1. `equation` is one of the following forms:
|
||||||
|
C1. `ab,bc->ac`
|
||||||
|
C2. `...ab,bc->...ac`
|
||||||
|
C3. `ab...,bc->ac...`
|
||||||
|
|
||||||
|
A2. The first character in the substring `a` (or `...a` in C2)
|
||||||
|
in assumption A1 corresponds to the batch dimension.
|
||||||
|
|
||||||
|
A3. The characters in `bias_axes` must be subset of the non-batch dimension
|
||||||
|
characters in the substring `ac` (or `...ac` in C2) in
|
||||||
|
assumption A1.
|
||||||
|
|
||||||
|
A4. `einsum_rank` is the length of the substring `ac` (or `...ac` in C2) in
|
||||||
|
assumption A1. This includes the batch dimension.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. equation = 'ab,bc->ac', bias_axes = 'c', einsum_rank = 2 -> []
|
||||||
|
2. equation = 'ab,bce->ace', bias_axes = 'ce', einsum_rank = 3, -> []
|
||||||
|
3. equation = 'ab,bce->ace', bias_axes = 'c', einsum_rank = 3, -> [2]
|
||||||
|
4. equation = 'ab,bce->ace', bias_axes = 'e', einsum_rank = 3, -> [1]
|
||||||
|
5. equation = 'ab,bced->aced', bias_axes = 'ced', einsum_rank = 4 -> []
|
||||||
|
6. equation = 'ab,bced->aced', bias_axes = 'ce', einsum_rank = 4, -> [3],
|
||||||
|
7. equation = 'ab,bced->aced', bias_axes = 'c', einsum_rank = 4, -> [2, 3]
|
||||||
|
8. equation = '...ab,bce->...ace', bias_axes = 'c', einsum_rank = 4
|
||||||
|
-> [1, 3]
|
||||||
|
9. equation = '...ab,bce->...ace', bias_axes = 'c', einsum_rank = 10
|
||||||
|
-> [1, 2, 3, 4, 5, 6, 7, 9]
|
||||||
|
10. equation = 'ab...,bce->ace...', bias_axes = 'e', einsum_rank = 4
|
||||||
|
-> [1, 3]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
equation: The einsum equation `string`.
|
||||||
|
bias_axes: A substring of the output part of `equation` specifying which
|
||||||
|
axes a bias `tf.Tensor` is added to.
|
||||||
|
einsum_rank: The rank of the tensor that the per-example adjoint operator is
|
||||||
|
being applied to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of `int` containing axes in the `input` corresponding to
|
||||||
|
`input_rank`. Each `int` is at most `input_rank-1` and excludes zero.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `equation` is not a valid einsum equation in the context of
|
||||||
|
the `tf.keras.layers.EinsumDense` layer.
|
||||||
|
"""
|
||||||
|
reduction_axes = []
|
||||||
|
bias_char_set = set(bias_axes)
|
||||||
|
equation_type, (_, _, ac_str) = _parse_einsum_equation(equation)
|
||||||
|
# Do not allow the bias axes to be the batch axis, since we want the adjoint
|
||||||
|
# of the bias broadcast op to apply the same operation to all examples in a
|
||||||
|
# batch.
|
||||||
|
if equation_type != EquationType.LEFT_ELLIPSES and ac_str[0] in bias_axes:
|
||||||
|
raise ValueError(f"Bias axis '{bias_axes}' cannot also be the batch axis.")
|
||||||
|
# If `equation` of the form `...ab,bc->...ac`, i.e., case C2, we do a
|
||||||
|
# right to left traversal; the other cases do a left to right traversal.
|
||||||
|
input_indices = range(einsum_rank)
|
||||||
|
traversal_zip = (
|
||||||
|
itertools.zip_longest(reversed(input_indices), reversed(ac_str))
|
||||||
|
if equation_type == EquationType.LEFT_ELLIPSES
|
||||||
|
else itertools.zip_longest(input_indices, ac_str)
|
||||||
|
)
|
||||||
|
# Traverse the output part of `equation` and add an index to the output if
|
||||||
|
# the corresponding `char` in the `ac` part is NOT in `bias_axes` and the
|
||||||
|
# index is not zero (batch dimension). Add all indices except index zero in
|
||||||
|
# the `...` part of the output substring (if present).
|
||||||
|
for idx, output_char in traversal_zip:
|
||||||
|
# Exclude the batch dimension (idx == 0), since we want the per-example
|
||||||
|
# adjoint.
|
||||||
|
if idx != 0:
|
||||||
|
if output_char is not None and bias_char_set:
|
||||||
|
if output_char not in bias_char_set:
|
||||||
|
reduction_axes.append(idx)
|
||||||
|
else:
|
||||||
|
bias_char_set.remove(output_char)
|
||||||
|
else:
|
||||||
|
reduction_axes.append(idx)
|
||||||
|
return reduction_axes
|
||||||
|
|
|
@ -45,6 +45,21 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
computed_result = einsum_utils._is_batch_of_vectors(t)
|
computed_result = einsum_utils._is_batch_of_vectors(t)
|
||||||
self.assertEqual(computed_result, true_result)
|
self.assertEqual(computed_result, true_result)
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
experiment_params=[
|
||||||
|
(('ab', 'bc', 'ac'), True),
|
||||||
|
(('ab', 'a', 'b'), False),
|
||||||
|
(('ab', 'ca', 'bc'), False),
|
||||||
|
(('b', 'bc', 'c'), True),
|
||||||
|
(('ab', 'bc', 'bc'), False),
|
||||||
|
(('abc', 'cde', 'abde'), True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_is_valid_einsum_equation(self, experiment_params):
|
||||||
|
inputs, true_result = experiment_params
|
||||||
|
computed_result = einsum_utils._is_valid_einsum_equation(*inputs)
|
||||||
|
self.assertEqual(computed_result, true_result)
|
||||||
|
|
||||||
@parameterized.product(
|
@parameterized.product(
|
||||||
experiment_params=[
|
experiment_params=[
|
||||||
(
|
(
|
||||||
|
@ -66,7 +81,7 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
)
|
)
|
||||||
def test_parse_einsum_equation(self, experiment_params):
|
def test_parse_einsum_equation(self, experiment_params):
|
||||||
equation, true_eqn_type, true_groups = experiment_params
|
equation, true_eqn_type, true_groups = experiment_params
|
||||||
(computed_eqn_type, computed_groups) = einsum_utils._parse_einsum_equation(
|
computed_eqn_type, computed_groups = einsum_utils._parse_einsum_equation(
|
||||||
equation
|
equation
|
||||||
)
|
)
|
||||||
self.assertEqual(computed_eqn_type, true_eqn_type)
|
self.assertEqual(computed_eqn_type, true_eqn_type)
|
||||||
|
@ -98,7 +113,7 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_reshape_einsum_inputs(self, experiment_params):
|
def test_reshape_einsum_inputs(self, experiment_params):
|
||||||
(equation, input_shape, true_permutations, true_parsed_shape) = (
|
equation, input_shape, true_permutations, true_parsed_shape = (
|
||||||
experiment_params
|
experiment_params
|
||||||
)
|
)
|
||||||
num_entries = int(np.prod(input_shape))
|
num_entries = int(np.prod(input_shape))
|
||||||
|
@ -141,7 +156,7 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_reshape_einsum_outputs(self, experiment_params):
|
def test_reshape_einsum_outputs(self, experiment_params):
|
||||||
(equation, output_shape, true_permutations, true_parsed_shape) = (
|
equation, output_shape, true_permutations, true_parsed_shape = (
|
||||||
experiment_params
|
experiment_params
|
||||||
)
|
)
|
||||||
num_entries = int(np.prod(output_shape))
|
num_entries = int(np.prod(output_shape))
|
||||||
|
@ -158,6 +173,77 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
true_parsed_tensor = tf.reshape(true_parsed_tensor, true_parsed_shape)
|
true_parsed_tensor = tf.reshape(true_parsed_tensor, true_parsed_shape)
|
||||||
self.assertAllEqual(computed_parsed_tensor, true_parsed_tensor)
|
self.assertAllEqual(computed_parsed_tensor, true_parsed_tensor)
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
experiment_params=[
|
||||||
|
# einsum_utils.EquationType.NO_ELLIPSES
|
||||||
|
('ab,bc->ac', 'c', 2, []),
|
||||||
|
('ab,bce->ace', 'ce', 3, []),
|
||||||
|
('ab,bce->ace', 'ec', 3, []),
|
||||||
|
('ab,bce->ace', 'c', 3, [2]),
|
||||||
|
('ab,bce->ace', 'e', 3, [1]),
|
||||||
|
('ab,bced->aced', 'ced', 4, []),
|
||||||
|
('ab,bced->aced', 'edc', 4, []),
|
||||||
|
('ab,bced->aced', 'ce', 4, [3]),
|
||||||
|
('ab,bced->aced', 'ec', 4, [3]),
|
||||||
|
('ab,bced->aced', 'cd', 4, [2]),
|
||||||
|
('ab,bced->aced', 'ed', 4, [1]),
|
||||||
|
('ab,bced->aced', 'c', 4, [2, 3]),
|
||||||
|
('ab,bced->aced', 'e', 4, [1, 3]),
|
||||||
|
('ab,bced->aced', 'd', 4, [1, 2]),
|
||||||
|
# einsum_utils.EquationType.LEFT_ELLIPSES
|
||||||
|
('...b,bc->...c', 'c', 2, []),
|
||||||
|
('...b,bce->...ce', 'c', 3, [2]),
|
||||||
|
('...b,bce->...ce', 'e', 3, [1]),
|
||||||
|
('...ab,bc->...ac', 'c', 3, [1]),
|
||||||
|
('...ab,bce->...ace', 'ac', 4, [3]),
|
||||||
|
('...ab,bce->...ace', 'ae', 4, [2]),
|
||||||
|
('...ab,bce->...ace', 'ce', 4, [1]),
|
||||||
|
('...ab,bce->...ace', 'ec', 4, [1]),
|
||||||
|
('...ab,bce->...ace', 'a', 4, [2, 3]),
|
||||||
|
('...ab,bce->...ace', 'c', 4, [1, 3]),
|
||||||
|
('...ab,bce->...ace', 'e', 4, [1, 2]),
|
||||||
|
('...ab,bce->...ace', 'c', 5, [1, 2, 4]),
|
||||||
|
('...ab,bce->...ace', 'c', 10, [1, 2, 3, 4, 5, 6, 7, 9]),
|
||||||
|
# einsum_utils.EquationType.RIGHT_ELLIPSES
|
||||||
|
('ab...,bc->ac...', 'c', 3, [2]),
|
||||||
|
('ab...,bce->ace...', 'ce', 4, [3]),
|
||||||
|
('ab...,bce->ace...', 'ec', 4, [3]),
|
||||||
|
('ab...,bce->ace...', 'c', 4, [2, 3]),
|
||||||
|
('ab...,bce->ace...', 'e', 4, [1, 3]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_get_einsum_bias_adjoint_reduction_axes(self, experiment_params):
|
||||||
|
equation, bias_axes, einsum_rank, true_reduction_axes = experiment_params
|
||||||
|
computed_reduction_axes = (
|
||||||
|
einsum_utils._get_einsum_bias_adjoint_reduction_axes(
|
||||||
|
equation, bias_axes, einsum_rank
|
||||||
|
)
|
||||||
|
)
|
||||||
|
computed_reduction_axes.sort()
|
||||||
|
true_reduction_axes.sort()
|
||||||
|
self.assertAllEqual(computed_reduction_axes, true_reduction_axes)
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
experiment_params=[
|
||||||
|
# einsum_utils.EquationType.NO_ELLIPSES
|
||||||
|
('ab,bc->ac', 'a', 2),
|
||||||
|
# einsum_utils.EquationType.RIGHT_ELLIPSES
|
||||||
|
('ab...,bc->ac...', 'a', 3),
|
||||||
|
('ab...,bc->ac...', 'a', 4),
|
||||||
|
('ab...,bcde->acde...', 'acd', 4),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_bias_axis_eq_batch_axis_throws_error(self, experiment_params):
|
||||||
|
equation, bias_axes, einsum_rank = experiment_params
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
einsum_utils._get_einsum_bias_adjoint_reduction_axes(
|
||||||
|
equation, bias_axes, einsum_rank
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
f"Bias axis '{bias_axes}' cannot also be the batch axis.",
|
||||||
|
str(context.exception),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue