diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD
index b72263e..ab16da6 100644
--- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD
+++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD
@@ -57,6 +57,7 @@ py_library(
     srcs = ["dense.py"],
     srcs_version = "PY3",
     deps = [
+        ":einsum_utils",
         "//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
         "//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
     ],
diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py
index 7c49c5b..4218a1a 100644
--- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py
+++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py
@@ -16,8 +16,8 @@
 from collections.abc import Mapping, Sequence
 from typing import Any, Optional
 import tensorflow as tf
-from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
 from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
+from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_utils
 
 
 def dense_layer_computation(
@@ -74,28 +74,12 @@ def dense_layer_computation(
   outputs = orig_activation(base_vars) if orig_activation else base_vars
 
   def sqr_norm_fn(base_vars_grads):
-    def _compute_gramian(x):
-      if num_microbatches is not None:
-        x_microbatched = common_manip_utils.maybe_add_microbatch_axis(
-            x,
-            num_microbatches,
-        )
-        return tf.matmul(x_microbatched, x_microbatched, transpose_b=True)
-      else:
-        # Special handling for better efficiency
-        return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x)))
-
-    inputs_gram = _compute_gramian(*input_args)
-    base_vars_grads_gram = _compute_gramian(base_vars_grads)
-    if layer_instance.use_bias:
-      # Adding a bias term is equivalent to a layer with no bias term and which
-      # adds an additional variable to the layer input that only takes a
-      # constant value of 1.0. This is thus equivalent to adding 1.0 to the sum
-      # of the squared values of the inputs.
-      inputs_gram += 1.0
-    return tf.reduce_sum(
-        inputs_gram * base_vars_grads_gram,
-        axis=tf.range(1, tf.rank(inputs_gram)),
+    return einsum_utils.compute_fast_einsum_squared_gradient_norm(
+        "...b,bc->...c",
+        input_args[0],
+        base_vars_grads,
+        "c" if layer_instance.use_bias else None,
+        num_microbatches,
     )
 
   return base_vars, outputs, sqr_norm_fn
diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py
index b7480ac..84b79e3 100644
--- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py
+++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py
@@ -19,7 +19,6 @@ import os
 import re
 from typing import Optional
 
-import numpy as np
 import tensorflow as tf
 from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
 
@@ -198,10 +197,10 @@ def _reshape_einsum_inputs(
     pivot_idx = b_idx
   # The output tensor is a batched set of matrices, split at the pivot index
   # of the previously prepped tensor.
-  base_tensor_shape = input_tensor.shape
-  batch_size = base_tensor_shape[0]
-  num_rows = int(np.prod(base_tensor_shape[1:pivot_idx]))
-  num_columns = int(np.prod(base_tensor_shape[pivot_idx:]))
+  input_shape = tf.shape(input_tensor)
+  batch_size = input_shape[0]
+  num_rows = tf.reduce_prod(input_shape[1:pivot_idx])
+  num_columns = tf.reduce_prod(input_shape[pivot_idx:])
   return tf.reshape(input_tensor, shape=[batch_size, num_rows, num_columns])