Add the second set of EinsumDense utility functions for implementing fast gradient norm computation.

PiperOrigin-RevId: 568063831
This commit is contained in:
A. Unique TensorFlower 2023-09-24 16:16:11 -07:00
parent 1be6e026e7
commit 62a2d43d1c
3 changed files with 270 additions and 30 deletions

View file

@ -19,6 +19,7 @@ py_test(
name = "einsum_utils_test", name = "einsum_utils_test",
srcs = ["einsum_utils_test.py"], srcs = ["einsum_utils_test.py"],
python_version = "PY3", python_version = "PY3",
shard_count = 4,
srcs_version = "PY3", srcs_version = "PY3",
deps = [":einsum_utils"], deps = [":einsum_utils"],
) )

View file

@ -15,6 +15,7 @@
import enum import enum
import itertools import itertools
import os
import re import re
import numpy as np import numpy as np
@ -37,6 +38,36 @@ def _is_batch_of_vectors(t: tf.Tensor) -> bool:
return num_nontrivial_indices <= 1 return num_nontrivial_indices <= 1
def _is_valid_einsum_equation(
maybe_ab: str,
maybe_bc: str,
maybe_ac: str,
) -> bool:
"""Checks if three input strings form a valid einsum dense equation.
Given three substrings `maybe_ab`, `maybe_bc`, and `maybe_ac`, this function
checks if
```
maybe_ab + ',' + maybe_bc + '->' + maybe_ac
```
is an einsum equation of the form `ab,bc->ac`.
Args:
maybe_ab: The proposed `ab` substring.
maybe_bc: The proposed `bc` substring.
maybe_ac: The proposed `ac` substring.
Returns:
`True` if the three input strings form an einsum equation of the form
`ab,bc->ac` and `False` otherwise.
"""
a_substr = os.path.commonprefix([maybe_ab, maybe_ac])
a_len = len(a_substr)
b_substr = maybe_ab[a_len:]
c_substr = maybe_ac[a_len:]
return maybe_bc == b_substr + c_substr
def _parse_einsum_equation( def _parse_einsum_equation(
equation: str, equation: str,
) -> tuple[EquationType, tuple[str, str, str]]: ) -> tuple[EquationType, tuple[str, str, str]]:
@ -61,22 +92,33 @@ def _parse_einsum_equation(
maybe_match = re.fullmatch(regex_str, equation) maybe_match = re.fullmatch(regex_str, equation)
return maybe_match.groups() if maybe_match is not None else None return maybe_match.groups() if maybe_match is not None else None
groups1 = _try_match(r"([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)") error_message = (
if groups1 is not None:
return EquationType.NO_ELLIPSES, groups1
groups2 = _try_match(r"\.\.\.([a-zA-Z]+),([a-zA-Z]+)->\.\.\.([a-zA-Z]+)")
if groups2 is not None:
return EquationType.LEFT_ELLIPSES, groups2
groups3 = _try_match(r"([a-zA-Z]+)\.\.\.,([a-zA-Z]+)->([a-zA-Z]+)\.\.\.")
if groups3 is not None:
return EquationType.RIGHT_ELLIPSES, groups3
raise ValueError(
"Invalid Einsum equation string " "Invalid Einsum equation string "
+ equation + equation
+ " ." + " ."
"Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, " "Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, "
"{ab...,bc->ac...}" "{ab...,bc->ac...}"
) )
case_pairs = [
# equation_type, regex_str
(EquationType.NO_ELLIPSES, r"([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)"),
(
EquationType.LEFT_ELLIPSES,
r"\.\.\.([a-zA-Z]+),([a-zA-Z]+)->\.\.\.([a-zA-Z]+)",
),
(
EquationType.RIGHT_ELLIPSES,
r"([a-zA-Z]+)\.\.\.,([a-zA-Z]+)->([a-zA-Z]+)\.\.\.",
),
]
for equation_type, regex_str in case_pairs:
groups = _try_match(regex_str)
if groups is not None:
if not _is_valid_einsum_equation(*groups):
raise ValueError(error_message)
return equation_type, groups
# No valid cases found. Raise an error.
raise ValueError(error_message)
def _reshape_einsum_inputs( def _reshape_einsum_inputs(
@ -100,13 +142,17 @@ def _reshape_einsum_inputs(
and the number of output columns is the second dimension of the input. The and the number of output columns is the second dimension of the input. The
product of the non-trivial dimensions of the output should be equal to product of the non-trivial dimensions of the output should be equal to
the product of the dimensions of `input_tensor`. the product of the dimensions of `input_tensor`.
Raises:
ValueError: If `equation` is not a valid einsum equation in the context of
the `tf.keras.layers.EinsumDense` layer.
""" """
# Find the components `ab`, `bc`, and `ac` given that `equation` can only be # Find the components `ab`, `bc`, and `ac` given that `equation` can only be
# one of the following mutually exclusive forms: # one of the following mutually exclusive forms:
# #
# (C1) ab,bc->ac, # C1. ab,bc->ac,
# (C2) ...ab,bc->...ac # C2. ...ab,bc->...ac
# (C3) ab...,bc->ac... # C3. ab...,bc->ac...
# #
# NOTE: `a`, `b`, and `c` are (possibly) also substrings. # NOTE: `a`, `b`, and `c` are (possibly) also substrings.
@ -115,7 +161,7 @@ def _reshape_einsum_inputs(
input_len = len(input_shape) input_len = len(input_shape)
equation_type, (ab_str, bc_str, ac_str) = _parse_einsum_equation(equation) equation_type, (ab_str, bc_str, ac_str) = _parse_einsum_equation(equation)
if equation_type == EquationType.LEFT_ELLIPSES: if equation_type == EquationType.LEFT_ELLIPSES:
# In case (C2), the `a` part of this component can be empty, so we have no # In case C2, the `a` part of this component can be empty, so we have no
# choice but to compare the `c` part of `ac` with the `bc` component. # choice but to compare the `c` part of `ac` with the `bc` component.
c_len = 0 c_len = 0
for s1, s2 in itertools.zip_longest(reversed(bc_str), reversed(ac_str)): for s1, s2 in itertools.zip_longest(reversed(bc_str), reversed(ac_str)):
@ -135,7 +181,7 @@ def _reshape_einsum_inputs(
else: else:
break break
# Prepare `input_tensor` for reshaping and get the pivot index of the prepped # Prepare `input_tensor` for reshaping and get the pivot index of the prepped
# tensor. Note that case (C3) requires a transpose to ensure that matrix # tensor. Note that case C3 requires a transpose to ensure that matrix
# multiplication is performed by the caller. # multiplication is performed by the caller.
if equation_type == EquationType.RIGHT_ELLIPSES: if equation_type == EquationType.RIGHT_ELLIPSES:
ellipses_idx = len(ab_str) ellipses_idx = len(ab_str)
@ -178,17 +224,124 @@ def _reshape_einsum_outputs(
A rank-3 `tf.Tensor` whose first dimension is the batch dimension. The A rank-3 `tf.Tensor` whose first dimension is the batch dimension. The
product of the non-trivial dimensions of the output should be equal to product of the non-trivial dimensions of the output should be equal to
the product of the non-trivial dimensions of `output_tensor`. the product of the non-trivial dimensions of `output_tensor`.
Raises:
ValueError: If `equation` is not a valid einsum equation in the context of
the `tf.keras.layers.EinsumDense` layer.
""" """
match = re.fullmatch(r"([a-zA-Z.]+),([a-zA-Z.]+)->([a-zA-Z.]+)", equation) # Get the raw components of the reversed equation.
if match is not None: equation_type, (ab_str, bc_str, ac_str) = _parse_einsum_equation(equation)
s1, s2, s3 = match.groups() prefix = "..." if equation_type == EquationType.LEFT_ELLIPSES else ""
else: suffix = "..." if equation_type == EquationType.RIGHT_ELLIPSES else ""
raise ValueError( ellided_ab_str = prefix + ab_str + suffix
"Invalid Einsum equation string " ellided_ac_str = prefix + ac_str + suffix
+ equation # Swap the `b` and `c` components.
+ " ." c_str = os.path.commonprefix([bc_str[::-1], ac_str[::-1]])[::-1]
"Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, " b_len = len(bc_str) - len(c_str)
"{ab...,bc->ac...}" b_str = bc_str[:b_len]
) cb_str = c_str + b_str
reversed_equation = s3 + "," + s2[::-1] + "->" + s1 reversed_equation = ellided_ac_str + "," + cb_str + "->" + ellided_ab_str
return _reshape_einsum_inputs(output_tensor, reversed_equation) return _reshape_einsum_inputs(output_tensor, reversed_equation)
def _get_einsum_bias_adjoint_reduction_axes(
equation: str,
bias_axes: str,
einsum_rank: int,
) -> list[int]:
"""Computes axes related to the per-example adjoint of the einsum bias op.
To describe the output of this computation, first recall that for each
example the `EinsumDense` layer performs the following transformation:
```
F(W, bias | X) = Einsum(W, X) + Q(bias)
```
where `W` is a tensor of trainable variables, `bias` is a tensor of rank
`len(bias_axes)`, `X` is a batch of inputs, and `Q` is a linear broadcast
operator that roughly corresponds to `Q(bias) ~= tf.broadcast_to(bias, S)` for
`S := tf.shape(Einsum(W, X))`.
It is straightforward to show that the per-example adjoint of `Q` is given by
`Q'(Y) := tf.reduce_sum(Y, axes=R)` where `R` contains the broadcasting
indices. This function returns `R` as an unordered list of `int`s.
Assumptions:
A1. `equation` is one of the following forms:
C1. `ab,bc->ac`
C2. `...ab,bc->...ac`
C3. `ab...,bc->ac...`
A2. The first character in the substring `a` (or `...a` in C2)
in assumption A1 corresponds to the batch dimension.
A3. The characters in `bias_axes` must be subset of the non-batch dimension
characters in the substring `ac` (or `...ac` in C2) in
assumption A1.
A4. `einsum_rank` is the length of the substring `ac` (or `...ac` in C2) in
assumption A1. This includes the batch dimension.
Examples:
1. equation = 'ab,bc->ac', bias_axes = 'c', einsum_rank = 2 -> []
2. equation = 'ab,bce->ace', bias_axes = 'ce', einsum_rank = 3, -> []
3. equation = 'ab,bce->ace', bias_axes = 'c', einsum_rank = 3, -> [2]
4. equation = 'ab,bce->ace', bias_axes = 'e', einsum_rank = 3, -> [1]
5. equation = 'ab,bced->aced', bias_axes = 'ced', einsum_rank = 4 -> []
6. equation = 'ab,bced->aced', bias_axes = 'ce', einsum_rank = 4, -> [3],
7. equation = 'ab,bced->aced', bias_axes = 'c', einsum_rank = 4, -> [2, 3]
8. equation = '...ab,bce->...ace', bias_axes = 'c', einsum_rank = 4
-> [1, 3]
9. equation = '...ab,bce->...ace', bias_axes = 'c', einsum_rank = 10
-> [1, 2, 3, 4, 5, 6, 7, 9]
10. equation = 'ab...,bce->ace...', bias_axes = 'e', einsum_rank = 4
-> [1, 3]
Args:
equation: The einsum equation `string`.
bias_axes: A substring of the output part of `equation` specifying which
axes a bias `tf.Tensor` is added to.
einsum_rank: The rank of the tensor that the per-example adjoint operator is
being applied to.
Returns:
A list of `int` containing axes in the `input` corresponding to
`input_rank`. Each `int` is at most `input_rank-1` and excludes zero.
Raises:
ValueError: If `equation` is not a valid einsum equation in the context of
the `tf.keras.layers.EinsumDense` layer.
"""
reduction_axes = []
bias_char_set = set(bias_axes)
equation_type, (_, _, ac_str) = _parse_einsum_equation(equation)
# Do not allow the bias axes to be the batch axis, since we want the adjoint
# of the bias broadcast op to apply the same operation to all examples in a
# batch.
if equation_type != EquationType.LEFT_ELLIPSES and ac_str[0] in bias_axes:
raise ValueError(f"Bias axis '{bias_axes}' cannot also be the batch axis.")
# If `equation` of the form `...ab,bc->...ac`, i.e., case C2, we do a
# right to left traversal; the other cases do a left to right traversal.
input_indices = range(einsum_rank)
traversal_zip = (
itertools.zip_longest(reversed(input_indices), reversed(ac_str))
if equation_type == EquationType.LEFT_ELLIPSES
else itertools.zip_longest(input_indices, ac_str)
)
# Traverse the output part of `equation` and add an index to the output if
# the corresponding `char` in the `ac` part is NOT in `bias_axes` and the
# index is not zero (batch dimension). Add all indices except index zero in
# the `...` part of the output substring (if present).
for idx, output_char in traversal_zip:
# Exclude the batch dimension (idx == 0), since we want the per-example
# adjoint.
if idx != 0:
if output_char is not None and bias_char_set:
if output_char not in bias_char_set:
reduction_axes.append(idx)
else:
bias_char_set.remove(output_char)
else:
reduction_axes.append(idx)
return reduction_axes

View file

@ -45,6 +45,21 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
computed_result = einsum_utils._is_batch_of_vectors(t) computed_result = einsum_utils._is_batch_of_vectors(t)
self.assertEqual(computed_result, true_result) self.assertEqual(computed_result, true_result)
@parameterized.product(
experiment_params=[
(('ab', 'bc', 'ac'), True),
(('ab', 'a', 'b'), False),
(('ab', 'ca', 'bc'), False),
(('b', 'bc', 'c'), True),
(('ab', 'bc', 'bc'), False),
(('abc', 'cde', 'abde'), True),
]
)
def test_is_valid_einsum_equation(self, experiment_params):
inputs, true_result = experiment_params
computed_result = einsum_utils._is_valid_einsum_equation(*inputs)
self.assertEqual(computed_result, true_result)
@parameterized.product( @parameterized.product(
experiment_params=[ experiment_params=[
( (
@ -66,7 +81,7 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
) )
def test_parse_einsum_equation(self, experiment_params): def test_parse_einsum_equation(self, experiment_params):
equation, true_eqn_type, true_groups = experiment_params equation, true_eqn_type, true_groups = experiment_params
(computed_eqn_type, computed_groups) = einsum_utils._parse_einsum_equation( computed_eqn_type, computed_groups = einsum_utils._parse_einsum_equation(
equation equation
) )
self.assertEqual(computed_eqn_type, true_eqn_type) self.assertEqual(computed_eqn_type, true_eqn_type)
@ -98,7 +113,7 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
] ]
) )
def test_reshape_einsum_inputs(self, experiment_params): def test_reshape_einsum_inputs(self, experiment_params):
(equation, input_shape, true_permutations, true_parsed_shape) = ( equation, input_shape, true_permutations, true_parsed_shape = (
experiment_params experiment_params
) )
num_entries = int(np.prod(input_shape)) num_entries = int(np.prod(input_shape))
@ -141,7 +156,7 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
] ]
) )
def test_reshape_einsum_outputs(self, experiment_params): def test_reshape_einsum_outputs(self, experiment_params):
(equation, output_shape, true_permutations, true_parsed_shape) = ( equation, output_shape, true_permutations, true_parsed_shape = (
experiment_params experiment_params
) )
num_entries = int(np.prod(output_shape)) num_entries = int(np.prod(output_shape))
@ -158,6 +173,77 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
true_parsed_tensor = tf.reshape(true_parsed_tensor, true_parsed_shape) true_parsed_tensor = tf.reshape(true_parsed_tensor, true_parsed_shape)
self.assertAllEqual(computed_parsed_tensor, true_parsed_tensor) self.assertAllEqual(computed_parsed_tensor, true_parsed_tensor)
@parameterized.product(
experiment_params=[
# einsum_utils.EquationType.NO_ELLIPSES
('ab,bc->ac', 'c', 2, []),
('ab,bce->ace', 'ce', 3, []),
('ab,bce->ace', 'ec', 3, []),
('ab,bce->ace', 'c', 3, [2]),
('ab,bce->ace', 'e', 3, [1]),
('ab,bced->aced', 'ced', 4, []),
('ab,bced->aced', 'edc', 4, []),
('ab,bced->aced', 'ce', 4, [3]),
('ab,bced->aced', 'ec', 4, [3]),
('ab,bced->aced', 'cd', 4, [2]),
('ab,bced->aced', 'ed', 4, [1]),
('ab,bced->aced', 'c', 4, [2, 3]),
('ab,bced->aced', 'e', 4, [1, 3]),
('ab,bced->aced', 'd', 4, [1, 2]),
# einsum_utils.EquationType.LEFT_ELLIPSES
('...b,bc->...c', 'c', 2, []),
('...b,bce->...ce', 'c', 3, [2]),
('...b,bce->...ce', 'e', 3, [1]),
('...ab,bc->...ac', 'c', 3, [1]),
('...ab,bce->...ace', 'ac', 4, [3]),
('...ab,bce->...ace', 'ae', 4, [2]),
('...ab,bce->...ace', 'ce', 4, [1]),
('...ab,bce->...ace', 'ec', 4, [1]),
('...ab,bce->...ace', 'a', 4, [2, 3]),
('...ab,bce->...ace', 'c', 4, [1, 3]),
('...ab,bce->...ace', 'e', 4, [1, 2]),
('...ab,bce->...ace', 'c', 5, [1, 2, 4]),
('...ab,bce->...ace', 'c', 10, [1, 2, 3, 4, 5, 6, 7, 9]),
# einsum_utils.EquationType.RIGHT_ELLIPSES
('ab...,bc->ac...', 'c', 3, [2]),
('ab...,bce->ace...', 'ce', 4, [3]),
('ab...,bce->ace...', 'ec', 4, [3]),
('ab...,bce->ace...', 'c', 4, [2, 3]),
('ab...,bce->ace...', 'e', 4, [1, 3]),
]
)
def test_get_einsum_bias_adjoint_reduction_axes(self, experiment_params):
equation, bias_axes, einsum_rank, true_reduction_axes = experiment_params
computed_reduction_axes = (
einsum_utils._get_einsum_bias_adjoint_reduction_axes(
equation, bias_axes, einsum_rank
)
)
computed_reduction_axes.sort()
true_reduction_axes.sort()
self.assertAllEqual(computed_reduction_axes, true_reduction_axes)
@parameterized.product(
experiment_params=[
# einsum_utils.EquationType.NO_ELLIPSES
('ab,bc->ac', 'a', 2),
# einsum_utils.EquationType.RIGHT_ELLIPSES
('ab...,bc->ac...', 'a', 3),
('ab...,bc->ac...', 'a', 4),
('ab...,bcde->acde...', 'acd', 4),
]
)
def test_bias_axis_eq_batch_axis_throws_error(self, experiment_params):
equation, bias_axes, einsum_rank = experiment_params
with self.assertRaises(ValueError) as context:
einsum_utils._get_einsum_bias_adjoint_reduction_axes(
equation, bias_axes, einsum_rank
)
self.assertEqual(
f"Bias axis '{bias_axes}' cannot also be the batch axis.",
str(context.exception),
)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()