Add support for fast clipping of dense layer gradients where the dimension of the input is larger than 1.

This change specifically wraps the fast clipping logic used in EinsumDense layers, which is a generalization of the Gramian-based that was used for dense layer clipping.

PiperOrigin-RevId: 585809850
This commit is contained in:
William Kong 2023-11-27 17:58:08 -08:00 committed by A. Unique TensorFlower
parent b19088f048
commit f51b637dda
3 changed files with 12 additions and 28 deletions

View file

@ -57,6 +57,7 @@ py_library(
srcs = ["dense.py"], srcs = ["dense.py"],
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
":einsum_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases", "//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
], ],

View file

@ -16,8 +16,8 @@
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Optional from typing import Any, Optional
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_utils
def dense_layer_computation( def dense_layer_computation(
@ -74,28 +74,12 @@ def dense_layer_computation(
outputs = orig_activation(base_vars) if orig_activation else base_vars outputs = orig_activation(base_vars) if orig_activation else base_vars
def sqr_norm_fn(base_vars_grads): def sqr_norm_fn(base_vars_grads):
def _compute_gramian(x): return einsum_utils.compute_fast_einsum_squared_gradient_norm(
if num_microbatches is not None: "...b,bc->...c",
x_microbatched = common_manip_utils.maybe_add_microbatch_axis( input_args[0],
x, base_vars_grads,
num_microbatches, "c" if layer_instance.use_bias else None,
) num_microbatches,
return tf.matmul(x_microbatched, x_microbatched, transpose_b=True)
else:
# Special handling for better efficiency
return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x)))
inputs_gram = _compute_gramian(*input_args)
base_vars_grads_gram = _compute_gramian(base_vars_grads)
if layer_instance.use_bias:
# Adding a bias term is equivalent to a layer with no bias term and which
# adds an additional variable to the layer input that only takes a
# constant value of 1.0. This is thus equivalent to adding 1.0 to the sum
# of the squared values of the inputs.
inputs_gram += 1.0
return tf.reduce_sum(
inputs_gram * base_vars_grads_gram,
axis=tf.range(1, tf.rank(inputs_gram)),
) )
return base_vars, outputs, sqr_norm_fn return base_vars, outputs, sqr_norm_fn

View file

@ -19,7 +19,6 @@ import os
import re import re
from typing import Optional from typing import Optional
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
@ -198,10 +197,10 @@ def _reshape_einsum_inputs(
pivot_idx = b_idx pivot_idx = b_idx
# The output tensor is a batched set of matrices, split at the pivot index # The output tensor is a batched set of matrices, split at the pivot index
# of the previously prepped tensor. # of the previously prepped tensor.
base_tensor_shape = input_tensor.shape input_shape = tf.shape(input_tensor)
batch_size = base_tensor_shape[0] batch_size = input_shape[0]
num_rows = int(np.prod(base_tensor_shape[1:pivot_idx])) num_rows = tf.reduce_prod(input_shape[1:pivot_idx])
num_columns = int(np.prod(base_tensor_shape[pivot_idx:])) num_columns = tf.reduce_prod(input_shape[pivot_idx:])
return tf.reshape(input_tensor, shape=[batch_size, num_rows, num_columns]) return tf.reshape(input_tensor, shape=[batch_size, num_rows, num_columns])