forked from 626_privacy/tensorflow_privacy
Add the first set of EinsumDense utility functions for implementing fast gradient norm computation.
PiperOrigin-RevId: 564460945
This commit is contained in:
parent
a23cccde8b
commit
113b27be43
3 changed files with 371 additions and 0 deletions
|
@ -2,6 +2,20 @@ package(
|
||||||
default_visibility = ["//visibility:public"],
|
default_visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "einsum_utils",
|
||||||
|
srcs = ["einsum_utils.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "einsum_utils_test",
|
||||||
|
srcs = ["einsum_utils_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [":einsum_utils"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "dense",
|
name = "dense",
|
||||||
srcs = ["dense.py"],
|
srcs = ["dense.py"],
|
||||||
|
|
|
@ -0,0 +1,194 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Various helper functions related to `tf.keras.layers.EinsumDense`."""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import itertools
|
||||||
|
import re
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
EquationType = enum.Enum(
|
||||||
|
"EquationType",
|
||||||
|
["UNKNOWN", "NO_ELLIPSES", "LEFT_ELLIPSES", "RIGHT_ELLIPSES"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_batch_of_vectors(t: tf.Tensor) -> bool:
|
||||||
|
"""Checks if an input is a batch of (effectively) 1D vectors."""
|
||||||
|
num_nontrivial_indices = 0
|
||||||
|
for s in t.shape[1:]:
|
||||||
|
if s > 1:
|
||||||
|
num_nontrivial_indices += 1
|
||||||
|
if num_nontrivial_indices > 1:
|
||||||
|
return False
|
||||||
|
return num_nontrivial_indices <= 1
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_einsum_equation(
|
||||||
|
equation: str,
|
||||||
|
) -> tuple[EquationType, tuple[str, str, str]]:
|
||||||
|
"""Returns the EquationType and I/O substrings of an einsum equation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
equation: The einsum equation `string`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A nested tuple `(equation_type, (ab_str, bc_str, ac_str))`, where
|
||||||
|
`equation_type` specifies the type of einsum equation and `**_str`
|
||||||
|
are the components of the equation. Excluding ellipses, the input equation
|
||||||
|
should be of the form `ab,bc->ac` where `a`, `b`, and `c` can be themselves
|
||||||
|
be substrings.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `equation` is not a valid einsum equation in the context of
|
||||||
|
the `tf.keras.layers.EinsumDense` layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _try_match(regex_str):
|
||||||
|
maybe_match = re.fullmatch(regex_str, equation)
|
||||||
|
return maybe_match.groups() if maybe_match is not None else None
|
||||||
|
|
||||||
|
groups1 = _try_match(r"([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)")
|
||||||
|
if groups1 is not None:
|
||||||
|
return EquationType.NO_ELLIPSES, groups1
|
||||||
|
groups2 = _try_match(r"\.\.\.([a-zA-Z]+),([a-zA-Z]+)->\.\.\.([a-zA-Z]+)")
|
||||||
|
if groups2 is not None:
|
||||||
|
return EquationType.LEFT_ELLIPSES, groups2
|
||||||
|
groups3 = _try_match(r"([a-zA-Z]+)\.\.\.,([a-zA-Z]+)->([a-zA-Z]+)\.\.\.")
|
||||||
|
if groups3 is not None:
|
||||||
|
return EquationType.RIGHT_ELLIPSES, groups3
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid Einsum equation string "
|
||||||
|
+ equation
|
||||||
|
+ " ."
|
||||||
|
"Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, "
|
||||||
|
"{ab...,bc->ac...}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _reshape_einsum_inputs(
|
||||||
|
input_tensor: tf.Tensor,
|
||||||
|
equation: str,
|
||||||
|
) -> tf.Tensor:
|
||||||
|
"""Converts an input tensor of arbitrary rank to a batched matrix tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensor: A `tf.Tensor` corresponding to the first input of the einsum
|
||||||
|
equation.
|
||||||
|
equation: The einsum equation `string`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A rank-3 `tf.Tensor` representing a batch of rank-2 matrices with the same
|
||||||
|
number of rows and columns. The output dimensions, in order, are:
|
||||||
|
```
|
||||||
|
(num_batches, num_rows, num_columns)
|
||||||
|
```
|
||||||
|
When `input_tensor` is a rank-2 `tf.Tensor`, the number of output rows is 1
|
||||||
|
and the number of output columns is the second dimension of the input. The
|
||||||
|
product of the non-trivial dimensions of the output should be equal to
|
||||||
|
the product of the dimensions of `input_tensor`.
|
||||||
|
"""
|
||||||
|
# Find the components `ab`, `bc`, and `ac` given that `equation` can only be
|
||||||
|
# one of the following mutually exclusive forms:
|
||||||
|
#
|
||||||
|
# (C1) ab,bc->ac,
|
||||||
|
# (C2) ...ab,bc->...ac
|
||||||
|
# (C3) ab...,bc->ac...
|
||||||
|
#
|
||||||
|
# NOTE: `a`, `b`, and `c` are (possibly) also substrings.
|
||||||
|
|
||||||
|
# Compute the first index of the `b` part of the `ab` component.
|
||||||
|
input_shape = input_tensor.shape
|
||||||
|
input_len = len(input_shape)
|
||||||
|
equation_type, (ab_str, bc_str, ac_str) = _parse_einsum_equation(equation)
|
||||||
|
if equation_type == EquationType.LEFT_ELLIPSES:
|
||||||
|
# In case (C2), the `a` part of this component can be empty, so we have no
|
||||||
|
# choice but to compare the `c` part of `ac` with the `bc` component.
|
||||||
|
c_len = 0
|
||||||
|
for s1, s2 in itertools.zip_longest(reversed(bc_str), reversed(ac_str)):
|
||||||
|
if s1 == s2:
|
||||||
|
c_len += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
b_len = len(bc_str) - c_len
|
||||||
|
b_idx = input_len - b_len
|
||||||
|
else:
|
||||||
|
# For the other cases, we simply compare `ab` with `ac` to get the length
|
||||||
|
# of the `a` component, i.e., the first index of `b`.
|
||||||
|
b_idx = 0
|
||||||
|
for s1, s2 in itertools.zip_longest(ab_str, ac_str):
|
||||||
|
if s1 == s2:
|
||||||
|
b_idx += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
# Prepare `input_tensor` for reshaping and get the pivot index of the prepped
|
||||||
|
# tensor. Note that case (C3) requires a transpose to ensure that matrix
|
||||||
|
# multiplication is performed by the caller.
|
||||||
|
if equation_type == EquationType.RIGHT_ELLIPSES:
|
||||||
|
ellipses_idx = len(ab_str)
|
||||||
|
# Convert `ab...` to `a...b`.
|
||||||
|
new_ordering = (
|
||||||
|
list(range(0, b_idx))
|
||||||
|
+ list(range(ellipses_idx, input_len))
|
||||||
|
+ list(range(b_idx, ellipses_idx))
|
||||||
|
)
|
||||||
|
input_tensor = tf.transpose(input_tensor, perm=new_ordering)
|
||||||
|
ellipses_len = input_len - ellipses_idx
|
||||||
|
pivot_idx = b_idx + ellipses_len
|
||||||
|
else:
|
||||||
|
pivot_idx = b_idx
|
||||||
|
# The output tensor is a batched set of matrices, split at the pivot index
|
||||||
|
# of the previously prepped tensor.
|
||||||
|
base_tensor_shape = input_tensor.shape
|
||||||
|
batch_size = base_tensor_shape[0]
|
||||||
|
num_rows = int(np.prod(base_tensor_shape[1:pivot_idx]))
|
||||||
|
num_columns = int(np.prod(base_tensor_shape[pivot_idx:]))
|
||||||
|
return tf.reshape(input_tensor, shape=[batch_size, num_rows, num_columns])
|
||||||
|
|
||||||
|
|
||||||
|
def _reshape_einsum_outputs(
|
||||||
|
output_tensor: tf.Tensor,
|
||||||
|
equation: str,
|
||||||
|
) -> tf.Tensor:
|
||||||
|
"""Converts an output tensor of arbitrary rank to a batched matrix tensor.
|
||||||
|
|
||||||
|
The logic is almost the same as in `_reshape_einsum_inputs()` except
|
||||||
|
in the case where the equation is left-elided by ellipses. For this case,
|
||||||
|
we need to pass in a reversed kernel shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_tensor: A `tf.Tensor` corresponding to the output of the einsum
|
||||||
|
equation.
|
||||||
|
equation: The einsum equation `string`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A rank-3 `tf.Tensor` whose first dimension is the batch dimension. The
|
||||||
|
product of the non-trivial dimensions of the output should be equal to
|
||||||
|
the product of the non-trivial dimensions of `output_tensor`.
|
||||||
|
"""
|
||||||
|
match = re.fullmatch(r"([a-zA-Z.]+),([a-zA-Z.]+)->([a-zA-Z.]+)", equation)
|
||||||
|
if match is not None:
|
||||||
|
s1, s2, s3 = match.groups()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid Einsum equation string "
|
||||||
|
+ equation
|
||||||
|
+ " ."
|
||||||
|
"Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, "
|
||||||
|
"{ab...,bc->ac...}"
|
||||||
|
)
|
||||||
|
reversed_equation = s3 + "," + s2[::-1] + "->" + s1
|
||||||
|
return _reshape_einsum_inputs(output_tensor, reversed_equation)
|
|
@ -0,0 +1,163 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_utils
|
||||||
|
|
||||||
|
|
||||||
|
class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
experiment_params=[
|
||||||
|
# 1D tensors
|
||||||
|
([1], True),
|
||||||
|
([2], True),
|
||||||
|
# 2D tensors
|
||||||
|
([1, 2], True),
|
||||||
|
([2, 1], True),
|
||||||
|
([2, 2], True),
|
||||||
|
# 3D tensors
|
||||||
|
([2, 1, 1], True),
|
||||||
|
([1, 2, 1], True),
|
||||||
|
([1, 1, 2], True),
|
||||||
|
([2, 2, 1], True),
|
||||||
|
([2, 1, 2], True),
|
||||||
|
([1, 2, 2], False),
|
||||||
|
([2, 2, 2], False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_is_batch_of_vectors(self, experiment_params):
|
||||||
|
shape, true_result = experiment_params
|
||||||
|
t = tf.zeros(shape)
|
||||||
|
computed_result = einsum_utils._is_batch_of_vectors(t)
|
||||||
|
self.assertEqual(computed_result, true_result)
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
experiment_params=[
|
||||||
|
(
|
||||||
|
'ab,bc->ac',
|
||||||
|
einsum_utils.EquationType.NO_ELLIPSES,
|
||||||
|
('ab', 'bc', 'ac'),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
'...b,bc->...c',
|
||||||
|
einsum_utils.EquationType.LEFT_ELLIPSES,
|
||||||
|
('b', 'bc', 'c'),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
'ab...,bc->ac...',
|
||||||
|
einsum_utils.EquationType.RIGHT_ELLIPSES,
|
||||||
|
('ab', 'bc', 'ac'),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_parse_einsum_equation(self, experiment_params):
|
||||||
|
equation, true_eqn_type, true_groups = experiment_params
|
||||||
|
(computed_eqn_type, computed_groups) = einsum_utils._parse_einsum_equation(
|
||||||
|
equation
|
||||||
|
)
|
||||||
|
self.assertEqual(computed_eqn_type, true_eqn_type)
|
||||||
|
self.assertEqual(computed_groups, true_groups)
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
experiment_params=[
|
||||||
|
# einsum_utils.EquationType.NO_ELLIPSES
|
||||||
|
('ab,bc->ac', [2, 3], None, [2, 1, 3]),
|
||||||
|
('adb,bc->adc', [2, 3, 4], None, [2, 3, 4]),
|
||||||
|
('adeb,bc->adec', [2, 3, 4, 5], None, [2, 12, 5]),
|
||||||
|
('abe,bec->ac', [2, 3, 4], None, [2, 1, 12]),
|
||||||
|
('ab,bce->ace', [2, 3], None, [2, 1, 3]),
|
||||||
|
# einsum_utils.EquationType.LEFT_ELLIPSES
|
||||||
|
('...b,bc->...c', [2, 3], None, [2, 1, 3]),
|
||||||
|
('...b,bc->...c', [2, 3, 4], None, [2, 3, 4]),
|
||||||
|
('...b,bc->...c', [2, 3, 4, 5], None, [2, 12, 5]),
|
||||||
|
('...ab,bc->...ac', [2, 3, 4], None, [2, 3, 4]),
|
||||||
|
('...ab,bc->...ac', [2, 3, 4, 5], None, [2, 12, 5]),
|
||||||
|
('...be,bec->...c', [2, 3, 4], None, [2, 1, 12]),
|
||||||
|
('...b,bce->...ce', [2, 3], None, [2, 1, 3]),
|
||||||
|
# einsum_utils.EquationType.RIGHT_ELLIPSES
|
||||||
|
('ab...,bc->ac...', [2, 3, 4], [0, 2, 1], [2, 4, 3]),
|
||||||
|
('ab...,bc->ac...', [2, 3, 4, 5], [0, 2, 3, 1], [2, 20, 3]),
|
||||||
|
('adb...,bc->adc...', [2, 3, 4, 5], [0, 1, 3, 2], [2, 15, 4]),
|
||||||
|
('adeb...,bc->adec...', [2, 3, 4, 5, 6], [0, 1, 2, 4, 3], [2, 72, 5]),
|
||||||
|
('abe...,bec->ac...', [2, 3, 4, 5], [0, 3, 1, 2], [2, 5, 12]),
|
||||||
|
('ab...,bce->ace...', [2, 3, 4], [0, 2, 1], [2, 4, 3]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_reshape_einsum_inputs(self, experiment_params):
|
||||||
|
(equation, input_shape, true_permutations, true_parsed_shape) = (
|
||||||
|
experiment_params
|
||||||
|
)
|
||||||
|
num_entries = int(np.prod(input_shape))
|
||||||
|
input_tensor = tf.reshape(tf.range(0, num_entries), input_shape)
|
||||||
|
computed_parsed_tensor = einsum_utils._reshape_einsum_inputs(
|
||||||
|
input_tensor,
|
||||||
|
equation,
|
||||||
|
)
|
||||||
|
true_parsed_tensor = input_tensor
|
||||||
|
if true_permutations is not None:
|
||||||
|
true_parsed_tensor = tf.transpose(
|
||||||
|
true_parsed_tensor, perm=true_permutations
|
||||||
|
)
|
||||||
|
true_parsed_tensor = tf.reshape(true_parsed_tensor, true_parsed_shape)
|
||||||
|
self.assertAllEqual(computed_parsed_tensor, true_parsed_tensor)
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
experiment_params=[
|
||||||
|
# einsum_utils.EquationType.NO_ELLIPSES
|
||||||
|
('ab,bc->ac', [2, 3], None, [2, 1, 3]),
|
||||||
|
('adb,bc->adc', [2, 3, 4], None, [2, 3, 4]),
|
||||||
|
('adeb,bc->adec', [2, 3, 4, 5], None, [2, 12, 5]),
|
||||||
|
('abe,bec->ac', [2, 3, 4], None, [2, 1, 12]),
|
||||||
|
('ab,bce->ace', [2, 3, 4], None, [2, 1, 12]),
|
||||||
|
# einsum_utils.EquationType.LEFT_ELLIPSES
|
||||||
|
('...b,bc->...c', [2, 3], None, [2, 1, 3]),
|
||||||
|
('...b,bc->...c', [2, 3, 4], None, [2, 3, 4]),
|
||||||
|
('...b,bc->...c', [2, 3, 4, 5], None, [2, 12, 5]),
|
||||||
|
('...ab,bc->...ac', [2, 3, 4], None, [2, 3, 4]),
|
||||||
|
('...ab,bc->...ac', [2, 3, 4, 5], None, [2, 12, 5]),
|
||||||
|
('...be,bec->...c', [2, 4], None, [2, 1, 4]),
|
||||||
|
('...b,bce->...ce', [2, 3, 4], None, [2, 1, 12]),
|
||||||
|
# einsum_utils.EquationType.RIGHT_ELLIPSES
|
||||||
|
('ab...,bc->ac...', [2, 3, 4], [0, 2, 1], [2, 4, 3]),
|
||||||
|
('ab...,bc->ac...', [2, 3, 4, 5], [0, 2, 3, 1], [2, 20, 3]),
|
||||||
|
('adb...,bc->adc...', [2, 3, 4, 5], [0, 1, 3, 2], [2, 15, 4]),
|
||||||
|
('adeb...,bc->adec...', [2, 3, 4, 5, 6], [0, 1, 2, 4, 3], [2, 72, 5]),
|
||||||
|
('abe...,bec->ac...', [2, 3, 4], [0, 2, 1], [2, 4, 3]),
|
||||||
|
('ab...,bce->ace...', [2, 3, 4, 5], [0, 3, 1, 2], [2, 5, 12]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_reshape_einsum_outputs(self, experiment_params):
|
||||||
|
(equation, output_shape, true_permutations, true_parsed_shape) = (
|
||||||
|
experiment_params
|
||||||
|
)
|
||||||
|
num_entries = int(np.prod(output_shape))
|
||||||
|
output_tensor = tf.reshape(tf.range(0, num_entries), output_shape)
|
||||||
|
computed_parsed_tensor = einsum_utils._reshape_einsum_outputs(
|
||||||
|
output_tensor,
|
||||||
|
equation,
|
||||||
|
)
|
||||||
|
true_parsed_tensor = output_tensor
|
||||||
|
if true_permutations is not None:
|
||||||
|
true_parsed_tensor = tf.transpose(
|
||||||
|
true_parsed_tensor, perm=true_permutations
|
||||||
|
)
|
||||||
|
true_parsed_tensor = tf.reshape(true_parsed_tensor, true_parsed_shape)
|
||||||
|
self.assertAllEqual(computed_parsed_tensor, true_parsed_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
Loading…
Reference in a new issue