From 3deaae30a15967e767dc755904611d48d8ed334b Mon Sep 17 00:00:00 2001 From: William Kong Date: Thu, 25 Apr 2024 10:08:19 -0700 Subject: [PATCH] Add compatability for Einsum layers with dynamic shapes. PiperOrigin-RevId: 628111219 --- .../registry_functions/einsum_utils.py | 8 ++-- .../registry_functions/einsum_utils_test.py | 44 +++++++++++++++++-- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py index 84b79e3..b506961 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py @@ -28,11 +28,11 @@ EquationType = enum.Enum( ) -def _is_batch_of_vectors(t: tf.Tensor) -> bool: +def is_batch_of_vectors(t: tf.Tensor) -> bool: """Checks if an input is a batch of (effectively) 1D vectors.""" num_nontrivial_indices = 0 for s in t.shape[1:]: - if s > 1: + if s is None or s > 1: num_nontrivial_indices += 1 if num_nontrivial_indices > 1: return False @@ -442,8 +442,8 @@ def compute_fast_einsum_squared_gradient_norm( # NOTE: When the input/gradient tensors are 1D, it is MUCH faster to do # a `tf.square()` + `tf.reduce_sum()` than a single `tf.matmul()`. if ( - _is_batch_of_vectors(input_tensor) - and _is_batch_of_vectors(grad_tensor) + is_batch_of_vectors(input_tensor) + and is_batch_of_vectors(grad_tensor) and num_microbatches is None ): x_matrix = tf.reshape(x, [tf.shape(x)[0], -1]) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils_test.py index 2e48fc7..102b30c 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils_test.py @@ -29,20 +29,58 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase): ([1, 2], True), ([2, 1], True), ([2, 2], True), - # 3D tensors + # 3D tensors (batch of vectors) ([2, 1, 1], True), ([1, 2, 1], True), ([1, 1, 2], True), ([2, 2, 1], True), ([2, 1, 2], True), + # 3D tensors (batch of matrices) ([1, 2, 2], False), ([2, 2, 2], False), ] ) - def test_is_batch_of_vectors(self, experiment_params): + def test_is_batch_of_vectors_on_static_shapes(self, experiment_params): shape, true_result = experiment_params t = tf.zeros(shape) - computed_result = einsum_utils._is_batch_of_vectors(t) + computed_result = einsum_utils.is_batch_of_vectors(t) + self.assertEqual(computed_result, true_result) + + @parameterized.product( + experiment_params=[ + # 1D tensors + ([None], True), + # 2D tensors + ([1, None], True), + ([None, 1], True), + ([None, None], True), + ([2, None], True), + ([None, 2], True), + # 3D tensors (batch of vectors) + ([None, 1, 1], True), + ([1, None, 1], True), + ([1, 1, None], True), + ([None, None, 1], True), + ([None, 2, 1], True), + ([2, None, 1], True), + ([None, 1, None], True), + ([None, 1, 2], True), + ([2, 1, None], True), + ([1, None, None], False), + # 3D tensors (batch of matrices) + ([1, 2, None], False), + ([1, None, 2], False), + ([None, None, None], False), + ([None, 2, None], False), + ([None, None, 2], False), + ([2, None, 2], False), + ([2, 2, None], False), + ] + ) + def test_is_batch_of_vectors_on_dynamic_shapes(self, experiment_params): + shape, true_result = experiment_params + t = tf.keras.Input(shape=shape[1:], batch_size=shape[0]) + computed_result = einsum_utils.is_batch_of_vectors(t) self.assertEqual(computed_result, true_result) @parameterized.product(