forked from 626_privacy/tensorflow_privacy
Add the second set of EinsumDense utility functions for implementing fast gradient norm computation.
PiperOrigin-RevId: 568063831
This commit is contained in:
parent
1be6e026e7
commit
62a2d43d1c
3 changed files with 270 additions and 30 deletions
|
@ -19,6 +19,7 @@ py_test(
|
|||
name = "einsum_utils_test",
|
||||
srcs = ["einsum_utils_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 4,
|
||||
srcs_version = "PY3",
|
||||
deps = [":einsum_utils"],
|
||||
)
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
import enum
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
@ -37,6 +38,36 @@ def _is_batch_of_vectors(t: tf.Tensor) -> bool:
|
|||
return num_nontrivial_indices <= 1
|
||||
|
||||
|
||||
def _is_valid_einsum_equation(
|
||||
maybe_ab: str,
|
||||
maybe_bc: str,
|
||||
maybe_ac: str,
|
||||
) -> bool:
|
||||
"""Checks if three input strings form a valid einsum dense equation.
|
||||
|
||||
Given three substrings `maybe_ab`, `maybe_bc`, and `maybe_ac`, this function
|
||||
checks if
|
||||
```
|
||||
maybe_ab + ',' + maybe_bc + '->' + maybe_ac
|
||||
```
|
||||
is an einsum equation of the form `ab,bc->ac`.
|
||||
|
||||
Args:
|
||||
maybe_ab: The proposed `ab` substring.
|
||||
maybe_bc: The proposed `bc` substring.
|
||||
maybe_ac: The proposed `ac` substring.
|
||||
|
||||
Returns:
|
||||
`True` if the three input strings form an einsum equation of the form
|
||||
`ab,bc->ac` and `False` otherwise.
|
||||
"""
|
||||
a_substr = os.path.commonprefix([maybe_ab, maybe_ac])
|
||||
a_len = len(a_substr)
|
||||
b_substr = maybe_ab[a_len:]
|
||||
c_substr = maybe_ac[a_len:]
|
||||
return maybe_bc == b_substr + c_substr
|
||||
|
||||
|
||||
def _parse_einsum_equation(
|
||||
equation: str,
|
||||
) -> tuple[EquationType, tuple[str, str, str]]:
|
||||
|
@ -61,22 +92,33 @@ def _parse_einsum_equation(
|
|||
maybe_match = re.fullmatch(regex_str, equation)
|
||||
return maybe_match.groups() if maybe_match is not None else None
|
||||
|
||||
groups1 = _try_match(r"([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)")
|
||||
if groups1 is not None:
|
||||
return EquationType.NO_ELLIPSES, groups1
|
||||
groups2 = _try_match(r"\.\.\.([a-zA-Z]+),([a-zA-Z]+)->\.\.\.([a-zA-Z]+)")
|
||||
if groups2 is not None:
|
||||
return EquationType.LEFT_ELLIPSES, groups2
|
||||
groups3 = _try_match(r"([a-zA-Z]+)\.\.\.,([a-zA-Z]+)->([a-zA-Z]+)\.\.\.")
|
||||
if groups3 is not None:
|
||||
return EquationType.RIGHT_ELLIPSES, groups3
|
||||
raise ValueError(
|
||||
error_message = (
|
||||
"Invalid Einsum equation string "
|
||||
+ equation
|
||||
+ " ."
|
||||
"Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, "
|
||||
"{ab...,bc->ac...}"
|
||||
)
|
||||
case_pairs = [
|
||||
# equation_type, regex_str
|
||||
(EquationType.NO_ELLIPSES, r"([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)"),
|
||||
(
|
||||
EquationType.LEFT_ELLIPSES,
|
||||
r"\.\.\.([a-zA-Z]+),([a-zA-Z]+)->\.\.\.([a-zA-Z]+)",
|
||||
),
|
||||
(
|
||||
EquationType.RIGHT_ELLIPSES,
|
||||
r"([a-zA-Z]+)\.\.\.,([a-zA-Z]+)->([a-zA-Z]+)\.\.\.",
|
||||
),
|
||||
]
|
||||
for equation_type, regex_str in case_pairs:
|
||||
groups = _try_match(regex_str)
|
||||
if groups is not None:
|
||||
if not _is_valid_einsum_equation(*groups):
|
||||
raise ValueError(error_message)
|
||||
return equation_type, groups
|
||||
# No valid cases found. Raise an error.
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def _reshape_einsum_inputs(
|
||||
|
@ -100,13 +142,17 @@ def _reshape_einsum_inputs(
|
|||
and the number of output columns is the second dimension of the input. The
|
||||
product of the non-trivial dimensions of the output should be equal to
|
||||
the product of the dimensions of `input_tensor`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `equation` is not a valid einsum equation in the context of
|
||||
the `tf.keras.layers.EinsumDense` layer.
|
||||
"""
|
||||
# Find the components `ab`, `bc`, and `ac` given that `equation` can only be
|
||||
# one of the following mutually exclusive forms:
|
||||
#
|
||||
# (C1) ab,bc->ac,
|
||||
# (C2) ...ab,bc->...ac
|
||||
# (C3) ab...,bc->ac...
|
||||
# C1. ab,bc->ac,
|
||||
# C2. ...ab,bc->...ac
|
||||
# C3. ab...,bc->ac...
|
||||
#
|
||||
# NOTE: `a`, `b`, and `c` are (possibly) also substrings.
|
||||
|
||||
|
@ -115,7 +161,7 @@ def _reshape_einsum_inputs(
|
|||
input_len = len(input_shape)
|
||||
equation_type, (ab_str, bc_str, ac_str) = _parse_einsum_equation(equation)
|
||||
if equation_type == EquationType.LEFT_ELLIPSES:
|
||||
# In case (C2), the `a` part of this component can be empty, so we have no
|
||||
# In case C2, the `a` part of this component can be empty, so we have no
|
||||
# choice but to compare the `c` part of `ac` with the `bc` component.
|
||||
c_len = 0
|
||||
for s1, s2 in itertools.zip_longest(reversed(bc_str), reversed(ac_str)):
|
||||
|
@ -135,7 +181,7 @@ def _reshape_einsum_inputs(
|
|||
else:
|
||||
break
|
||||
# Prepare `input_tensor` for reshaping and get the pivot index of the prepped
|
||||
# tensor. Note that case (C3) requires a transpose to ensure that matrix
|
||||
# tensor. Note that case C3 requires a transpose to ensure that matrix
|
||||
# multiplication is performed by the caller.
|
||||
if equation_type == EquationType.RIGHT_ELLIPSES:
|
||||
ellipses_idx = len(ab_str)
|
||||
|
@ -178,17 +224,124 @@ def _reshape_einsum_outputs(
|
|||
A rank-3 `tf.Tensor` whose first dimension is the batch dimension. The
|
||||
product of the non-trivial dimensions of the output should be equal to
|
||||
the product of the non-trivial dimensions of `output_tensor`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `equation` is not a valid einsum equation in the context of
|
||||
the `tf.keras.layers.EinsumDense` layer.
|
||||
"""
|
||||
match = re.fullmatch(r"([a-zA-Z.]+),([a-zA-Z.]+)->([a-zA-Z.]+)", equation)
|
||||
if match is not None:
|
||||
s1, s2, s3 = match.groups()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid Einsum equation string "
|
||||
+ equation
|
||||
+ " ."
|
||||
"Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, "
|
||||
"{ab...,bc->ac...}"
|
||||
)
|
||||
reversed_equation = s3 + "," + s2[::-1] + "->" + s1
|
||||
# Get the raw components of the reversed equation.
|
||||
equation_type, (ab_str, bc_str, ac_str) = _parse_einsum_equation(equation)
|
||||
prefix = "..." if equation_type == EquationType.LEFT_ELLIPSES else ""
|
||||
suffix = "..." if equation_type == EquationType.RIGHT_ELLIPSES else ""
|
||||
ellided_ab_str = prefix + ab_str + suffix
|
||||
ellided_ac_str = prefix + ac_str + suffix
|
||||
# Swap the `b` and `c` components.
|
||||
c_str = os.path.commonprefix([bc_str[::-1], ac_str[::-1]])[::-1]
|
||||
b_len = len(bc_str) - len(c_str)
|
||||
b_str = bc_str[:b_len]
|
||||
cb_str = c_str + b_str
|
||||
reversed_equation = ellided_ac_str + "," + cb_str + "->" + ellided_ab_str
|
||||
return _reshape_einsum_inputs(output_tensor, reversed_equation)
|
||||
|
||||
|
||||
def _get_einsum_bias_adjoint_reduction_axes(
|
||||
equation: str,
|
||||
bias_axes: str,
|
||||
einsum_rank: int,
|
||||
) -> list[int]:
|
||||
"""Computes axes related to the per-example adjoint of the einsum bias op.
|
||||
|
||||
To describe the output of this computation, first recall that for each
|
||||
example the `EinsumDense` layer performs the following transformation:
|
||||
```
|
||||
F(W, bias | X) = Einsum(W, X) + Q(bias)
|
||||
```
|
||||
where `W` is a tensor of trainable variables, `bias` is a tensor of rank
|
||||
`len(bias_axes)`, `X` is a batch of inputs, and `Q` is a linear broadcast
|
||||
operator that roughly corresponds to `Q(bias) ~= tf.broadcast_to(bias, S)` for
|
||||
`S := tf.shape(Einsum(W, X))`.
|
||||
|
||||
It is straightforward to show that the per-example adjoint of `Q` is given by
|
||||
`Q'(Y) := tf.reduce_sum(Y, axes=R)` where `R` contains the broadcasting
|
||||
indices. This function returns `R` as an unordered list of `int`s.
|
||||
|
||||
Assumptions:
|
||||
|
||||
A1. `equation` is one of the following forms:
|
||||
C1. `ab,bc->ac`
|
||||
C2. `...ab,bc->...ac`
|
||||
C3. `ab...,bc->ac...`
|
||||
|
||||
A2. The first character in the substring `a` (or `...a` in C2)
|
||||
in assumption A1 corresponds to the batch dimension.
|
||||
|
||||
A3. The characters in `bias_axes` must be subset of the non-batch dimension
|
||||
characters in the substring `ac` (or `...ac` in C2) in
|
||||
assumption A1.
|
||||
|
||||
A4. `einsum_rank` is the length of the substring `ac` (or `...ac` in C2) in
|
||||
assumption A1. This includes the batch dimension.
|
||||
|
||||
Examples:
|
||||
|
||||
1. equation = 'ab,bc->ac', bias_axes = 'c', einsum_rank = 2 -> []
|
||||
2. equation = 'ab,bce->ace', bias_axes = 'ce', einsum_rank = 3, -> []
|
||||
3. equation = 'ab,bce->ace', bias_axes = 'c', einsum_rank = 3, -> [2]
|
||||
4. equation = 'ab,bce->ace', bias_axes = 'e', einsum_rank = 3, -> [1]
|
||||
5. equation = 'ab,bced->aced', bias_axes = 'ced', einsum_rank = 4 -> []
|
||||
6. equation = 'ab,bced->aced', bias_axes = 'ce', einsum_rank = 4, -> [3],
|
||||
7. equation = 'ab,bced->aced', bias_axes = 'c', einsum_rank = 4, -> [2, 3]
|
||||
8. equation = '...ab,bce->...ace', bias_axes = 'c', einsum_rank = 4
|
||||
-> [1, 3]
|
||||
9. equation = '...ab,bce->...ace', bias_axes = 'c', einsum_rank = 10
|
||||
-> [1, 2, 3, 4, 5, 6, 7, 9]
|
||||
10. equation = 'ab...,bce->ace...', bias_axes = 'e', einsum_rank = 4
|
||||
-> [1, 3]
|
||||
|
||||
Args:
|
||||
equation: The einsum equation `string`.
|
||||
bias_axes: A substring of the output part of `equation` specifying which
|
||||
axes a bias `tf.Tensor` is added to.
|
||||
einsum_rank: The rank of the tensor that the per-example adjoint operator is
|
||||
being applied to.
|
||||
|
||||
Returns:
|
||||
A list of `int` containing axes in the `input` corresponding to
|
||||
`input_rank`. Each `int` is at most `input_rank-1` and excludes zero.
|
||||
|
||||
Raises:
|
||||
ValueError: If `equation` is not a valid einsum equation in the context of
|
||||
the `tf.keras.layers.EinsumDense` layer.
|
||||
"""
|
||||
reduction_axes = []
|
||||
bias_char_set = set(bias_axes)
|
||||
equation_type, (_, _, ac_str) = _parse_einsum_equation(equation)
|
||||
# Do not allow the bias axes to be the batch axis, since we want the adjoint
|
||||
# of the bias broadcast op to apply the same operation to all examples in a
|
||||
# batch.
|
||||
if equation_type != EquationType.LEFT_ELLIPSES and ac_str[0] in bias_axes:
|
||||
raise ValueError(f"Bias axis '{bias_axes}' cannot also be the batch axis.")
|
||||
# If `equation` of the form `...ab,bc->...ac`, i.e., case C2, we do a
|
||||
# right to left traversal; the other cases do a left to right traversal.
|
||||
input_indices = range(einsum_rank)
|
||||
traversal_zip = (
|
||||
itertools.zip_longest(reversed(input_indices), reversed(ac_str))
|
||||
if equation_type == EquationType.LEFT_ELLIPSES
|
||||
else itertools.zip_longest(input_indices, ac_str)
|
||||
)
|
||||
# Traverse the output part of `equation` and add an index to the output if
|
||||
# the corresponding `char` in the `ac` part is NOT in `bias_axes` and the
|
||||
# index is not zero (batch dimension). Add all indices except index zero in
|
||||
# the `...` part of the output substring (if present).
|
||||
for idx, output_char in traversal_zip:
|
||||
# Exclude the batch dimension (idx == 0), since we want the per-example
|
||||
# adjoint.
|
||||
if idx != 0:
|
||||
if output_char is not None and bias_char_set:
|
||||
if output_char not in bias_char_set:
|
||||
reduction_axes.append(idx)
|
||||
else:
|
||||
bias_char_set.remove(output_char)
|
||||
else:
|
||||
reduction_axes.append(idx)
|
||||
return reduction_axes
|
||||
|
|
|
@ -45,6 +45,21 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
computed_result = einsum_utils._is_batch_of_vectors(t)
|
||||
self.assertEqual(computed_result, true_result)
|
||||
|
||||
@parameterized.product(
|
||||
experiment_params=[
|
||||
(('ab', 'bc', 'ac'), True),
|
||||
(('ab', 'a', 'b'), False),
|
||||
(('ab', 'ca', 'bc'), False),
|
||||
(('b', 'bc', 'c'), True),
|
||||
(('ab', 'bc', 'bc'), False),
|
||||
(('abc', 'cde', 'abde'), True),
|
||||
]
|
||||
)
|
||||
def test_is_valid_einsum_equation(self, experiment_params):
|
||||
inputs, true_result = experiment_params
|
||||
computed_result = einsum_utils._is_valid_einsum_equation(*inputs)
|
||||
self.assertEqual(computed_result, true_result)
|
||||
|
||||
@parameterized.product(
|
||||
experiment_params=[
|
||||
(
|
||||
|
@ -66,7 +81,7 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
)
|
||||
def test_parse_einsum_equation(self, experiment_params):
|
||||
equation, true_eqn_type, true_groups = experiment_params
|
||||
(computed_eqn_type, computed_groups) = einsum_utils._parse_einsum_equation(
|
||||
computed_eqn_type, computed_groups = einsum_utils._parse_einsum_equation(
|
||||
equation
|
||||
)
|
||||
self.assertEqual(computed_eqn_type, true_eqn_type)
|
||||
|
@ -98,7 +113,7 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
]
|
||||
)
|
||||
def test_reshape_einsum_inputs(self, experiment_params):
|
||||
(equation, input_shape, true_permutations, true_parsed_shape) = (
|
||||
equation, input_shape, true_permutations, true_parsed_shape = (
|
||||
experiment_params
|
||||
)
|
||||
num_entries = int(np.prod(input_shape))
|
||||
|
@ -141,7 +156,7 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
]
|
||||
)
|
||||
def test_reshape_einsum_outputs(self, experiment_params):
|
||||
(equation, output_shape, true_permutations, true_parsed_shape) = (
|
||||
equation, output_shape, true_permutations, true_parsed_shape = (
|
||||
experiment_params
|
||||
)
|
||||
num_entries = int(np.prod(output_shape))
|
||||
|
@ -158,6 +173,77 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
true_parsed_tensor = tf.reshape(true_parsed_tensor, true_parsed_shape)
|
||||
self.assertAllEqual(computed_parsed_tensor, true_parsed_tensor)
|
||||
|
||||
@parameterized.product(
|
||||
experiment_params=[
|
||||
# einsum_utils.EquationType.NO_ELLIPSES
|
||||
('ab,bc->ac', 'c', 2, []),
|
||||
('ab,bce->ace', 'ce', 3, []),
|
||||
('ab,bce->ace', 'ec', 3, []),
|
||||
('ab,bce->ace', 'c', 3, [2]),
|
||||
('ab,bce->ace', 'e', 3, [1]),
|
||||
('ab,bced->aced', 'ced', 4, []),
|
||||
('ab,bced->aced', 'edc', 4, []),
|
||||
('ab,bced->aced', 'ce', 4, [3]),
|
||||
('ab,bced->aced', 'ec', 4, [3]),
|
||||
('ab,bced->aced', 'cd', 4, [2]),
|
||||
('ab,bced->aced', 'ed', 4, [1]),
|
||||
('ab,bced->aced', 'c', 4, [2, 3]),
|
||||
('ab,bced->aced', 'e', 4, [1, 3]),
|
||||
('ab,bced->aced', 'd', 4, [1, 2]),
|
||||
# einsum_utils.EquationType.LEFT_ELLIPSES
|
||||
('...b,bc->...c', 'c', 2, []),
|
||||
('...b,bce->...ce', 'c', 3, [2]),
|
||||
('...b,bce->...ce', 'e', 3, [1]),
|
||||
('...ab,bc->...ac', 'c', 3, [1]),
|
||||
('...ab,bce->...ace', 'ac', 4, [3]),
|
||||
('...ab,bce->...ace', 'ae', 4, [2]),
|
||||
('...ab,bce->...ace', 'ce', 4, [1]),
|
||||
('...ab,bce->...ace', 'ec', 4, [1]),
|
||||
('...ab,bce->...ace', 'a', 4, [2, 3]),
|
||||
('...ab,bce->...ace', 'c', 4, [1, 3]),
|
||||
('...ab,bce->...ace', 'e', 4, [1, 2]),
|
||||
('...ab,bce->...ace', 'c', 5, [1, 2, 4]),
|
||||
('...ab,bce->...ace', 'c', 10, [1, 2, 3, 4, 5, 6, 7, 9]),
|
||||
# einsum_utils.EquationType.RIGHT_ELLIPSES
|
||||
('ab...,bc->ac...', 'c', 3, [2]),
|
||||
('ab...,bce->ace...', 'ce', 4, [3]),
|
||||
('ab...,bce->ace...', 'ec', 4, [3]),
|
||||
('ab...,bce->ace...', 'c', 4, [2, 3]),
|
||||
('ab...,bce->ace...', 'e', 4, [1, 3]),
|
||||
]
|
||||
)
|
||||
def test_get_einsum_bias_adjoint_reduction_axes(self, experiment_params):
|
||||
equation, bias_axes, einsum_rank, true_reduction_axes = experiment_params
|
||||
computed_reduction_axes = (
|
||||
einsum_utils._get_einsum_bias_adjoint_reduction_axes(
|
||||
equation, bias_axes, einsum_rank
|
||||
)
|
||||
)
|
||||
computed_reduction_axes.sort()
|
||||
true_reduction_axes.sort()
|
||||
self.assertAllEqual(computed_reduction_axes, true_reduction_axes)
|
||||
|
||||
@parameterized.product(
|
||||
experiment_params=[
|
||||
# einsum_utils.EquationType.NO_ELLIPSES
|
||||
('ab,bc->ac', 'a', 2),
|
||||
# einsum_utils.EquationType.RIGHT_ELLIPSES
|
||||
('ab...,bc->ac...', 'a', 3),
|
||||
('ab...,bc->ac...', 'a', 4),
|
||||
('ab...,bcde->acde...', 'acd', 4),
|
||||
]
|
||||
)
|
||||
def test_bias_axis_eq_batch_axis_throws_error(self, experiment_params):
|
||||
equation, bias_axes, einsum_rank = experiment_params
|
||||
with self.assertRaises(ValueError) as context:
|
||||
einsum_utils._get_einsum_bias_adjoint_reduction_axes(
|
||||
equation, bias_axes, einsum_rank
|
||||
)
|
||||
self.assertEqual(
|
||||
f"Bias axis '{bias_axes}' cannot also be the batch axis.",
|
||||
str(context.exception),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in a new issue