forked from 626_privacy/tensorflow_privacy
Add compatability for Einsum layers with dynamic shapes.
PiperOrigin-RevId: 628111219
This commit is contained in:
parent
3fa0a2d362
commit
3deaae30a1
2 changed files with 45 additions and 7 deletions
|
@ -28,11 +28,11 @@ EquationType = enum.Enum(
|
|||
)
|
||||
|
||||
|
||||
def _is_batch_of_vectors(t: tf.Tensor) -> bool:
|
||||
def is_batch_of_vectors(t: tf.Tensor) -> bool:
|
||||
"""Checks if an input is a batch of (effectively) 1D vectors."""
|
||||
num_nontrivial_indices = 0
|
||||
for s in t.shape[1:]:
|
||||
if s > 1:
|
||||
if s is None or s > 1:
|
||||
num_nontrivial_indices += 1
|
||||
if num_nontrivial_indices > 1:
|
||||
return False
|
||||
|
@ -442,8 +442,8 @@ def compute_fast_einsum_squared_gradient_norm(
|
|||
# NOTE: When the input/gradient tensors are 1D, it is MUCH faster to do
|
||||
# a `tf.square()` + `tf.reduce_sum()` than a single `tf.matmul()`.
|
||||
if (
|
||||
_is_batch_of_vectors(input_tensor)
|
||||
and _is_batch_of_vectors(grad_tensor)
|
||||
is_batch_of_vectors(input_tensor)
|
||||
and is_batch_of_vectors(grad_tensor)
|
||||
and num_microbatches is None
|
||||
):
|
||||
x_matrix = tf.reshape(x, [tf.shape(x)[0], -1])
|
||||
|
|
|
@ -29,20 +29,58 @@ class EinsumUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
([1, 2], True),
|
||||
([2, 1], True),
|
||||
([2, 2], True),
|
||||
# 3D tensors
|
||||
# 3D tensors (batch of vectors)
|
||||
([2, 1, 1], True),
|
||||
([1, 2, 1], True),
|
||||
([1, 1, 2], True),
|
||||
([2, 2, 1], True),
|
||||
([2, 1, 2], True),
|
||||
# 3D tensors (batch of matrices)
|
||||
([1, 2, 2], False),
|
||||
([2, 2, 2], False),
|
||||
]
|
||||
)
|
||||
def test_is_batch_of_vectors(self, experiment_params):
|
||||
def test_is_batch_of_vectors_on_static_shapes(self, experiment_params):
|
||||
shape, true_result = experiment_params
|
||||
t = tf.zeros(shape)
|
||||
computed_result = einsum_utils._is_batch_of_vectors(t)
|
||||
computed_result = einsum_utils.is_batch_of_vectors(t)
|
||||
self.assertEqual(computed_result, true_result)
|
||||
|
||||
@parameterized.product(
|
||||
experiment_params=[
|
||||
# 1D tensors
|
||||
([None], True),
|
||||
# 2D tensors
|
||||
([1, None], True),
|
||||
([None, 1], True),
|
||||
([None, None], True),
|
||||
([2, None], True),
|
||||
([None, 2], True),
|
||||
# 3D tensors (batch of vectors)
|
||||
([None, 1, 1], True),
|
||||
([1, None, 1], True),
|
||||
([1, 1, None], True),
|
||||
([None, None, 1], True),
|
||||
([None, 2, 1], True),
|
||||
([2, None, 1], True),
|
||||
([None, 1, None], True),
|
||||
([None, 1, 2], True),
|
||||
([2, 1, None], True),
|
||||
([1, None, None], False),
|
||||
# 3D tensors (batch of matrices)
|
||||
([1, 2, None], False),
|
||||
([1, None, 2], False),
|
||||
([None, None, None], False),
|
||||
([None, 2, None], False),
|
||||
([None, None, 2], False),
|
||||
([2, None, 2], False),
|
||||
([2, 2, None], False),
|
||||
]
|
||||
)
|
||||
def test_is_batch_of_vectors_on_dynamic_shapes(self, experiment_params):
|
||||
shape, true_result = experiment_params
|
||||
t = tf.keras.Input(shape=shape[1:], batch_size=shape[0])
|
||||
computed_result = einsum_utils.is_batch_of_vectors(t)
|
||||
self.assertEqual(computed_result, true_result)
|
||||
|
||||
@parameterized.product(
|
||||
|
|
Loading…
Reference in a new issue