diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD index 465969e..9828852 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD @@ -19,6 +19,7 @@ py_test( name = "einsum_utils_test", srcs = ["einsum_utils_test.py"], python_version = "PY3", + shard_count = 4, srcs_version = "PY3", deps = [":einsum_utils"], ) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py index 4a947ba..4912939 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py @@ -15,6 +15,7 @@ import enum import itertools +import os import re import numpy as np @@ -37,6 +38,36 @@ def _is_batch_of_vectors(t: tf.Tensor) -> bool: return num_nontrivial_indices <= 1 +def _is_valid_einsum_equation( + maybe_ab: str, + maybe_bc: str, + maybe_ac: str, +) -> bool: + """Checks if three input strings form a valid einsum dense equation. + + Given three substrings `maybe_ab`, `maybe_bc`, and `maybe_ac`, this function + checks if + ``` + maybe_ab + ',' + maybe_bc + '->' + maybe_ac + ``` + is an einsum equation of the form `ab,bc->ac`. + + Args: + maybe_ab: The proposed `ab` substring. + maybe_bc: The proposed `bc` substring. + maybe_ac: The proposed `ac` substring. + + Returns: + `True` if the three input strings form an einsum equation of the form + `ab,bc->ac` and `False` otherwise. + """ + a_substr = os.path.commonprefix([maybe_ab, maybe_ac]) + a_len = len(a_substr) + b_substr = maybe_ab[a_len:] + c_substr = maybe_ac[a_len:] + return maybe_bc == b_substr + c_substr + + def _parse_einsum_equation( equation: str, ) -> tuple[EquationType, tuple[str, str, str]]: @@ -61,22 +92,33 @@ def _parse_einsum_equation( maybe_match = re.fullmatch(regex_str, equation) return maybe_match.groups() if maybe_match is not None else None - groups1 = _try_match(r"([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)") - if groups1 is not None: - return EquationType.NO_ELLIPSES, groups1 - groups2 = _try_match(r"\.\.\.([a-zA-Z]+),([a-zA-Z]+)->\.\.\.([a-zA-Z]+)") - if groups2 is not None: - return EquationType.LEFT_ELLIPSES, groups2 - groups3 = _try_match(r"([a-zA-Z]+)\.\.\.,([a-zA-Z]+)->([a-zA-Z]+)\.\.\.") - if groups3 is not None: - return EquationType.RIGHT_ELLIPSES, groups3 - raise ValueError( + error_message = ( "Invalid Einsum equation string " + equation + " ." "Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, " "{ab...,bc->ac...}" ) + case_pairs = [ + # equation_type, regex_str + (EquationType.NO_ELLIPSES, r"([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)"), + ( + EquationType.LEFT_ELLIPSES, + r"\.\.\.([a-zA-Z]+),([a-zA-Z]+)->\.\.\.([a-zA-Z]+)", + ), + ( + EquationType.RIGHT_ELLIPSES, + r"([a-zA-Z]+)\.\.\.,([a-zA-Z]+)->([a-zA-Z]+)\.\.\.", + ), + ] + for equation_type, regex_str in case_pairs: + groups = _try_match(regex_str) + if groups is not None: + if not _is_valid_einsum_equation(*groups): + raise ValueError(error_message) + return equation_type, groups + # No valid cases found. Raise an error. + raise ValueError(error_message) def _reshape_einsum_inputs( @@ -100,13 +142,17 @@ def _reshape_einsum_inputs( and the number of output columns is the second dimension of the input. The product of the non-trivial dimensions of the output should be equal to the product of the dimensions of `input_tensor`. + + Raises: + ValueError: If `equation` is not a valid einsum equation in the context of + the `tf.keras.layers.EinsumDense` layer. """ # Find the components `ab`, `bc`, and `ac` given that `equation` can only be # one of the following mutually exclusive forms: # - # (C1) ab,bc->ac, - # (C2) ...ab,bc->...ac - # (C3) ab...,bc->ac... + # C1. ab,bc->ac, + # C2. ...ab,bc->...ac + # C3. ab...,bc->ac... # # NOTE: `a`, `b`, and `c` are (possibly) also substrings. @@ -115,7 +161,7 @@ def _reshape_einsum_inputs( input_len = len(input_shape) equation_type, (ab_str, bc_str, ac_str) = _parse_einsum_equation(equation) if equation_type == EquationType.LEFT_ELLIPSES: - # In case (C2), the `a` part of this component can be empty, so we have no + # In case C2, the `a` part of this component can be empty, so we have no # choice but to compare the `c` part of `ac` with the `bc` component. c_len = 0 for s1, s2 in itertools.zip_longest(reversed(bc_str), reversed(ac_str)): @@ -135,7 +181,7 @@ def _reshape_einsum_inputs( else: break # Prepare `input_tensor` for reshaping and get the pivot index of the prepped - # tensor. Note that case (C3) requires a transpose to ensure that matrix + # tensor. Note that case C3 requires a transpose to ensure that matrix # multiplication is performed by the caller. if equation_type == EquationType.RIGHT_ELLIPSES: ellipses_idx = len(ab_str) @@ -178,17 +224,124 @@ def _reshape_einsum_outputs( A rank-3 `tf.Tensor` whose first dimension is the batch dimension. The product of the non-trivial dimensions of the output should be equal to the product of the non-trivial dimensions of `output_tensor`. + + Raises: + ValueError: If `equation` is not a valid einsum equation in the context of + the `tf.keras.layers.EinsumDense` layer. """ - match = re.fullmatch(r"([a-zA-Z.]+),([a-zA-Z.]+)->([a-zA-Z.]+)", equation) - if match is not None: - s1, s2, s3 = match.groups() - else: - raise ValueError( - "Invalid Einsum equation string " - + equation - + " ." - "Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, " - "{ab...,bc->ac...}" - ) - reversed_equation = s3 + "," + s2[::-1] + "->" + s1 + # Get the raw components of the reversed equation. + equation_type, (ab_str, bc_str, ac_str) = _parse_einsum_equation(equation) + prefix = "..." if equation_type == EquationType.LEFT_ELLIPSES else "" + suffix = "..." if equation_type == EquationType.RIGHT_ELLIPSES else "" + ellided_ab_str = prefix + ab_str + suffix + ellided_ac_str = prefix + ac_str + suffix + # Swap the `b` and `c` components. + c_str = os.path.commonprefix([bc_str[::-1], ac_str[::-1]])[::-1] + b_len = len(bc_str) - len(c_str) + b_str = bc_str[:b_len] + cb_str = c_str + b_str + reversed_equation = ellided_ac_str + "," + cb_str + "->" + ellided_ab_str return _reshape_einsum_inputs(output_tensor, reversed_equation) + + +def _get_einsum_bias_adjoint_reduction_axes( + equation: str, + bias_axes: str, + einsum_rank: int, +) -> list[int]: + """Computes axes related to the per-example adjoint of the einsum bias op. + + To describe the output of this computation, first recall that for each + example the `EinsumDense` layer performs the following transformation: + ``` + F(W, bias | X) = Einsum(W, X) + Q(bias) + ``` + where `W` is a tensor of trainable variables, `bias` is a tensor of rank + `len(bias_axes)`, `X` is a batch of inputs, and `Q` is a linear broadcast + operator that roughly corresponds to `Q(bias) ~= tf.broadcast_to(bias, S)` for + `S := tf.shape(Einsum(W, X))`. + + It is straightforward to show that the per-example adjoint of `Q` is given by + `Q'(Y) := tf.reduce_sum(Y, axes=R)` where `R` contains the broadcasting + indices. This function returns `R` as an unordered list of `int`s. + + Assumptions: + + A1. `equation` is one of the following forms: + C1. `ab,bc->ac` + C2. `...ab,bc->...ac` + C3. `ab...,bc->ac...` + + A2. The first character in the substring `a` (or `...a` in C2) + in assumption A1 corresponds to the batch dimension. + + A3. The characters in `bias_axes` must be subset of the non-batch dimension + characters in the substring `ac` (or `...ac` in C2) in + assumption A1. + + A4. `einsum_rank` is the length of the substring `ac` (or `...ac` in C2) in + assumption A1. This includes the batch dimension. + + Examples: + + 1. equation = 'ab,bc->ac', bias_axes = 'c', einsum_rank = 2 -> [] + 2. equation = 'ab,bce->ace', bias_axes = 'ce', einsum_rank = 3, -> [] + 3. equation = 'ab,bce->ace', bias_axes = 'c', einsum_rank = 3, -> [2] + 4. equation = 'ab,bce->ace', bias_axes = 'e', einsum_rank = 3, -> [1] + 5. equation = 'ab,bced->aced', bias_axes = 'ced', einsum_rank = 4 -> [] + 6. equation = 'ab,bced->aced', bias_axes = 'ce', einsum_rank = 4, -> [3], + 7. equation = 'ab,bced->aced', bias_axes = 'c', einsum_rank = 4, -> [2, 3] + 8. equation = '...ab,bce->...ace', bias_axes = 'c', einsum_rank = 4 + -> [1, 3] + 9. equation = '...ab,bce->...ace', bias_axes = 'c', einsum_rank = 10 + -> [1, 2, 3, 4, 5, 6, 7, 9] + 10. equation = 'ab...,bce->ace...', bias_axes = 'e', einsum_rank = 4 + -> [1, 3] + + Args: + equation: The einsum equation `string`. + bias_axes: A substring of the output part of `equation` specifying which + axes a bias `tf.Tensor` is added to. + einsum_rank: The rank of the tensor that the per-example adjoint operator is + being applied to. + + Returns: + A list of `int` containing axes in the `input` corresponding to + `input_rank`. Each `int` is at most `input_rank-1` and excludes zero. + + Raises: + ValueError: If `equation` is not a valid einsum equation in the context of + the `tf.keras.layers.EinsumDense` layer. + """ + reduction_axes = [] + bias_char_set = set(bias_axes) + equation_type, (_, _, ac_str) = _parse_einsum_equation(equation) + # Do not allow the bias axes to be the batch axis, since we want the adjoint + # of the bias broadcast op to apply the same operation to all examples in a + # batch. + if equation_type != EquationType.LEFT_ELLIPSES and ac_str[0] in bias_axes: + raise ValueError(f"Bias axis '{bias_axes}' cannot also be the batch axis.") + # If `equation` of the form `...ab,bc->...ac`, i.e., case C2, we do a + # right to left traversal; the other cases do a left to right traversal. + input_indices = range(einsum_rank) + traversal_zip = ( + itertools.zip_longest(reversed(input_indices), reversed(ac_str)) + if equation_type == EquationType.LEFT_ELLIPSES + else itertools.zip_longest(input_indices, ac_str) + ) + # Traverse the output part of `equation` and add an index to the output if + # the corresponding `char` in the `ac` part is NOT in `bias_axes` and the + # index is not zero (batch dimension). Add all indices except index zero in + # the `...` part of the output substring (if present). + for idx, output_char in traversal_zip: + # Exclude the batch dimension (idx == 0), since we want the per-example + # adjoint. + if idx != 0: + if output_char is not None and bias_char_set: + if output_char not in bias_char_set: + reduction_axes.append(idx) + else: + bias_char_set.remove(output_char) + else: + reduction_axes.append(idx) + return reduction_axes diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils_test.py index 1fd740c..2e48fc7 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils_test.py @@ -45,6 +45,21 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase): computed_result = einsum_utils._is_batch_of_vectors(t) self.assertEqual(computed_result, true_result) + @parameterized.product( + experiment_params=[ + (('ab', 'bc', 'ac'), True), + (('ab', 'a', 'b'), False), + (('ab', 'ca', 'bc'), False), + (('b', 'bc', 'c'), True), + (('ab', 'bc', 'bc'), False), + (('abc', 'cde', 'abde'), True), + ] + ) + def test_is_valid_einsum_equation(self, experiment_params): + inputs, true_result = experiment_params + computed_result = einsum_utils._is_valid_einsum_equation(*inputs) + self.assertEqual(computed_result, true_result) + @parameterized.product( experiment_params=[ ( @@ -66,7 +81,7 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase): ) def test_parse_einsum_equation(self, experiment_params): equation, true_eqn_type, true_groups = experiment_params - (computed_eqn_type, computed_groups) = einsum_utils._parse_einsum_equation( + computed_eqn_type, computed_groups = einsum_utils._parse_einsum_equation( equation ) self.assertEqual(computed_eqn_type, true_eqn_type) @@ -98,7 +113,7 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase): ] ) def test_reshape_einsum_inputs(self, experiment_params): - (equation, input_shape, true_permutations, true_parsed_shape) = ( + equation, input_shape, true_permutations, true_parsed_shape = ( experiment_params ) num_entries = int(np.prod(input_shape)) @@ -141,7 +156,7 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase): ] ) def test_reshape_einsum_outputs(self, experiment_params): - (equation, output_shape, true_permutations, true_parsed_shape) = ( + equation, output_shape, true_permutations, true_parsed_shape = ( experiment_params ) num_entries = int(np.prod(output_shape)) @@ -158,6 +173,77 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase): true_parsed_tensor = tf.reshape(true_parsed_tensor, true_parsed_shape) self.assertAllEqual(computed_parsed_tensor, true_parsed_tensor) + @parameterized.product( + experiment_params=[ + # einsum_utils.EquationType.NO_ELLIPSES + ('ab,bc->ac', 'c', 2, []), + ('ab,bce->ace', 'ce', 3, []), + ('ab,bce->ace', 'ec', 3, []), + ('ab,bce->ace', 'c', 3, [2]), + ('ab,bce->ace', 'e', 3, [1]), + ('ab,bced->aced', 'ced', 4, []), + ('ab,bced->aced', 'edc', 4, []), + ('ab,bced->aced', 'ce', 4, [3]), + ('ab,bced->aced', 'ec', 4, [3]), + ('ab,bced->aced', 'cd', 4, [2]), + ('ab,bced->aced', 'ed', 4, [1]), + ('ab,bced->aced', 'c', 4, [2, 3]), + ('ab,bced->aced', 'e', 4, [1, 3]), + ('ab,bced->aced', 'd', 4, [1, 2]), + # einsum_utils.EquationType.LEFT_ELLIPSES + ('...b,bc->...c', 'c', 2, []), + ('...b,bce->...ce', 'c', 3, [2]), + ('...b,bce->...ce', 'e', 3, [1]), + ('...ab,bc->...ac', 'c', 3, [1]), + ('...ab,bce->...ace', 'ac', 4, [3]), + ('...ab,bce->...ace', 'ae', 4, [2]), + ('...ab,bce->...ace', 'ce', 4, [1]), + ('...ab,bce->...ace', 'ec', 4, [1]), + ('...ab,bce->...ace', 'a', 4, [2, 3]), + ('...ab,bce->...ace', 'c', 4, [1, 3]), + ('...ab,bce->...ace', 'e', 4, [1, 2]), + ('...ab,bce->...ace', 'c', 5, [1, 2, 4]), + ('...ab,bce->...ace', 'c', 10, [1, 2, 3, 4, 5, 6, 7, 9]), + # einsum_utils.EquationType.RIGHT_ELLIPSES + ('ab...,bc->ac...', 'c', 3, [2]), + ('ab...,bce->ace...', 'ce', 4, [3]), + ('ab...,bce->ace...', 'ec', 4, [3]), + ('ab...,bce->ace...', 'c', 4, [2, 3]), + ('ab...,bce->ace...', 'e', 4, [1, 3]), + ] + ) + def test_get_einsum_bias_adjoint_reduction_axes(self, experiment_params): + equation, bias_axes, einsum_rank, true_reduction_axes = experiment_params + computed_reduction_axes = ( + einsum_utils._get_einsum_bias_adjoint_reduction_axes( + equation, bias_axes, einsum_rank + ) + ) + computed_reduction_axes.sort() + true_reduction_axes.sort() + self.assertAllEqual(computed_reduction_axes, true_reduction_axes) + + @parameterized.product( + experiment_params=[ + # einsum_utils.EquationType.NO_ELLIPSES + ('ab,bc->ac', 'a', 2), + # einsum_utils.EquationType.RIGHT_ELLIPSES + ('ab...,bc->ac...', 'a', 3), + ('ab...,bc->ac...', 'a', 4), + ('ab...,bcde->acde...', 'acd', 4), + ] + ) + def test_bias_axis_eq_batch_axis_throws_error(self, experiment_params): + equation, bias_axes, einsum_rank = experiment_params + with self.assertRaises(ValueError) as context: + einsum_utils._get_einsum_bias_adjoint_reduction_axes( + equation, bias_axes, einsum_rank + ) + self.assertEqual( + f"Bias axis '{bias_axes}' cannot also be the batch axis.", + str(context.exception), + ) + if __name__ == '__main__': tf.test.main()