Add the first set of EinsumDense utility functions for implementing fast gradient norm computation.

PiperOrigin-RevId: 564460945
This commit is contained in:
A. Unique TensorFlower 2023-09-11 12:05:41 -07:00
parent a23cccde8b
commit 113b27be43
3 changed files with 371 additions and 0 deletions

View file

@ -2,6 +2,20 @@ package(
default_visibility = ["//visibility:public"],
)
py_library(
name = "einsum_utils",
srcs = ["einsum_utils.py"],
srcs_version = "PY3",
)
py_test(
name = "einsum_utils_test",
srcs = ["einsum_utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":einsum_utils"],
)
py_library(
name = "dense",
srcs = ["dense.py"],

View file

@ -0,0 +1,194 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Various helper functions related to `tf.keras.layers.EinsumDense`."""
import enum
import itertools
import re
import numpy as np
import tensorflow as tf
EquationType = enum.Enum(
"EquationType",
["UNKNOWN", "NO_ELLIPSES", "LEFT_ELLIPSES", "RIGHT_ELLIPSES"],
)
def _is_batch_of_vectors(t: tf.Tensor) -> bool:
"""Checks if an input is a batch of (effectively) 1D vectors."""
num_nontrivial_indices = 0
for s in t.shape[1:]:
if s > 1:
num_nontrivial_indices += 1
if num_nontrivial_indices > 1:
return False
return num_nontrivial_indices <= 1
def _parse_einsum_equation(
equation: str,
) -> tuple[EquationType, tuple[str, str, str]]:
"""Returns the EquationType and I/O substrings of an einsum equation.
Args:
equation: The einsum equation `string`.
Returns:
A nested tuple `(equation_type, (ab_str, bc_str, ac_str))`, where
`equation_type` specifies the type of einsum equation and `**_str`
are the components of the equation. Excluding ellipses, the input equation
should be of the form `ab,bc->ac` where `a`, `b`, and `c` can be themselves
be substrings.
Raises:
ValueError: If `equation` is not a valid einsum equation in the context of
the `tf.keras.layers.EinsumDense` layer.
"""
def _try_match(regex_str):
maybe_match = re.fullmatch(regex_str, equation)
return maybe_match.groups() if maybe_match is not None else None
groups1 = _try_match(r"([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)")
if groups1 is not None:
return EquationType.NO_ELLIPSES, groups1
groups2 = _try_match(r"\.\.\.([a-zA-Z]+),([a-zA-Z]+)->\.\.\.([a-zA-Z]+)")
if groups2 is not None:
return EquationType.LEFT_ELLIPSES, groups2
groups3 = _try_match(r"([a-zA-Z]+)\.\.\.,([a-zA-Z]+)->([a-zA-Z]+)\.\.\.")
if groups3 is not None:
return EquationType.RIGHT_ELLIPSES, groups3
raise ValueError(
"Invalid Einsum equation string "
+ equation
+ " ."
"Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, "
"{ab...,bc->ac...}"
)
def _reshape_einsum_inputs(
input_tensor: tf.Tensor,
equation: str,
) -> tf.Tensor:
"""Converts an input tensor of arbitrary rank to a batched matrix tensor.
Args:
input_tensor: A `tf.Tensor` corresponding to the first input of the einsum
equation.
equation: The einsum equation `string`.
Returns:
A rank-3 `tf.Tensor` representing a batch of rank-2 matrices with the same
number of rows and columns. The output dimensions, in order, are:
```
(num_batches, num_rows, num_columns)
```
When `input_tensor` is a rank-2 `tf.Tensor`, the number of output rows is 1
and the number of output columns is the second dimension of the input. The
product of the non-trivial dimensions of the output should be equal to
the product of the dimensions of `input_tensor`.
"""
# Find the components `ab`, `bc`, and `ac` given that `equation` can only be
# one of the following mutually exclusive forms:
#
# (C1) ab,bc->ac,
# (C2) ...ab,bc->...ac
# (C3) ab...,bc->ac...
#
# NOTE: `a`, `b`, and `c` are (possibly) also substrings.
# Compute the first index of the `b` part of the `ab` component.
input_shape = input_tensor.shape
input_len = len(input_shape)
equation_type, (ab_str, bc_str, ac_str) = _parse_einsum_equation(equation)
if equation_type == EquationType.LEFT_ELLIPSES:
# In case (C2), the `a` part of this component can be empty, so we have no
# choice but to compare the `c` part of `ac` with the `bc` component.
c_len = 0
for s1, s2 in itertools.zip_longest(reversed(bc_str), reversed(ac_str)):
if s1 == s2:
c_len += 1
else:
break
b_len = len(bc_str) - c_len
b_idx = input_len - b_len
else:
# For the other cases, we simply compare `ab` with `ac` to get the length
# of the `a` component, i.e., the first index of `b`.
b_idx = 0
for s1, s2 in itertools.zip_longest(ab_str, ac_str):
if s1 == s2:
b_idx += 1
else:
break
# Prepare `input_tensor` for reshaping and get the pivot index of the prepped
# tensor. Note that case (C3) requires a transpose to ensure that matrix
# multiplication is performed by the caller.
if equation_type == EquationType.RIGHT_ELLIPSES:
ellipses_idx = len(ab_str)
# Convert `ab...` to `a...b`.
new_ordering = (
list(range(0, b_idx))
+ list(range(ellipses_idx, input_len))
+ list(range(b_idx, ellipses_idx))
)
input_tensor = tf.transpose(input_tensor, perm=new_ordering)
ellipses_len = input_len - ellipses_idx
pivot_idx = b_idx + ellipses_len
else:
pivot_idx = b_idx
# The output tensor is a batched set of matrices, split at the pivot index
# of the previously prepped tensor.
base_tensor_shape = input_tensor.shape
batch_size = base_tensor_shape[0]
num_rows = int(np.prod(base_tensor_shape[1:pivot_idx]))
num_columns = int(np.prod(base_tensor_shape[pivot_idx:]))
return tf.reshape(input_tensor, shape=[batch_size, num_rows, num_columns])
def _reshape_einsum_outputs(
output_tensor: tf.Tensor,
equation: str,
) -> tf.Tensor:
"""Converts an output tensor of arbitrary rank to a batched matrix tensor.
The logic is almost the same as in `_reshape_einsum_inputs()` except
in the case where the equation is left-elided by ellipses. For this case,
we need to pass in a reversed kernel shape.
Args:
output_tensor: A `tf.Tensor` corresponding to the output of the einsum
equation.
equation: The einsum equation `string`.
Returns:
A rank-3 `tf.Tensor` whose first dimension is the batch dimension. The
product of the non-trivial dimensions of the output should be equal to
the product of the non-trivial dimensions of `output_tensor`.
"""
match = re.fullmatch(r"([a-zA-Z.]+),([a-zA-Z.]+)->([a-zA-Z.]+)", equation)
if match is not None:
s1, s2, s3 = match.groups()
else:
raise ValueError(
"Invalid Einsum equation string "
+ equation
+ " ."
"Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, "
"{ab...,bc->ac...}"
)
reversed_equation = s3 + "," + s2[::-1] + "->" + s1
return _reshape_einsum_inputs(output_tensor, reversed_equation)

View file

@ -0,0 +1,163 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_utils
class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
experiment_params=[
# 1D tensors
([1], True),
([2], True),
# 2D tensors
([1, 2], True),
([2, 1], True),
([2, 2], True),
# 3D tensors
([2, 1, 1], True),
([1, 2, 1], True),
([1, 1, 2], True),
([2, 2, 1], True),
([2, 1, 2], True),
([1, 2, 2], False),
([2, 2, 2], False),
]
)
def test_is_batch_of_vectors(self, experiment_params):
shape, true_result = experiment_params
t = tf.zeros(shape)
computed_result = einsum_utils._is_batch_of_vectors(t)
self.assertEqual(computed_result, true_result)
@parameterized.product(
experiment_params=[
(
'ab,bc->ac',
einsum_utils.EquationType.NO_ELLIPSES,
('ab', 'bc', 'ac'),
),
(
'...b,bc->...c',
einsum_utils.EquationType.LEFT_ELLIPSES,
('b', 'bc', 'c'),
),
(
'ab...,bc->ac...',
einsum_utils.EquationType.RIGHT_ELLIPSES,
('ab', 'bc', 'ac'),
),
]
)
def test_parse_einsum_equation(self, experiment_params):
equation, true_eqn_type, true_groups = experiment_params
(computed_eqn_type, computed_groups) = einsum_utils._parse_einsum_equation(
equation
)
self.assertEqual(computed_eqn_type, true_eqn_type)
self.assertEqual(computed_groups, true_groups)
@parameterized.product(
experiment_params=[
# einsum_utils.EquationType.NO_ELLIPSES
('ab,bc->ac', [2, 3], None, [2, 1, 3]),
('adb,bc->adc', [2, 3, 4], None, [2, 3, 4]),
('adeb,bc->adec', [2, 3, 4, 5], None, [2, 12, 5]),
('abe,bec->ac', [2, 3, 4], None, [2, 1, 12]),
('ab,bce->ace', [2, 3], None, [2, 1, 3]),
# einsum_utils.EquationType.LEFT_ELLIPSES
('...b,bc->...c', [2, 3], None, [2, 1, 3]),
('...b,bc->...c', [2, 3, 4], None, [2, 3, 4]),
('...b,bc->...c', [2, 3, 4, 5], None, [2, 12, 5]),
('...ab,bc->...ac', [2, 3, 4], None, [2, 3, 4]),
('...ab,bc->...ac', [2, 3, 4, 5], None, [2, 12, 5]),
('...be,bec->...c', [2, 3, 4], None, [2, 1, 12]),
('...b,bce->...ce', [2, 3], None, [2, 1, 3]),
# einsum_utils.EquationType.RIGHT_ELLIPSES
('ab...,bc->ac...', [2, 3, 4], [0, 2, 1], [2, 4, 3]),
('ab...,bc->ac...', [2, 3, 4, 5], [0, 2, 3, 1], [2, 20, 3]),
('adb...,bc->adc...', [2, 3, 4, 5], [0, 1, 3, 2], [2, 15, 4]),
('adeb...,bc->adec...', [2, 3, 4, 5, 6], [0, 1, 2, 4, 3], [2, 72, 5]),
('abe...,bec->ac...', [2, 3, 4, 5], [0, 3, 1, 2], [2, 5, 12]),
('ab...,bce->ace...', [2, 3, 4], [0, 2, 1], [2, 4, 3]),
]
)
def test_reshape_einsum_inputs(self, experiment_params):
(equation, input_shape, true_permutations, true_parsed_shape) = (
experiment_params
)
num_entries = int(np.prod(input_shape))
input_tensor = tf.reshape(tf.range(0, num_entries), input_shape)
computed_parsed_tensor = einsum_utils._reshape_einsum_inputs(
input_tensor,
equation,
)
true_parsed_tensor = input_tensor
if true_permutations is not None:
true_parsed_tensor = tf.transpose(
true_parsed_tensor, perm=true_permutations
)
true_parsed_tensor = tf.reshape(true_parsed_tensor, true_parsed_shape)
self.assertAllEqual(computed_parsed_tensor, true_parsed_tensor)
@parameterized.product(
experiment_params=[
# einsum_utils.EquationType.NO_ELLIPSES
('ab,bc->ac', [2, 3], None, [2, 1, 3]),
('adb,bc->adc', [2, 3, 4], None, [2, 3, 4]),
('adeb,bc->adec', [2, 3, 4, 5], None, [2, 12, 5]),
('abe,bec->ac', [2, 3, 4], None, [2, 1, 12]),
('ab,bce->ace', [2, 3, 4], None, [2, 1, 12]),
# einsum_utils.EquationType.LEFT_ELLIPSES
('...b,bc->...c', [2, 3], None, [2, 1, 3]),
('...b,bc->...c', [2, 3, 4], None, [2, 3, 4]),
('...b,bc->...c', [2, 3, 4, 5], None, [2, 12, 5]),
('...ab,bc->...ac', [2, 3, 4], None, [2, 3, 4]),
('...ab,bc->...ac', [2, 3, 4, 5], None, [2, 12, 5]),
('...be,bec->...c', [2, 4], None, [2, 1, 4]),
('...b,bce->...ce', [2, 3, 4], None, [2, 1, 12]),
# einsum_utils.EquationType.RIGHT_ELLIPSES
('ab...,bc->ac...', [2, 3, 4], [0, 2, 1], [2, 4, 3]),
('ab...,bc->ac...', [2, 3, 4, 5], [0, 2, 3, 1], [2, 20, 3]),
('adb...,bc->adc...', [2, 3, 4, 5], [0, 1, 3, 2], [2, 15, 4]),
('adeb...,bc->adec...', [2, 3, 4, 5, 6], [0, 1, 2, 4, 3], [2, 72, 5]),
('abe...,bec->ac...', [2, 3, 4], [0, 2, 1], [2, 4, 3]),
('ab...,bce->ace...', [2, 3, 4, 5], [0, 3, 1, 2], [2, 5, 12]),
]
)
def test_reshape_einsum_outputs(self, experiment_params):
(equation, output_shape, true_permutations, true_parsed_shape) = (
experiment_params
)
num_entries = int(np.prod(output_shape))
output_tensor = tf.reshape(tf.range(0, num_entries), output_shape)
computed_parsed_tensor = einsum_utils._reshape_einsum_outputs(
output_tensor,
equation,
)
true_parsed_tensor = output_tensor
if true_permutations is not None:
true_parsed_tensor = tf.transpose(
true_parsed_tensor, perm=true_permutations
)
true_parsed_tensor = tf.reshape(true_parsed_tensor, true_parsed_shape)
self.assertAllEqual(computed_parsed_tensor, true_parsed_tensor)
if __name__ == '__main__':
tf.test.main()