diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD index fd9fc31..e6a948f 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD @@ -2,6 +2,20 @@ package( default_visibility = ["//visibility:public"], ) +py_library( + name = "einsum_utils", + srcs = ["einsum_utils.py"], + srcs_version = "PY3", +) + +py_test( + name = "einsum_utils_test", + srcs = ["einsum_utils_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [":einsum_utils"], +) + py_library( name = "dense", srcs = ["dense.py"], diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py new file mode 100644 index 0000000..4a947ba --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py @@ -0,0 +1,194 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Various helper functions related to `tf.keras.layers.EinsumDense`.""" + +import enum +import itertools +import re + +import numpy as np +import tensorflow as tf + +EquationType = enum.Enum( + "EquationType", + ["UNKNOWN", "NO_ELLIPSES", "LEFT_ELLIPSES", "RIGHT_ELLIPSES"], +) + + +def _is_batch_of_vectors(t: tf.Tensor) -> bool: + """Checks if an input is a batch of (effectively) 1D vectors.""" + num_nontrivial_indices = 0 + for s in t.shape[1:]: + if s > 1: + num_nontrivial_indices += 1 + if num_nontrivial_indices > 1: + return False + return num_nontrivial_indices <= 1 + + +def _parse_einsum_equation( + equation: str, +) -> tuple[EquationType, tuple[str, str, str]]: + """Returns the EquationType and I/O substrings of an einsum equation. + + Args: + equation: The einsum equation `string`. + + Returns: + A nested tuple `(equation_type, (ab_str, bc_str, ac_str))`, where + `equation_type` specifies the type of einsum equation and `**_str` + are the components of the equation. Excluding ellipses, the input equation + should be of the form `ab,bc->ac` where `a`, `b`, and `c` can be themselves + be substrings. + + Raises: + ValueError: If `equation` is not a valid einsum equation in the context of + the `tf.keras.layers.EinsumDense` layer. + """ + + def _try_match(regex_str): + maybe_match = re.fullmatch(regex_str, equation) + return maybe_match.groups() if maybe_match is not None else None + + groups1 = _try_match(r"([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)") + if groups1 is not None: + return EquationType.NO_ELLIPSES, groups1 + groups2 = _try_match(r"\.\.\.([a-zA-Z]+),([a-zA-Z]+)->\.\.\.([a-zA-Z]+)") + if groups2 is not None: + return EquationType.LEFT_ELLIPSES, groups2 + groups3 = _try_match(r"([a-zA-Z]+)\.\.\.,([a-zA-Z]+)->([a-zA-Z]+)\.\.\.") + if groups3 is not None: + return EquationType.RIGHT_ELLIPSES, groups3 + raise ValueError( + "Invalid Einsum equation string " + + equation + + " ." + "Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, " + "{ab...,bc->ac...}" + ) + + +def _reshape_einsum_inputs( + input_tensor: tf.Tensor, + equation: str, +) -> tf.Tensor: + """Converts an input tensor of arbitrary rank to a batched matrix tensor. + + Args: + input_tensor: A `tf.Tensor` corresponding to the first input of the einsum + equation. + equation: The einsum equation `string`. + + Returns: + A rank-3 `tf.Tensor` representing a batch of rank-2 matrices with the same + number of rows and columns. The output dimensions, in order, are: + ``` + (num_batches, num_rows, num_columns) + ``` + When `input_tensor` is a rank-2 `tf.Tensor`, the number of output rows is 1 + and the number of output columns is the second dimension of the input. The + product of the non-trivial dimensions of the output should be equal to + the product of the dimensions of `input_tensor`. + """ + # Find the components `ab`, `bc`, and `ac` given that `equation` can only be + # one of the following mutually exclusive forms: + # + # (C1) ab,bc->ac, + # (C2) ...ab,bc->...ac + # (C3) ab...,bc->ac... + # + # NOTE: `a`, `b`, and `c` are (possibly) also substrings. + + # Compute the first index of the `b` part of the `ab` component. + input_shape = input_tensor.shape + input_len = len(input_shape) + equation_type, (ab_str, bc_str, ac_str) = _parse_einsum_equation(equation) + if equation_type == EquationType.LEFT_ELLIPSES: + # In case (C2), the `a` part of this component can be empty, so we have no + # choice but to compare the `c` part of `ac` with the `bc` component. + c_len = 0 + for s1, s2 in itertools.zip_longest(reversed(bc_str), reversed(ac_str)): + if s1 == s2: + c_len += 1 + else: + break + b_len = len(bc_str) - c_len + b_idx = input_len - b_len + else: + # For the other cases, we simply compare `ab` with `ac` to get the length + # of the `a` component, i.e., the first index of `b`. + b_idx = 0 + for s1, s2 in itertools.zip_longest(ab_str, ac_str): + if s1 == s2: + b_idx += 1 + else: + break + # Prepare `input_tensor` for reshaping and get the pivot index of the prepped + # tensor. Note that case (C3) requires a transpose to ensure that matrix + # multiplication is performed by the caller. + if equation_type == EquationType.RIGHT_ELLIPSES: + ellipses_idx = len(ab_str) + # Convert `ab...` to `a...b`. + new_ordering = ( + list(range(0, b_idx)) + + list(range(ellipses_idx, input_len)) + + list(range(b_idx, ellipses_idx)) + ) + input_tensor = tf.transpose(input_tensor, perm=new_ordering) + ellipses_len = input_len - ellipses_idx + pivot_idx = b_idx + ellipses_len + else: + pivot_idx = b_idx + # The output tensor is a batched set of matrices, split at the pivot index + # of the previously prepped tensor. + base_tensor_shape = input_tensor.shape + batch_size = base_tensor_shape[0] + num_rows = int(np.prod(base_tensor_shape[1:pivot_idx])) + num_columns = int(np.prod(base_tensor_shape[pivot_idx:])) + return tf.reshape(input_tensor, shape=[batch_size, num_rows, num_columns]) + + +def _reshape_einsum_outputs( + output_tensor: tf.Tensor, + equation: str, +) -> tf.Tensor: + """Converts an output tensor of arbitrary rank to a batched matrix tensor. + + The logic is almost the same as in `_reshape_einsum_inputs()` except + in the case where the equation is left-elided by ellipses. For this case, + we need to pass in a reversed kernel shape. + + Args: + output_tensor: A `tf.Tensor` corresponding to the output of the einsum + equation. + equation: The einsum equation `string`. + + Returns: + A rank-3 `tf.Tensor` whose first dimension is the batch dimension. The + product of the non-trivial dimensions of the output should be equal to + the product of the non-trivial dimensions of `output_tensor`. + """ + match = re.fullmatch(r"([a-zA-Z.]+),([a-zA-Z.]+)->([a-zA-Z.]+)", equation) + if match is not None: + s1, s2, s3 = match.groups() + else: + raise ValueError( + "Invalid Einsum equation string " + + equation + + " ." + "Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, " + "{ab...,bc->ac...}" + ) + reversed_equation = s3 + "," + s2[::-1] + "->" + s1 + return _reshape_einsum_inputs(output_tensor, reversed_equation) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils_test.py new file mode 100644 index 0000000..1fd740c --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils_test.py @@ -0,0 +1,163 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import numpy as np +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_utils + + +class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.product( + experiment_params=[ + # 1D tensors + ([1], True), + ([2], True), + # 2D tensors + ([1, 2], True), + ([2, 1], True), + ([2, 2], True), + # 3D tensors + ([2, 1, 1], True), + ([1, 2, 1], True), + ([1, 1, 2], True), + ([2, 2, 1], True), + ([2, 1, 2], True), + ([1, 2, 2], False), + ([2, 2, 2], False), + ] + ) + def test_is_batch_of_vectors(self, experiment_params): + shape, true_result = experiment_params + t = tf.zeros(shape) + computed_result = einsum_utils._is_batch_of_vectors(t) + self.assertEqual(computed_result, true_result) + + @parameterized.product( + experiment_params=[ + ( + 'ab,bc->ac', + einsum_utils.EquationType.NO_ELLIPSES, + ('ab', 'bc', 'ac'), + ), + ( + '...b,bc->...c', + einsum_utils.EquationType.LEFT_ELLIPSES, + ('b', 'bc', 'c'), + ), + ( + 'ab...,bc->ac...', + einsum_utils.EquationType.RIGHT_ELLIPSES, + ('ab', 'bc', 'ac'), + ), + ] + ) + def test_parse_einsum_equation(self, experiment_params): + equation, true_eqn_type, true_groups = experiment_params + (computed_eqn_type, computed_groups) = einsum_utils._parse_einsum_equation( + equation + ) + self.assertEqual(computed_eqn_type, true_eqn_type) + self.assertEqual(computed_groups, true_groups) + + @parameterized.product( + experiment_params=[ + # einsum_utils.EquationType.NO_ELLIPSES + ('ab,bc->ac', [2, 3], None, [2, 1, 3]), + ('adb,bc->adc', [2, 3, 4], None, [2, 3, 4]), + ('adeb,bc->adec', [2, 3, 4, 5], None, [2, 12, 5]), + ('abe,bec->ac', [2, 3, 4], None, [2, 1, 12]), + ('ab,bce->ace', [2, 3], None, [2, 1, 3]), + # einsum_utils.EquationType.LEFT_ELLIPSES + ('...b,bc->...c', [2, 3], None, [2, 1, 3]), + ('...b,bc->...c', [2, 3, 4], None, [2, 3, 4]), + ('...b,bc->...c', [2, 3, 4, 5], None, [2, 12, 5]), + ('...ab,bc->...ac', [2, 3, 4], None, [2, 3, 4]), + ('...ab,bc->...ac', [2, 3, 4, 5], None, [2, 12, 5]), + ('...be,bec->...c', [2, 3, 4], None, [2, 1, 12]), + ('...b,bce->...ce', [2, 3], None, [2, 1, 3]), + # einsum_utils.EquationType.RIGHT_ELLIPSES + ('ab...,bc->ac...', [2, 3, 4], [0, 2, 1], [2, 4, 3]), + ('ab...,bc->ac...', [2, 3, 4, 5], [0, 2, 3, 1], [2, 20, 3]), + ('adb...,bc->adc...', [2, 3, 4, 5], [0, 1, 3, 2], [2, 15, 4]), + ('adeb...,bc->adec...', [2, 3, 4, 5, 6], [0, 1, 2, 4, 3], [2, 72, 5]), + ('abe...,bec->ac...', [2, 3, 4, 5], [0, 3, 1, 2], [2, 5, 12]), + ('ab...,bce->ace...', [2, 3, 4], [0, 2, 1], [2, 4, 3]), + ] + ) + def test_reshape_einsum_inputs(self, experiment_params): + (equation, input_shape, true_permutations, true_parsed_shape) = ( + experiment_params + ) + num_entries = int(np.prod(input_shape)) + input_tensor = tf.reshape(tf.range(0, num_entries), input_shape) + computed_parsed_tensor = einsum_utils._reshape_einsum_inputs( + input_tensor, + equation, + ) + true_parsed_tensor = input_tensor + if true_permutations is not None: + true_parsed_tensor = tf.transpose( + true_parsed_tensor, perm=true_permutations + ) + true_parsed_tensor = tf.reshape(true_parsed_tensor, true_parsed_shape) + self.assertAllEqual(computed_parsed_tensor, true_parsed_tensor) + + @parameterized.product( + experiment_params=[ + # einsum_utils.EquationType.NO_ELLIPSES + ('ab,bc->ac', [2, 3], None, [2, 1, 3]), + ('adb,bc->adc', [2, 3, 4], None, [2, 3, 4]), + ('adeb,bc->adec', [2, 3, 4, 5], None, [2, 12, 5]), + ('abe,bec->ac', [2, 3, 4], None, [2, 1, 12]), + ('ab,bce->ace', [2, 3, 4], None, [2, 1, 12]), + # einsum_utils.EquationType.LEFT_ELLIPSES + ('...b,bc->...c', [2, 3], None, [2, 1, 3]), + ('...b,bc->...c', [2, 3, 4], None, [2, 3, 4]), + ('...b,bc->...c', [2, 3, 4, 5], None, [2, 12, 5]), + ('...ab,bc->...ac', [2, 3, 4], None, [2, 3, 4]), + ('...ab,bc->...ac', [2, 3, 4, 5], None, [2, 12, 5]), + ('...be,bec->...c', [2, 4], None, [2, 1, 4]), + ('...b,bce->...ce', [2, 3, 4], None, [2, 1, 12]), + # einsum_utils.EquationType.RIGHT_ELLIPSES + ('ab...,bc->ac...', [2, 3, 4], [0, 2, 1], [2, 4, 3]), + ('ab...,bc->ac...', [2, 3, 4, 5], [0, 2, 3, 1], [2, 20, 3]), + ('adb...,bc->adc...', [2, 3, 4, 5], [0, 1, 3, 2], [2, 15, 4]), + ('adeb...,bc->adec...', [2, 3, 4, 5, 6], [0, 1, 2, 4, 3], [2, 72, 5]), + ('abe...,bec->ac...', [2, 3, 4], [0, 2, 1], [2, 4, 3]), + ('ab...,bce->ace...', [2, 3, 4, 5], [0, 3, 1, 2], [2, 5, 12]), + ] + ) + def test_reshape_einsum_outputs(self, experiment_params): + (equation, output_shape, true_permutations, true_parsed_shape) = ( + experiment_params + ) + num_entries = int(np.prod(output_shape)) + output_tensor = tf.reshape(tf.range(0, num_entries), output_shape) + computed_parsed_tensor = einsum_utils._reshape_einsum_outputs( + output_tensor, + equation, + ) + true_parsed_tensor = output_tensor + if true_permutations is not None: + true_parsed_tensor = tf.transpose( + true_parsed_tensor, perm=true_permutations + ) + true_parsed_tensor = tf.reshape(true_parsed_tensor, true_parsed_shape) + self.assertAllEqual(computed_parsed_tensor, true_parsed_tensor) + + +if __name__ == '__main__': + tf.test.main()