Re-organize files and simplify test names.

These changes are intended to support a more modular system for when we
add more layer registry functions (and their corresponding tests). They are
also made so that we do not have an enormous number of lengthy tests inside
`clip_grads_test.py`.

PiperOrigin-RevId: 545779495
This commit is contained in:
A. Unique TensorFlower 2023-07-05 14:00:40 -07:00
parent 9536fb26e7
commit 6b8007ddde
15 changed files with 977 additions and 712 deletions

View file

@ -2,11 +2,38 @@ load("@rules_python//python:defs.bzl", "py_library", "py_test")
package(default_visibility = ["//visibility:public"])
py_library(
name = "type_aliases",
srcs = ["type_aliases.py"],
srcs_version = "PY3",
)
py_library(
name = "common_manip_utils",
srcs = ["common_manip_utils.py"],
srcs_version = "PY3",
deps = [":type_aliases"],
)
py_library(
name = "common_test_utils",
srcs = ["common_test_utils.py"],
srcs_version = "PY3",
deps = [
":clip_grads",
":layer_registry",
":type_aliases",
],
)
py_library(
name = "gradient_clipping_utils",
srcs = ["gradient_clipping_utils.py"],
srcs_version = "PY3",
deps = [":layer_registry"],
deps = [
":layer_registry",
":type_aliases",
],
)
py_test(
@ -14,13 +41,21 @@ py_test(
srcs = ["gradient_clipping_utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":gradient_clipping_utils"],
deps = [
":gradient_clipping_utils",
":type_aliases",
],
)
py_library(
name = "layer_registry",
srcs = ["layer_registry.py"],
srcs_version = "PY3",
deps = [
":type_aliases",
"//tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions:dense",
"//tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions:embedding",
],
)
py_library(
@ -28,20 +63,22 @@ py_library(
srcs = ["clip_grads.py"],
srcs_version = "PY3",
deps = [
":common_manip_utils",
":gradient_clipping_utils",
":layer_registry",
":type_aliases",
],
)
py_test(
name = "clip_grads_test",
size = "large",
srcs = ["clip_grads_test.py"],
python_version = "PY3",
shard_count = 8,
srcs_version = "PY3",
deps = [
":clip_grads",
":common_test_utils",
":layer_registry",
":type_aliases",
],
)

View file

@ -21,20 +21,19 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
`compute_gradient_norms()` function).
"""
from typing import Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
from typing import List, Optional, Tuple
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
LossFn = Callable[..., tf.Tensor]
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
def get_registry_generator_fn(
tape: tf.GradientTape,
layer_registry: lr.LayerRegistry,
num_microbatches: Optional[lr.BatchSize] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None,
):
"""Creates the generator function for `compute_gradient_norms()`."""
if layer_registry is None:
@ -70,11 +69,11 @@ def get_registry_generator_fn(
def compute_gradient_norms(
input_model: tf.keras.Model,
layer_registry: lr.LayerRegistry,
x_batch: InputTensor,
x_batch: type_aliases.InputTensors,
y_batch: tf.Tensor,
weight_batch: Optional[tf.Tensor] = None,
per_example_loss_fn: Optional[LossFn] = None,
num_microbatches: Optional[lr.BatchSize] = None,
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None,
trainable_vars: Optional[List[tf.Variable]] = None,
):
"""Computes the per-example loss gradient norms for given data.
@ -147,7 +146,10 @@ def compute_gradient_norms(
)
if num_microbatches is not None:
losses = tf.reduce_mean(
lr.maybe_add_microbatch_axis(losses, num_microbatches), axis=1
common_manip_utils.maybe_add_microbatch_axis(
losses, num_microbatches
),
axis=1,
)
summed_loss = tf.reduce_sum(losses)
# Unwrap the generator outputs so that the next loop avoids duplicating
@ -212,11 +214,11 @@ def compute_clipped_gradients_and_outputs(
input_model: tf.keras.Model,
l2_norm_clip: float,
layer_registry: lr.LayerRegistry,
x_batch: InputTensor,
x_batch: type_aliases.InputTensors,
y_batch: tf.Tensor,
weight_batch: Optional[tf.Tensor] = None,
num_microbatches: Optional[lr.BatchSize] = None,
clipping_loss: Optional[LossFn] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None,
clipping_loss: Optional[type_aliases.LossFn] = None,
) -> Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor]:
"""Computes the per-example clipped loss gradient and other useful outputs.
@ -287,7 +289,9 @@ def compute_clipped_gradients_and_outputs(
# c is computed based on the gradient of w*l, so that if we scale w*l by c,
# the result has bounded per-example gradients. So the loss to optimize is
# c*w*l. Here we compute c*w before passing it to the loss.
weight_batch = lr.maybe_add_microbatch_axis(weight_batch, num_microbatches)
weight_batch = common_manip_utils.maybe_add_microbatch_axis(
weight_batch, num_microbatches
)
if num_microbatches is None:
clip_weights = clip_weights * weight_batch # shape [num_microbatches]
else:
@ -303,8 +307,12 @@ def compute_clipped_gradients_and_outputs(
# is not defined in the contract so may not hold, especially for
# custom losses.
y_pred = input_model(x_batch, training=True)
mb_y_batch = lr.maybe_add_microbatch_axis(y_batch, num_microbatches)
mb_y_pred = lr.maybe_add_microbatch_axis(y_pred, num_microbatches)
mb_y_batch = common_manip_utils.maybe_add_microbatch_axis(
y_batch, num_microbatches
)
mb_y_pred = common_manip_utils.maybe_add_microbatch_axis(
y_pred, num_microbatches
)
clipping_loss_value = clipping_loss(mb_y_batch, mb_y_pred, clip_weights)
clipped_grads = tape.gradient(
clipping_loss_value,

View file

@ -12,22 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
from typing import Any, Dict, Optional, Text, Tuple
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
# ==============================================================================
# Type aliases
# ==============================================================================
LayerGenerator = Callable[[int, int], tf.keras.layers.Layer]
ModelGenerator = Callable[
[LayerGenerator, Union[int, List[int]], int], tf.keras.Model
]
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
# ==============================================================================
@ -52,7 +44,7 @@ def double_dense_layer_computation(
input_kwargs: Dict[Text, Any],
tape: tf.GradientTape,
num_microbatches: Optional[int],
) -> layer_registry.RegistryFunctionOutput:
) -> type_aliases.RegistryFunctionOutput:
"""Layer registry function for the custom `DoubleDense` layer class."""
vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation(
layer_instance.dense1, input_args, input_kwargs, tape, num_microbatches
@ -69,304 +61,17 @@ def double_dense_layer_computation(
return [vars1, vars2], outputs, sqr_norm_fn
def test_loss_fn(
x: tf.Tensor, y: tf.Tensor, weights: Optional[tf.Tensor] = None
) -> tf.Tensor:
# Define a loss function which is unlikely to be coincidently defined.
if weights is None:
weights = 1.0
loss = 3.14 * tf.reduce_sum(
tf.cast(weights, tf.float32) * tf.square(x - y), axis=1
)
return loss
def compute_true_gradient_norms(
input_model: tf.keras.Model,
x_batch: tf.Tensor,
y_batch: tf.Tensor,
weight_batch: Optional[tf.Tensor],
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
num_microbatches: Optional[int],
trainable_vars: Optional[tf.Variable] = None,
) -> layer_registry.OutputTensor:
"""Computes the real gradient norms for an input `(model, x, y)`."""
if per_example_loss_fn is None:
loss_config = input_model.loss.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
per_example_loss_fn = input_model.loss.from_config(loss_config)
with tf.GradientTape(persistent=True) as tape:
y_pred = input_model(x_batch)
loss = per_example_loss_fn(y_batch, y_pred, weight_batch)
if num_microbatches is not None:
loss = tf.reduce_mean(
tf.reshape(
loss,
tf.concat([[num_microbatches, -1], tf.shape(loss)[1:]], axis=0),
),
axis=1,
)
if isinstance(loss, tf.RaggedTensor):
loss = loss.to_tensor()
sqr_norms = []
trainable_vars = trainable_vars or input_model.trainable_variables
for var in trainable_vars:
jacobian = tape.jacobian(loss, var, experimental_use_pfor=False)
reduction_axes = tf.range(1, len(jacobian.shape))
sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes)
sqr_norms.append(sqr_norm)
sqr_norm_tsr = tf.stack(sqr_norms, axis=1)
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
def get_computed_and_true_norms(
model_generator: ModelGenerator,
layer_generator: LayerGenerator,
input_dims: Union[int, List[int]],
output_dim: int,
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
num_microbatches: Optional[int],
is_eager: bool,
x_batch: tf.Tensor,
weight_batch: Optional[tf.Tensor] = None,
rng_seed: int = 777,
registry: layer_registry.LayerRegistry = None,
partial: bool = False,
) -> Tuple[tf.Tensor, tf.Tensor]:
"""Obtains the true and computed gradient norms for a model and batch input.
Helpful testing wrapper function used to avoid code duplication.
Args:
model_generator: A function which takes in three arguments:
`layer_generator`, `idim`, and `odim`. Returns a `tf.keras.Model` that
accepts input tensors of dimension `idim` and returns output tensors of
dimension `odim`. Layers of the model are based on the `layer_generator`
(see below for its description).
layer_generator: A function which takes in two arguments: `idim` and `odim`.
Returns a `tf.keras.layers.Layer` that accepts input tensors of dimension
`idim` and returns output tensors of dimension `odim`.
input_dims: The input dimension(s) of the test `tf.keras.Model` instance.
output_dim: The output dimension of the test `tf.keras.Model` instance.
per_example_loss_fn: If not None, used as vectorized per example loss
function.
num_microbatches: The number of microbatches. None or an integer.
is_eager: whether the model should be run eagerly.
x_batch: inputs to be tested.
weight_batch: optional weights passed to the loss.
rng_seed: used as a seed for random initialization.
registry: required for fast clipping.
partial: Whether to compute the gradient norm with respect to a partial set
of varibles. If True, only consider the variables in the first layer.
Returns:
A `tuple` `(computed_norm, true_norms)`. The first element contains the
clipped gradient norms that are generated by
`clip_grads.compute_gradient_norms()` under the setting given by the given
model and layer generators. The second element contains the true clipped
gradient norms under the aforementioned setting.
"""
model = model_generator(layer_generator, input_dims, output_dim)
model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=1.0),
loss=tf.keras.losses.MeanSquaredError(
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
),
run_eagerly=is_eager,
)
trainable_vars = None
if partial:
# Gets the first layer with variables.
for l in model.layers:
trainable_vars = l.trainable_variables
if trainable_vars:
break
y_pred = model(x_batch)
y_batch = tf.ones_like(y_pred)
tf.keras.utils.set_random_seed(rng_seed)
computed_norms = clip_grads.compute_gradient_norms(
input_model=model,
x_batch=x_batch,
y_batch=y_batch,
weight_batch=weight_batch,
layer_registry=registry,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
trainable_vars=trainable_vars,
)
tf.keras.utils.set_random_seed(rng_seed)
true_norms = compute_true_gradient_norms(
model,
x_batch,
y_batch,
weight_batch,
per_example_loss_fn,
num_microbatches,
trainable_vars=trainable_vars,
)
return (computed_norms, true_norms)
# ==============================================================================
# Model generators.
# ==============================================================================
def make_two_layer_sequential_model(layer_generator, input_dim, output_dim):
"""Creates a 2-layer sequential model."""
model = tf.keras.Sequential()
model.add(tf.keras.Input(shape=(input_dim,)))
model.add(layer_generator(input_dim, output_dim))
model.add(tf.keras.layers.Dense(1))
return model
def make_three_layer_sequential_model(layer_generator, input_dim, output_dim):
"""Creates a 3-layer sequential model."""
model = tf.keras.Sequential()
model.add(tf.keras.Input(shape=(input_dim,)))
layer1 = layer_generator(input_dim, output_dim)
model.add(layer1)
if isinstance(layer1, tf.keras.layers.Embedding):
# Having multiple consecutive embedding layers does not make sense since
# embedding layers only map integers to real-valued vectors.
model.add(tf.keras.layers.Dense(output_dim))
else:
model.add(layer_generator(output_dim, output_dim))
model.add(tf.keras.layers.Dense(1))
return model
def make_two_layer_functional_model(layer_generator, input_dim, output_dim):
"""Creates a 2-layer 1-input functional model with a pre-output square op."""
inputs = tf.keras.Input(shape=(input_dim,))
layer1 = layer_generator(input_dim, output_dim)
temp1 = layer1(inputs)
temp2 = tf.square(temp1)
outputs = tf.keras.layers.Dense(1)(temp2)
return tf.keras.Model(inputs=inputs, outputs=outputs)
def make_two_tower_model(layer_generator, input_dim, output_dim):
"""Creates a 2-layer 2-input functional model."""
inputs1 = tf.keras.Input(shape=(input_dim,))
layer1 = layer_generator(input_dim, output_dim)
temp1 = layer1(inputs1)
inputs2 = tf.keras.Input(shape=(input_dim,))
layer2 = layer_generator(input_dim, output_dim)
temp2 = layer2(inputs2)
temp3 = tf.add(temp1, temp2)
outputs = tf.keras.layers.Dense(1)(temp3)
return tf.keras.Model(inputs=[inputs1, inputs2], outputs=outputs)
def make_bow_model(layer_generator, input_dims, output_dim):
del layer_generator
inputs = tf.keras.Input(shape=input_dims)
# For the Embedding layer, input_dim is the vocabulary size. This should
# be distinguished from the input_dim argument, which is the number of ids
# in eache example.
emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim)
feature_embs = emb_layer(inputs)
reduction_axes = tf.range(1, len(feature_embs.shape))
example_embs = tf.expand_dims(
tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1
)
return tf.keras.Model(inputs=inputs, outputs=example_embs)
def make_dense_bow_model(layer_generator, input_dims, output_dim):
del layer_generator
inputs = tf.keras.Input(shape=input_dims)
# For the Embedding layer, input_dim is the vocabulary size. This should
# be distinguished from the input_dim argument, which is the number of ids
# in eache example.
cardinality = 10
emb_layer = tf.keras.layers.Embedding(
input_dim=cardinality, output_dim=output_dim
)
feature_embs = emb_layer(inputs)
reduction_axes = tf.range(1, len(feature_embs.shape))
example_embs = tf.expand_dims(
tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1
)
outputs = tf.keras.layers.Dense(1)(example_embs)
return tf.keras.Model(inputs=inputs, outputs=outputs)
def make_weighted_bow_model(layer_generator, input_dims, output_dim):
# NOTE: This model only accepts dense input tensors.
del layer_generator
inputs = tf.keras.Input(shape=input_dims)
# For the Embedding layer, input_dim is the vocabulary size. This should
# be distinguished from the input_dim argument, which is the number of ids
# in eache example.
cardinality = 10
emb_layer = tf.keras.layers.Embedding(
input_dim=cardinality, output_dim=output_dim
)
feature_embs = emb_layer(inputs)
feature_weights = tf.random.uniform(tf.shape(feature_embs))
weighted_embs = feature_embs * feature_weights
reduction_axes = tf.range(1, len(weighted_embs.shape))
example_embs = tf.expand_dims(
tf.reduce_sum(weighted_embs, axis=reduction_axes), axis=-1
)
outputs = tf.keras.layers.Dense(1)(example_embs)
return tf.keras.Model(inputs=inputs, outputs=outputs)
# ==============================================================================
# Factory functions.
# ==============================================================================
def get_nd_test_batches(n: int):
"""Returns a list of input batches of dimension n."""
# The first two batches have a single element, the last batch has 2 elements.
x0 = tf.zeros([1, n], dtype=tf.float64)
x1 = tf.constant([range(n)], dtype=tf.float64)
x2 = tf.concat([x0, x1], axis=0)
w0 = tf.constant([1], dtype=tf.float64)
w1 = tf.constant([2], dtype=tf.float64)
w2 = tf.constant([0.5, 0.5], dtype=tf.float64)
return [x0, x1, x2], [w0, w1, w2]
def get_dense_layer_generators():
def sigmoid_dense_layer(b):
return tf.keras.layers.Dense(b, activation='sigmoid')
return {
'pure_dense': lambda a, b: tf.keras.layers.Dense(b),
'sigmoid_dense': lambda a, b: sigmoid_dense_layer(b),
}
def get_dense_model_generators():
return {
'seq1': make_two_layer_sequential_model,
'seq2': make_three_layer_sequential_model,
'func1': make_two_layer_functional_model,
'tower1': make_two_tower_model,
}
def get_embedding_model_generators():
return {
'bow1': make_bow_model,
'bow2': make_dense_bow_model,
'weighted_bow1': make_weighted_bow_model,
}
# ==============================================================================
# Main tests.
# ==============================================================================
class ClipGradsDirectTest(tf.test.TestCase, parameterized.TestCase):
class DirectWeightsTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
input_dim=[1, 2], clip_value=[1e-6, 0.5, 1.0, 2.0, 10.0, 1e6]
)
def test_clip_weights(self, input_dim, clip_value):
tol = 1e-6
ts, _ = get_nd_test_batches(input_dim)
ts, _ = common_test_utils.get_nd_test_batches(input_dim)
for t in ts:
weights = clip_grads.compute_clip_weights(clip_value, t)
self.assertAllLessEqual(t * weights, clip_value + tol)
@ -375,136 +80,12 @@ class ClipGradsDirectTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsNone(clip_grads.compute_clip_weights(None, tf.ones(3)))
class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
model_name=list(get_dense_model_generators().keys()),
layer_name=list(get_dense_layer_generators().keys()),
input_dim=[4],
output_dim=[2],
per_example_loss_fn=[None, test_loss_fn],
num_microbatches=[None, 1, 2],
is_eager=[True, False],
partial=[True, False],
weighted=[True, False],
)
def test_gradient_norms_on_various_models(
self,
model_name,
layer_name,
input_dim,
output_dim,
per_example_loss_fn,
num_microbatches,
is_eager,
partial,
weighted,
):
model_generator = get_dense_model_generators()[model_name]
layer_generator = get_dense_layer_generators()[layer_name]
x_batches, weight_batches = get_nd_test_batches(input_dim)
default_registry = layer_registry.make_default_layer_registry()
for x_batch, weight_batch in zip(x_batches, weight_batches):
batch_size = x_batch.shape[0]
if num_microbatches is not None and batch_size % num_microbatches != 0:
continue
(computed_norms, true_norms) = get_computed_and_true_norms(
model_generator,
layer_generator,
input_dim,
output_dim,
per_example_loss_fn,
num_microbatches,
is_eager,
x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch,
weight_batch=weight_batch if weighted else None,
registry=default_registry,
partial=partial,
)
expected_size = num_microbatches or batch_size
self.assertEqual(computed_norms.shape[0], expected_size)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
# supports them for embeddings.
@parameterized.product(
x_batch=[
# 2D inputs.
tf.convert_to_tensor([[0, 1]], dtype_hint=tf.int32),
tf.convert_to_tensor([[0, 1], [1, 1], [0, 0]], dtype_hint=tf.int32),
tf.ragged.constant(
[[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32
),
tf.ragged.constant(
[[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1], [0, 1]],
dtype=tf.int32,
),
# 3D inputs.
tf.convert_to_tensor([[[0, 1]]], dtype_hint=tf.int32),
tf.convert_to_tensor(
[[[0, 1]], [[1, 1]], [[0, 0]]], dtype_hint=tf.int32
),
tf.ragged.constant(
[[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]],
dtype=tf.int32,
),
tf.ragged.constant(
[[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]], [[0]]],
dtype=tf.int32,
),
],
model_name=list(get_embedding_model_generators().keys()),
output_dim=[2],
per_example_loss_fn=[None, test_loss_fn],
num_microbatches=[None, 2],
is_eager=[True, False],
partial=[True, False],
)
def test_gradient_norms_on_various_models(
self,
x_batch,
model_name,
output_dim,
per_example_loss_fn,
num_microbatches,
is_eager,
partial,
):
batch_size = x_batch.shape[0]
# The following are invalid test combinations, and are skipped.
if (
num_microbatches is not None and batch_size % num_microbatches != 0
) or (
model_name == 'weighted_bow1' and isinstance(x_batch, tf.RaggedTensor)
):
return
default_registry = layer_registry.make_default_layer_registry()
model_generator = get_embedding_model_generators()[model_name]
(computed_norms, true_norms) = get_computed_and_true_norms(
model_generator=model_generator,
layer_generator=None,
input_dims=x_batch.shape[1:],
output_dim=output_dim,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
is_eager=is_eager,
x_batch=x_batch,
registry=default_registry,
partial=partial,
)
self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
class CustomLayerTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
input_dim=[3],
output_dim=[2],
per_example_loss_fn=[None, test_loss_fn],
per_example_loss_fn=[None, common_test_utils.test_loss_fn],
num_microbatches=[None, 2],
is_eager=[True, False],
partial=[True, False],
@ -522,29 +103,31 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
):
registry = layer_registry.make_default_layer_registry()
registry.insert(DoubleDense, double_dense_layer_computation)
x_batches, weight_batches = get_nd_test_batches(input_dim)
x_batches, weight_batches = common_test_utils.get_nd_test_batches(input_dim)
for x_batch, weight_batch in zip(x_batches, weight_batches):
batch_size = x_batch.shape[0]
if num_microbatches is not None and batch_size % num_microbatches != 0:
continue
(computed_norms, true_norms) = get_computed_and_true_norms(
model_generator=make_two_layer_sequential_model,
layer_generator=lambda a, b: DoubleDense(b),
input_dims=input_dim,
output_dim=output_dim,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
is_eager=is_eager,
x_batch=x_batch,
weight_batch=weight_batch if weighted else None,
registry=registry,
partial=partial,
(computed_norms, true_norms) = (
common_test_utils.get_computed_and_true_norms(
model_generator=common_test_utils.make_two_layer_sequential_model,
layer_generator=lambda a, b: DoubleDense(b),
input_dims=input_dim,
output_dim=output_dim,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
is_eager=is_eager,
x_batch=x_batch,
weight_batch=weight_batch if weighted else None,
registry=registry,
partial=partial,
)
)
self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
class ClipGradsComputeClippedGradsAndOutputsTest(
class ComputeClippedGradsAndOutputsTest(
tf.test.TestCase, parameterized.TestCase
):
@ -553,7 +136,7 @@ class ClipGradsComputeClippedGradsAndOutputsTest(
dense_generator = lambda a, b: tf.keras.layers.Dense(b)
self._input_dim = 2
self._output_dim = 3
self._model = make_two_layer_sequential_model(
self._model = common_test_utils.make_two_layer_sequential_model(
dense_generator, self._input_dim, self._output_dim
)

View file

@ -0,0 +1,44 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A collection of common utility functions for tensor/data manipulation."""
from typing import Optional
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
def maybe_add_microbatch_axis(
x: tf.Tensor,
num_microbatches: Optional[type_aliases.BatchSize],
) -> tf.Tensor:
"""Adds the microbatch axis.
Args:
x: the input tensor.
num_microbatches: If None, x is returned unchanged. Otherwise, must divide
the batch size.
Returns:
The input tensor x, reshaped from [batch_size, ...] to
[num_microbatches, batch_size / num_microbatches, ...].
"""
if num_microbatches is None:
return x
with tf.control_dependencies(
[tf.assert_equal(tf.math.floormod(tf.shape(x)[0], num_microbatches), 0)]
):
return tf.reshape(
x, tf.concat([[num_microbatches, -1], tf.shape(x)[1:]], axis=0)
)

View file

@ -0,0 +1,284 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A collection of common utility functions for unit testing."""
from typing import Callable, List, Optional, Tuple, Union
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
# ==============================================================================
# Helper functions
# ==============================================================================
def get_nd_test_batches(n: int):
"""Returns a list of input batches of dimension n."""
# The first two batches have a single element, the last batch has 2 elements.
x0 = tf.zeros([1, n], dtype=tf.float64)
x1 = tf.constant([range(n)], dtype=tf.float64)
x2 = tf.concat([x0, x1], axis=0)
w0 = tf.constant([1], dtype=tf.float64)
w1 = tf.constant([2], dtype=tf.float64)
w2 = tf.constant([0.5, 0.5], dtype=tf.float64)
return [x0, x1, x2], [w0, w1, w2]
def test_loss_fn(
x: tf.Tensor, y: tf.Tensor, weights: Optional[tf.Tensor] = None
) -> tf.Tensor:
# Define a loss function which is unlikely to be coincidently defined.
if weights is None:
weights = 1.0
loss = 3.14 * tf.reduce_sum(
tf.cast(weights, tf.float32) * tf.square(x - y), axis=1
)
return loss
def compute_true_gradient_norms(
input_model: tf.keras.Model,
x_batch: tf.Tensor,
y_batch: tf.Tensor,
weight_batch: Optional[tf.Tensor],
per_example_loss_fn: Optional[type_aliases.LossFn],
num_microbatches: Optional[int],
trainable_vars: Optional[tf.Variable] = None,
) -> type_aliases.OutputTensors:
"""Computes the real gradient norms for an input `(model, x, y)`."""
if per_example_loss_fn is None:
loss_config = input_model.loss.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
per_example_loss_fn = input_model.loss.from_config(loss_config)
with tf.GradientTape(persistent=True) as tape:
y_pred = input_model(x_batch)
loss = per_example_loss_fn(y_batch, y_pred, weight_batch)
if num_microbatches is not None:
loss = tf.reduce_mean(
tf.reshape(
loss,
tf.concat([[num_microbatches, -1], tf.shape(loss)[1:]], axis=0),
),
axis=1,
)
if isinstance(loss, tf.RaggedTensor):
loss = loss.to_tensor()
sqr_norms = []
trainable_vars = trainable_vars or input_model.trainable_variables
for var in trainable_vars:
jacobian = tape.jacobian(loss, var, experimental_use_pfor=False)
reduction_axes = tf.range(1, len(jacobian.shape))
sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes)
sqr_norms.append(sqr_norm)
sqr_norm_tsr = tf.stack(sqr_norms, axis=1)
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
def get_computed_and_true_norms(
model_generator: type_aliases.ModelGenerator,
layer_generator: type_aliases.LayerGenerator,
input_dims: Union[int, List[int]],
output_dim: int,
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
num_microbatches: Optional[int],
is_eager: bool,
x_batch: tf.Tensor,
weight_batch: Optional[tf.Tensor] = None,
rng_seed: int = 777,
registry: layer_registry.LayerRegistry = None,
partial: bool = False,
) -> Tuple[tf.Tensor, tf.Tensor]:
"""Obtains the true and computed gradient norms for a model and batch input.
Helpful testing wrapper function used to avoid code duplication.
Args:
model_generator: A function which takes in three arguments:
`layer_generator`, `idim`, and `odim`. Returns a `tf.keras.Model` that
accepts input tensors of dimension `idim` and returns output tensors of
dimension `odim`. Layers of the model are based on the `layer_generator`
(see below for its description).
layer_generator: A function which takes in two arguments: `idim` and `odim`.
Returns a `tf.keras.layers.Layer` that accepts input tensors of dimension
`idim` and returns output tensors of dimension `odim`.
input_dims: The input dimension(s) of the test `tf.keras.Model` instance.
output_dim: The output dimension of the test `tf.keras.Model` instance.
per_example_loss_fn: If not None, used as vectorized per example loss
function.
num_microbatches: The number of microbatches. None or an integer.
is_eager: whether the model should be run eagerly.
x_batch: inputs to be tested.
weight_batch: optional weights passed to the loss.
rng_seed: used as a seed for random initialization.
registry: required for fast clipping.
partial: Whether to compute the gradient norm with respect to a partial set
of varibles. If True, only consider the variables in the first layer.
Returns:
A `tuple` `(computed_norm, true_norms)`. The first element contains the
clipped gradient norms that are generated by
`clip_grads.compute_gradient_norms()` under the setting given by the given
model and layer generators. The second element contains the true clipped
gradient norms under the aforementioned setting.
"""
model = model_generator(layer_generator, input_dims, output_dim)
model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=1.0),
loss=tf.keras.losses.MeanSquaredError(
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
),
run_eagerly=is_eager,
)
trainable_vars = None
if partial:
# Gets the first layer with variables.
for l in model.layers:
trainable_vars = l.trainable_variables
if trainable_vars:
break
y_pred = model(x_batch)
y_batch = tf.ones_like(y_pred)
tf.keras.utils.set_random_seed(rng_seed)
computed_norms = clip_grads.compute_gradient_norms(
input_model=model,
x_batch=x_batch,
y_batch=y_batch,
weight_batch=weight_batch,
layer_registry=registry,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
trainable_vars=trainable_vars,
)
tf.keras.utils.set_random_seed(rng_seed)
true_norms = compute_true_gradient_norms(
model,
x_batch,
y_batch,
weight_batch,
per_example_loss_fn,
num_microbatches,
trainable_vars=trainable_vars,
)
return (computed_norms, true_norms)
# ==============================================================================
# Model generators.
# ==============================================================================
def make_two_layer_sequential_model(layer_generator, input_dim, output_dim):
"""Creates a 2-layer sequential model."""
model = tf.keras.Sequential()
model.add(tf.keras.Input(shape=(input_dim,)))
model.add(layer_generator(input_dim, output_dim))
model.add(tf.keras.layers.Dense(1))
return model
def make_three_layer_sequential_model(layer_generator, input_dim, output_dim):
"""Creates a 3-layer sequential model."""
model = tf.keras.Sequential()
model.add(tf.keras.Input(shape=(input_dim,)))
layer1 = layer_generator(input_dim, output_dim)
model.add(layer1)
if isinstance(layer1, tf.keras.layers.Embedding):
# Having multiple consecutive embedding layers does not make sense since
# embedding layers only map integers to real-valued vectors.
model.add(tf.keras.layers.Dense(output_dim))
else:
model.add(layer_generator(output_dim, output_dim))
model.add(tf.keras.layers.Dense(1))
return model
def make_two_layer_functional_model(layer_generator, input_dim, output_dim):
"""Creates a 2-layer 1-input functional model with a pre-output square op."""
inputs = tf.keras.Input(shape=(input_dim,))
layer1 = layer_generator(input_dim, output_dim)
temp1 = layer1(inputs)
temp2 = tf.square(temp1)
outputs = tf.keras.layers.Dense(1)(temp2)
return tf.keras.Model(inputs=inputs, outputs=outputs)
def make_two_tower_model(layer_generator, input_dim, output_dim):
"""Creates a 2-layer 2-input functional model."""
inputs1 = tf.keras.Input(shape=(input_dim,))
layer1 = layer_generator(input_dim, output_dim)
temp1 = layer1(inputs1)
inputs2 = tf.keras.Input(shape=(input_dim,))
layer2 = layer_generator(input_dim, output_dim)
temp2 = layer2(inputs2)
temp3 = tf.add(temp1, temp2)
outputs = tf.keras.layers.Dense(1)(temp3)
return tf.keras.Model(inputs=[inputs1, inputs2], outputs=outputs)
def make_bow_model(layer_generator, input_dims, output_dim):
"""Creates a simple embedding bow model."""
del layer_generator
inputs = tf.keras.Input(shape=input_dims)
# For the Embedding layer, input_dim is the vocabulary size. This should
# be distinguished from the input_dim argument, which is the number of ids
# in eache example.
emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim)
feature_embs = emb_layer(inputs)
reduction_axes = tf.range(1, len(feature_embs.shape))
example_embs = tf.expand_dims(
tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1
)
return tf.keras.Model(inputs=inputs, outputs=example_embs)
def make_dense_bow_model(layer_generator, input_dims, output_dim):
"""Creates an embedding bow model with a `Dense` layer."""
del layer_generator
inputs = tf.keras.Input(shape=input_dims)
# For the Embedding layer, input_dim is the vocabulary size. This should
# be distinguished from the input_dim argument, which is the number of ids
# in eache example.
cardinality = 10
emb_layer = tf.keras.layers.Embedding(
input_dim=cardinality, output_dim=output_dim
)
feature_embs = emb_layer(inputs)
reduction_axes = tf.range(1, len(feature_embs.shape))
example_embs = tf.expand_dims(
tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1
)
outputs = tf.keras.layers.Dense(1)(example_embs)
return tf.keras.Model(inputs=inputs, outputs=outputs)
def make_weighted_bow_model(layer_generator, input_dims, output_dim):
"""Creates a weighted embedding bow model."""
# NOTE: This model only accepts dense input tensors.
del layer_generator
inputs = tf.keras.Input(shape=input_dims)
# For the Embedding layer, input_dim is the vocabulary size. This should
# be distinguished from the input_dim argument, which is the number of ids
# in eache example.
cardinality = 10
emb_layer = tf.keras.layers.Embedding(
input_dim=cardinality, output_dim=output_dim
)
feature_embs = emb_layer(inputs)
feature_weights = tf.random.uniform(tf.shape(feature_embs))
weighted_embs = feature_embs * feature_weights
reduction_axes = tf.range(1, len(weighted_embs.shape))
example_embs = tf.expand_dims(
tf.reduce_sum(weighted_embs, axis=reduction_axes), axis=-1
)
outputs = tf.keras.layers.Dense(1)(example_embs)
return tf.keras.Model(inputs=inputs, outputs=outputs)

View file

@ -13,16 +13,12 @@
# limitations under the License.
"""Utility functions that help in the computation of per-example gradient norms."""
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Text, Tuple, Union
from typing import Any, List, Optional, Set, Tuple
from absl import logging
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]]
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
def has_internal_compute_graph(input_object: Any):
@ -38,9 +34,9 @@ def has_internal_compute_graph(input_object: Any):
def model_forward_pass(
input_model: tf.keras.Model,
inputs: PackedTensors,
generator_fn: GeneratorFunction = None,
) -> Tuple[PackedTensors, List[Any]]:
inputs: type_aliases.PackedTensors,
generator_fn: type_aliases.GeneratorFunction = None,
) -> Tuple[type_aliases.PackedTensors, List[Any]]:
"""Does a forward pass of a model and returns useful intermediates.
NOTE: the graph traversal algorithm is an adaptation of the logic in the
@ -192,7 +188,7 @@ def add_aggregate_noise(
def generate_model_outputs_using_core_keras_layers(
input_model: tf.keras.Model,
custom_layer_set: Optional[Set[type]] = None, # pylint: disable=g-bare-generic
) -> PackedTensors:
) -> type_aliases.PackedTensors:
"""Returns the model outputs generated by only core Keras layers.
Args:

View file

@ -49,27 +49,12 @@ aggregation at the microbatch level.
# The detailed algorithm can be found in go/fast-dpsgd-mb.
# copybara.strip_end
from typing import Any, Callable, Dict, Iterable, Optional, Text, Tuple, Type, Union
from typing import Type
import tensorflow as tf
# ==============================================================================
# Type aliases
# ==============================================================================
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
OutputTensor = Union[tf.Tensor, Iterable[tf.Tensor]]
BatchSize = Union[int, tf.Tensor]
SquareNormFunction = Callable[[OutputTensor], tf.Tensor]
RegistryFunctionOutput = Tuple[Any, OutputTensor, SquareNormFunction]
RegistryFunction = Callable[
[Any, Tuple[Any, ...], Dict[Text, Any], tf.GradientTape],
RegistryFunctionOutput,
]
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions.dense import dense_layer_computation
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions.embedding import embedding_layer_computation
# ==============================================================================
@ -87,14 +72,16 @@ class LayerRegistry:
"""Checks if a layer instance's class is in the registry."""
return hash(layer_instance.__class__) in self._registry
def lookup(self, layer_instance: tf.keras.layers.Layer) -> RegistryFunction:
def lookup(
self, layer_instance: tf.keras.layers.Layer
) -> type_aliases.RegistryFunction:
"""Returns the layer registry function for a given layer instance."""
return self._registry[hash(layer_instance.__class__)]
def insert(
self,
layer_class: Type[tf.keras.layers.Layer],
layer_registry_function: RegistryFunction,
layer_registry_function: type_aliases.RegistryFunction,
):
"""Inserts a layer registry function into the internal dictionaries."""
layer_key = hash(layer_class)
@ -102,222 +89,6 @@ class LayerRegistry:
self._registry[layer_key] = layer_registry_function
# ==============================================================================
# Utilities
# ==============================================================================
def maybe_add_microbatch_axis(
x: tf.Tensor,
num_microbatches: Optional[BatchSize],
) -> tf.Tensor:
"""Adds the microbatch axis.
Args:
x: the input tensor.
num_microbatches: If None, x is returned unchanged. Otherwise, must divide
the batch size.
Returns:
The input tensor x, reshaped from [batch_size, ...] to
[num_microbatches, batch_size / num_microbatches, ...].
"""
if num_microbatches is None:
return x
with tf.control_dependencies(
[tf.assert_equal(tf.math.floormod(tf.shape(x)[0], num_microbatches), 0)]
):
return tf.reshape(
x, tf.concat([[num_microbatches, -1], tf.shape(x)[1:]], axis=0)
)
# ==============================================================================
# Supported Keras layers
# ==============================================================================
def dense_layer_computation(
layer_instance: tf.keras.layers.Dense,
input_args: Tuple[Any, ...],
input_kwargs: Dict[Text, Any],
tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None,
) -> RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Dense`.
The logic for this computation is based on the following paper:
https://arxiv.org/abs/1510.01799
For the sake of efficiency, we fuse the variables and square grad norms
for the kernel weights and bias vector together.
Args:
layer_instance: A `tf.keras.layers.Dense` instance.
input_args: A `tuple` containing the first part of `layer_instance` input.
Specifically, `layer_instance(*inputs_args, **input_kwargs)` should return
a valid output.
input_kwargs: A `tuple` containing the second part of `layer_instance`
input. Specifically, `layer_instance(*inputs_args, **input_kwargs)` should
return a valid output.
tape: A `tf.GradientTape` instance that will be used to watch the output
`base_vars`.
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
indicating whether and how the losses are grouped into microbatches. If
not None, num_microbatches must divide the batch size.
Returns:
A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the
intermediate Tensor used in the chain-rule / "fast" clipping trick,
`outputs` is the result of `layer_instance(*inputs)`, and `sqr_norm_fn` is
a function that takes one input, a `tf.Tensor` that represents the output
of the call `tape.gradient(summed_loss, base_vars)` where `tape` is a
`tf.GradientTape` instance that records the dense layer computation and
`summed_loss` is the sum of the per-example losses of the underlying model.
This function then returns the per-example squared L2 gradient norms of the
trainable variables in `layer_instance`. These squared norms should be a 1D
`tf.Tensor` of length `batch_size`.
"""
if input_kwargs:
raise ValueError("Dense layer calls should not receive kwargs.")
del input_kwargs # Unused in dense layer calls.
if len(input_args) != 1:
raise ValueError("Only layer inputs of length 1 are permitted.")
orig_activation = layer_instance.activation
layer_instance.activation = None
base_vars = layer_instance(*input_args)
tape.watch(base_vars)
layer_instance.activation = orig_activation
outputs = orig_activation(base_vars) if orig_activation else base_vars
def sqr_norm_fn(base_vars_grads):
def _compute_gramian(x):
if num_microbatches is not None:
x_microbatched = maybe_add_microbatch_axis(x, num_microbatches)
return tf.matmul(x_microbatched, x_microbatched, transpose_b=True)
else:
# Special handling for better efficiency
return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x)))
inputs_gram = _compute_gramian(*input_args)
base_vars_grads_gram = _compute_gramian(base_vars_grads)
if layer_instance.use_bias:
# Adding a bias term is equivalent to a layer with no bias term and which
# adds an additional variable to the layer input that only takes a
# constant value of 1.0. This is thus equivalent to adding 1.0 to the sum
# of the squared values of the inputs.
inputs_gram += 1.0
return tf.reduce_sum(
inputs_gram * base_vars_grads_gram,
axis=tf.range(1, tf.rank(inputs_gram)),
)
return base_vars, outputs, sqr_norm_fn
def embedding_layer_computation(
layer_instance: tf.keras.layers.Embedding,
input_args: Tuple[Any, ...],
input_kwargs: Dict[Text, Any],
tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None,
) -> RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Embedding`.
The logic of this computation is based on the `tf.keras.layers.Dense`
computation and the fact that an embedding layer is just a dense layer
with no activation function and an output vector of the form X*W for input
X, where the i-th row of W is the i-th embedding vector and the j-th row of
X is a one-hot vector representing the input of example j.
Args:
layer_instance: A `tf.keras.layers.Embedding` instance.
input_args: See `dense_layer_computation()`.
input_kwargs: See `dense_layer_computation()`.
tape: See `dense_layer_computation()`.
num_microbatches: See `dense_layer_computation()`.
Returns:
See `dense_layer_computation()`.
"""
if input_kwargs:
raise ValueError("Embedding layer calls should not receive kwargs.")
del input_kwargs # Unused in embedding layer calls.
if len(input_args) != 1:
raise ValueError("Only layer inputs of length 1 are permitted.")
if hasattr(layer_instance, "sparse"): # for backwards compatibility
if layer_instance.sparse:
raise NotImplementedError("Sparse output tensors are not supported.")
if isinstance(input_args[0], tf.SparseTensor):
raise NotImplementedError("Sparse input tensors are not supported.")
# Disable experimental features.
if hasattr(layer_instance, "_use_one_hot_matmul"):
if layer_instance._use_one_hot_matmul: # pylint: disable=protected-access
raise NotImplementedError(
"The experimental embedding feature"
"'_use_one_hot_matmul' is not supported."
)
input_ids = tf.cast(*input_args, tf.int32)
base_vars = layer_instance.trainable_variables[0]
tape.watch(base_vars)
outputs = tf.nn.embedding_lookup(base_vars, input_ids)
def sqr_norm_fn(base_vars_grads):
# Get a 1D tensor of the row indices.
nrows = tf.shape(input_ids)[0]
if isinstance(input_ids, tf.RaggedTensor):
row_indices = tf.expand_dims(
input_ids.merge_dims(1, -1).value_rowids(), axis=-1
)
elif isinstance(input_ids, tf.Tensor):
ncols = tf.reduce_prod(tf.shape(input_ids)[1:])
repeats = tf.repeat(ncols, nrows)
row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1])
else:
raise NotImplementedError(
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
)
row_indices = tf.cast(row_indices, tf.int32)
if num_microbatches is not None:
microbatch_size = tf.cast(nrows / num_microbatches, tf.int32)
nrows = num_microbatches
row_indices = tf.cast(
tf.math.floordiv(row_indices, microbatch_size), tf.int32
)
# Sum-reduce the `IndexSlices` that is the result of a `tape.gradient()`
# call. The sum is reduced by the repeated embedding indices and batch
# index. It is adapted from the logic in:
# tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices
if not isinstance(base_vars_grads, tf.IndexedSlices):
raise NotImplementedError(
"Cannot parse embedding gradients of type: %s"
% base_vars_grads.__class__.__name__
)
slice_indices = tf.expand_dims(base_vars_grads.indices, axis=-1)
paired_indices = tf.concat(
[tf.cast(row_indices, tf.int64), tf.cast(slice_indices, tf.int64)],
axis=1,
)
(unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2(
x=paired_indices, axis=[0]
)
unique_batch_ids = unique_paired_indices[:, 0]
summed_gradients = tf.math.unsorted_segment_sum(
base_vars_grads.values,
new_index_positions,
tf.shape(unique_paired_indices)[0],
)
# Compute the squared gradient norms at the per-example level.
sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1)
summed_data_range = tf.range(tf.shape(sqr_gradient_sum)[0])
return tf.sparse.segment_sum(
sqr_gradient_sum,
summed_data_range,
tf.sort(unique_batch_ids),
num_segments=nrows,
) # fill in empty inputs
return base_vars, outputs, sqr_norm_fn
# ==============================================================================
# Main factory methods
# ==============================================================================

View file

@ -0,0 +1,50 @@
load("@rules_python//python:defs.bzl", "py_library", "py_test")
package(
default_visibility = ["//visibility:public"],
)
py_library(
name = "dense",
srcs = ["dense.py"],
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
],
)
py_test(
name = "dense_test",
size = "large",
srcs = ["dense_test.py"],
python_version = "PY3",
shard_count = 8,
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
],
)
py_library(
name = "embedding",
srcs = ["embedding.py"],
srcs_version = "PY3",
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases"],
)
py_test(
name = "embedding_test",
size = "large",
srcs = ["embedding_test.py"],
python_version = "PY3",
shard_count = 8,
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
],
)

View file

@ -0,0 +1,100 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast clipping function for `tf.keras.layers.Dense`."""
from typing import Any, Dict, Optional, Text, Tuple
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
def dense_layer_computation(
layer_instance: tf.keras.layers.Dense,
input_args: Tuple[Any, ...],
input_kwargs: Dict[Text, Any],
tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None,
) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Dense`.
The logic for this computation is based on the following paper:
https://arxiv.org/abs/1510.01799
For the sake of efficiency, we fuse the variables and square grad norms
for the kernel weights and bias vector together.
Args:
layer_instance: A `tf.keras.layers.Dense` instance.
input_args: A `tuple` containing the first part of `layer_instance` input.
Specifically, `layer_instance(*inputs_args, **input_kwargs)` should return
a valid output.
input_kwargs: A `tuple` containing the second part of `layer_instance`
input. Specifically, `layer_instance(*inputs_args, **input_kwargs)` should
return a valid output.
tape: A `tf.GradientTape` instance that will be used to watch the output
`base_vars`.
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
indicating whether and how the losses are grouped into microbatches. If
not None, num_microbatches must divide the batch size.
Returns:
A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the
intermediate Tensor used in the chain-rule / "fast" clipping trick,
`outputs` is the result of `layer_instance(*inputs)`, and `sqr_norm_fn` is
a function that takes one input, a `tf.Tensor` that represents the output
of the call `tape.gradient(summed_loss, base_vars)` where `tape` is a
`tf.GradientTape` instance that records the dense layer computation and
`summed_loss` is the sum of the per-example losses of the underlying model.
This function then returns the per-example squared L2 gradient norms of the
trainable variables in `layer_instance`. These squared norms should be a 1D
`tf.Tensor` of length `batch_size`.
"""
if input_kwargs:
raise ValueError("Dense layer calls should not receive kwargs.")
del input_kwargs # Unused in dense layer calls.
if len(input_args) != 1:
raise ValueError("Only layer inputs of length 1 are permitted.")
orig_activation = layer_instance.activation
layer_instance.activation = None
base_vars = layer_instance(*input_args)
tape.watch(base_vars)
layer_instance.activation = orig_activation
outputs = orig_activation(base_vars) if orig_activation else base_vars
def sqr_norm_fn(base_vars_grads):
def _compute_gramian(x):
if num_microbatches is not None:
x_microbatched = common_manip_utils.maybe_add_microbatch_axis(
x,
num_microbatches,
)
return tf.matmul(x_microbatched, x_microbatched, transpose_b=True)
else:
# Special handling for better efficiency
return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x)))
inputs_gram = _compute_gramian(*input_args)
base_vars_grads_gram = _compute_gramian(base_vars_grads)
if layer_instance.use_bias:
# Adding a bias term is equivalent to a layer with no bias term and which
# adds an additional variable to the layer input that only takes a
# constant value of 1.0. This is thus equivalent to adding 1.0 to the sum
# of the squared values of the inputs.
inputs_gram += 1.0
return tf.reduce_sum(
inputs_gram * base_vars_grads_gram,
axis=tf.range(1, tf.rank(inputs_gram)),
)
return base_vars, outputs, sqr_norm_fn

View file

@ -0,0 +1,100 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
# ==============================================================================
# Helper functions.
# ==============================================================================
def get_dense_layer_generators():
def sigmoid_dense_layer(b):
return tf.keras.layers.Dense(b, activation='sigmoid')
return {
'pure_dense': lambda a, b: tf.keras.layers.Dense(b),
'sigmoid_dense': lambda a, b: sigmoid_dense_layer(b),
}
def get_dense_model_generators():
return {
'seq1': common_test_utils.make_two_layer_sequential_model,
'seq2': common_test_utils.make_three_layer_sequential_model,
'func1': common_test_utils.make_two_layer_functional_model,
'tower1': common_test_utils.make_two_tower_model,
}
# ==============================================================================
# Main tests.
# ==============================================================================
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
model_name=list(get_dense_model_generators().keys()),
layer_name=list(get_dense_layer_generators().keys()),
input_dim=[4],
output_dim=[2],
per_example_loss_fn=[None, common_test_utils.test_loss_fn],
num_microbatches=[None, 1, 2],
is_eager=[True, False],
partial=[True, False],
weighted=[True, False],
)
def test_gradient_norms_on_various_models(
self,
model_name,
layer_name,
input_dim,
output_dim,
per_example_loss_fn,
num_microbatches,
is_eager,
partial,
weighted,
):
model_generator = get_dense_model_generators()[model_name]
layer_generator = get_dense_layer_generators()[layer_name]
x_batches, weight_batches = common_test_utils.get_nd_test_batches(input_dim)
default_registry = layer_registry.make_default_layer_registry()
for x_batch, weight_batch in zip(x_batches, weight_batches):
batch_size = x_batch.shape[0]
if num_microbatches is not None and batch_size % num_microbatches != 0:
continue
computed_norms, true_norms = (
common_test_utils.get_computed_and_true_norms(
model_generator,
layer_generator,
input_dim,
output_dim,
per_example_loss_fn,
num_microbatches,
is_eager,
x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch,
weight_batch=weight_batch if weighted else None,
registry=default_registry,
partial=partial,
)
)
expected_size = num_microbatches or batch_size
self.assertEqual(computed_norms.shape[0], expected_size)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,124 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast clipping function for `tf.keras.layers.Embedding`."""
from typing import Any, Dict, Optional, Text, Tuple
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
def embedding_layer_computation(
layer_instance: tf.keras.layers.Embedding,
input_args: Tuple[Any, ...],
input_kwargs: Dict[Text, Any],
tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None,
) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Embedding`.
The logic of this computation is based on the `tf.keras.layers.Dense`
computation and the fact that an embedding layer is just a dense layer
with no activation function and an output vector of the form X*W for input
X, where the i-th row of W is the i-th embedding vector and the j-th row of
X is a one-hot vector representing the input of example j.
Args:
layer_instance: A `tf.keras.layers.Embedding` instance.
input_args: See `dense_layer_computation()` in `dense.py`.
input_kwargs: See `dense_layer_computation()` in `dense.py`.
tape: See `dense_layer_computation()` in `dense.py`.
num_microbatches: See `dense_layer_computation()` in `dense.py`.
Returns:
See `dense_layer_computation()` in `dense.py`.
"""
if input_kwargs:
raise ValueError("Embedding layer calls should not receive kwargs.")
del input_kwargs # Unused in embedding layer calls.
if len(input_args) != 1:
raise ValueError("Only layer inputs of length 1 are permitted.")
if hasattr(layer_instance, "sparse"): # for backwards compatibility
if layer_instance.sparse:
raise NotImplementedError("Sparse output tensors are not supported.")
if isinstance(input_args[0], tf.SparseTensor):
raise NotImplementedError("Sparse input tensors are not supported.")
# Disable experimental features.
if hasattr(layer_instance, "_use_one_hot_matmul"):
if layer_instance._use_one_hot_matmul: # pylint: disable=protected-access
raise NotImplementedError(
"The experimental embedding feature"
"'_use_one_hot_matmul' is not supported."
)
input_ids = tf.cast(*input_args, tf.int32)
base_vars = layer_instance.trainable_variables[0]
tape.watch(base_vars)
outputs = tf.nn.embedding_lookup(base_vars, input_ids)
def sqr_norm_fn(base_vars_grads):
# Get a 1D tensor of the row indices.
nrows = tf.shape(input_ids)[0]
if isinstance(input_ids, tf.RaggedTensor):
row_indices = tf.expand_dims(
input_ids.merge_dims(1, -1).value_rowids(), axis=-1
)
elif isinstance(input_ids, tf.Tensor):
ncols = tf.reduce_prod(tf.shape(input_ids)[1:])
repeats = tf.repeat(ncols, nrows)
row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1])
else:
raise NotImplementedError(
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
)
row_indices = tf.cast(row_indices, tf.int32)
if num_microbatches is not None:
microbatch_size = tf.cast(nrows / num_microbatches, tf.int32)
nrows = num_microbatches
row_indices = tf.cast(
tf.math.floordiv(row_indices, microbatch_size), tf.int32
)
# Sum-reduce the `IndexSlices` that is the result of a `tape.gradient()`
# call. The sum is reduced by the repeated embedding indices and batch
# index. It is adapted from the logic in:
# tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices
if not isinstance(base_vars_grads, tf.IndexedSlices):
raise NotImplementedError(
"Cannot parse embedding gradients of type: %s"
% base_vars_grads.__class__.__name__
)
slice_indices = tf.expand_dims(base_vars_grads.indices, axis=-1)
paired_indices = tf.concat(
[tf.cast(row_indices, tf.int64), tf.cast(slice_indices, tf.int64)],
axis=1,
)
(unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2(
x=paired_indices, axis=[0]
)
unique_batch_ids = unique_paired_indices[:, 0]
summed_gradients = tf.math.unsorted_segment_sum(
base_vars_grads.values,
new_index_positions,
tf.shape(unique_paired_indices)[0],
)
# Compute the squared gradient norms at the per-example level.
sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1)
summed_data_range = tf.range(tf.shape(sqr_gradient_sum)[0])
return tf.sparse.segment_sum(
sqr_gradient_sum,
summed_data_range,
tf.sort(unique_batch_ids),
num_segments=nrows,
) # fill in empty inputs
return base_vars, outputs, sqr_norm_fn

View file

@ -0,0 +1,111 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
# ==============================================================================
# Helper functions.
# ==============================================================================
def get_embedding_model_generators():
return {
'bow1': common_test_utils.make_bow_model,
'bow2': common_test_utils.make_dense_bow_model,
'weighted_bow1': common_test_utils.make_weighted_bow_model,
}
# ==============================================================================
# Main tests.
# ==============================================================================
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
# supports them for embeddings.
@parameterized.product(
x_batch=[
# 2D inputs.
tf.convert_to_tensor([[0, 1]], dtype_hint=tf.int32),
tf.convert_to_tensor([[0, 1], [1, 1], [0, 0]], dtype_hint=tf.int32),
tf.ragged.constant(
[[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32
),
tf.ragged.constant(
[[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1], [0, 1]],
dtype=tf.int32,
),
# 3D inputs.
tf.convert_to_tensor([[[0, 1]]], dtype_hint=tf.int32),
tf.convert_to_tensor(
[[[0, 1]], [[1, 1]], [[0, 0]]], dtype_hint=tf.int32
),
tf.ragged.constant(
[[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]],
dtype=tf.int32,
),
tf.ragged.constant(
[[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]], [[0]]],
dtype=tf.int32,
),
],
model_name=list(get_embedding_model_generators().keys()),
output_dim=[2],
per_example_loss_fn=[None, common_test_utils.test_loss_fn],
num_microbatches=[None, 2],
is_eager=[True, False],
partial=[True, False],
)
def test_gradient_norms_on_various_models(
self,
x_batch,
model_name,
output_dim,
per_example_loss_fn,
num_microbatches,
is_eager,
partial,
):
batch_size = x_batch.shape[0]
# The following are invalid test combinations, and are skipped.
if (
num_microbatches is not None and batch_size % num_microbatches != 0
) or (
model_name == 'weighted_bow1' and isinstance(x_batch, tf.RaggedTensor)
):
return
default_registry = layer_registry.make_default_layer_registry()
model_generator = get_embedding_model_generators()[model_name]
computed_norms, true_norms = (
common_test_utils.get_computed_and_true_norms(
model_generator=model_generator,
layer_generator=None,
input_dims=x_batch.shape[1:],
output_dim=output_dim,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
is_eager=is_eager,
x_batch=x_batch,
registry=default_registry,
partial=partial,
)
)
self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size)
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,49 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A collection of type aliases used throughout the clipping library."""
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
import tensorflow as tf
# Tensorflow aliases.
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
InputTensors = PackedTensors
OutputTensors = Union[tf.Tensor, Iterable[tf.Tensor]]
BatchSize = Union[int, tf.Tensor]
LossFn = Callable[..., tf.Tensor]
# Layer Registry aliases.
SquareNormFunction = Callable[[OutputTensors], tf.Tensor]
RegistryFunctionOutput = Tuple[Any, OutputTensors, SquareNormFunction]
RegistryFunction = Callable[
[Any, Tuple[Any, ...], Dict[Text, Any], tf.GradientTape],
RegistryFunctionOutput,
]
# Clipping aliases.
GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]]
# Testing aliases.
LayerGenerator = Callable[[int, int], tf.keras.layers.Layer]
ModelGenerator = Callable[
[LayerGenerator, Union[int, List[int]], int], tf.keras.Model
]

View file

@ -17,6 +17,7 @@ py_library(
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
],

View file

@ -16,8 +16,8 @@
from absl import logging
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
_PRIVATIZED_LOSS_NAME = 'privatized_loss'
@ -143,7 +143,8 @@ def make_dp_model_class(cls):
if (
layer_registry is not None
and gradient_clipping_utils.all_trainable_layers_are_registered(
self, layer_registry
self,
layer_registry,
)
and gradient_clipping_utils.has_internal_compute_graph(self)
):
@ -273,10 +274,16 @@ def make_dp_model_class(cls):
grads = clipped_grads
else:
logging.info('Computing gradients using original clipping algorithm.')
# Computes per-example clipped gradients directly. This is called
# if at least one of the layers cannot use the "fast" gradient clipping
# algorithm.
reshape_fn = lambda z: lr.maybe_add_microbatch_axis(z, num_microbatches)
def reshape_fn(z):
return common_manip_utils.maybe_add_microbatch_axis(
z,
num_microbatches,
)
microbatched_data = tf.nest.map_structure(reshape_fn, data)
clipped_grads = tf.vectorized_map(
self._compute_per_example_grads,