diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD index b72263e..ab16da6 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD @@ -57,6 +57,7 @@ py_library( srcs = ["dense.py"], srcs_version = "PY3", deps = [ + ":einsum_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases", ], diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py index 7c49c5b..4218a1a 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py @@ -16,8 +16,8 @@ from collections.abc import Mapping, Sequence from typing import Any, Optional import tensorflow as tf -from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_utils def dense_layer_computation( @@ -74,28 +74,12 @@ def dense_layer_computation( outputs = orig_activation(base_vars) if orig_activation else base_vars def sqr_norm_fn(base_vars_grads): - def _compute_gramian(x): - if num_microbatches is not None: - x_microbatched = common_manip_utils.maybe_add_microbatch_axis( - x, - num_microbatches, - ) - return tf.matmul(x_microbatched, x_microbatched, transpose_b=True) - else: - # Special handling for better efficiency - return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x))) - - inputs_gram = _compute_gramian(*input_args) - base_vars_grads_gram = _compute_gramian(base_vars_grads) - if layer_instance.use_bias: - # Adding a bias term is equivalent to a layer with no bias term and which - # adds an additional variable to the layer input that only takes a - # constant value of 1.0. This is thus equivalent to adding 1.0 to the sum - # of the squared values of the inputs. - inputs_gram += 1.0 - return tf.reduce_sum( - inputs_gram * base_vars_grads_gram, - axis=tf.range(1, tf.rank(inputs_gram)), + return einsum_utils.compute_fast_einsum_squared_gradient_norm( + "...b,bc->...c", + input_args[0], + base_vars_grads, + "c" if layer_instance.use_bias else None, + num_microbatches, ) return base_vars, outputs, sqr_norm_fn diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py index b7480ac..84b79e3 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py @@ -19,7 +19,6 @@ import os import re from typing import Optional -import numpy as np import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils @@ -198,10 +197,10 @@ def _reshape_einsum_inputs( pivot_idx = b_idx # The output tensor is a batched set of matrices, split at the pivot index # of the previously prepped tensor. - base_tensor_shape = input_tensor.shape - batch_size = base_tensor_shape[0] - num_rows = int(np.prod(base_tensor_shape[1:pivot_idx])) - num_columns = int(np.prod(base_tensor_shape[pivot_idx:])) + input_shape = tf.shape(input_tensor) + batch_size = input_shape[0] + num_rows = tf.reduce_prod(input_shape[1:pivot_idx]) + num_columns = tf.reduce_prod(input_shape[pivot_idx:]) return tf.reshape(input_tensor, shape=[batch_size, num_rows, num_columns])