forked from 626_privacy/tensorflow_privacy
Add support for microbatching in the tf.keras.layers.LayerNormalization
fast square norm function.
PiperOrigin-RevId: 565050132
This commit is contained in:
parent
bcc0d4927e
commit
c7db4fa8cb
3 changed files with 19 additions and 12 deletions
|
@ -102,7 +102,10 @@ py_library(
|
||||||
name = "layer_normalization",
|
name = "layer_normalization",
|
||||||
srcs = ["layer_normalization.py"],
|
srcs = ["layer_normalization.py"],
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases"],
|
deps = [
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
|
|
@ -15,18 +15,13 @@
|
||||||
|
|
||||||
from typing import Any, Mapping, Tuple, Union
|
from typing import Any, Mapping, Tuple, Union
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Supported Keras layers
|
# Supported Keras layers
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
def _sqr_norm_fn(grads):
|
|
||||||
stacked_grads = tf.stack(grads, axis=-1)
|
|
||||||
reduction_axes = tf.range(1, tf.rank(stacked_grads))
|
|
||||||
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)
|
|
||||||
|
|
||||||
|
|
||||||
def layer_normalization_computation(
|
def layer_normalization_computation(
|
||||||
layer_instance: tf.keras.layers.LayerNormalization,
|
layer_instance: tf.keras.layers.LayerNormalization,
|
||||||
input_args: Tuple[Any, ...],
|
input_args: Tuple[Any, ...],
|
||||||
|
@ -51,9 +46,6 @@ def layer_normalization_computation(
|
||||||
See `dense_layer_computation()` in `dense.py`.
|
See `dense_layer_computation()` in `dense.py`.
|
||||||
"""
|
"""
|
||||||
del input_kwargs # Unused in layer normaliztion calls.
|
del input_kwargs # Unused in layer normaliztion calls.
|
||||||
if num_microbatches is not None:
|
|
||||||
raise NotImplementedError("Microbatching is not currently supported.")
|
|
||||||
|
|
||||||
# To make sure the watched variables (beta, gamma) generate per-example
|
# To make sure the watched variables (beta, gamma) generate per-example
|
||||||
# gradients, we need to convert trainable variables from shape [S] to
|
# gradients, we need to convert trainable variables from shape [S] to
|
||||||
# [batch_size, S] via duplication to `tf.shape(inputs)` via broadcasting.
|
# [batch_size, S] via duplication to `tf.shape(inputs)` via broadcasting.
|
||||||
|
@ -86,4 +78,14 @@ def layer_normalization_computation(
|
||||||
layer_instance.gamma = orig_gamma
|
layer_instance.gamma = orig_gamma
|
||||||
layer_instance.beta = orig_beta
|
layer_instance.beta = orig_beta
|
||||||
|
|
||||||
return base_vars, outputs, _sqr_norm_fn
|
def sqr_norm_fn(grads):
|
||||||
|
stacked_grads = tf.stack(grads, axis=-1)
|
||||||
|
if num_microbatches is not None:
|
||||||
|
stacked_grads = common_manip_utils.maybe_add_microbatch_axis(
|
||||||
|
grads, num_microbatches
|
||||||
|
)
|
||||||
|
stacked_grads = tf.reduce_sum(stacked_grads, axis=1)
|
||||||
|
reduction_axes = tf.range(1, tf.rank(stacked_grads))
|
||||||
|
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)
|
||||||
|
|
||||||
|
return base_vars, outputs, sqr_norm_fn
|
||||||
|
|
|
@ -87,6 +87,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
layer_name=list(get_layer_norm_layer_generators().keys()),
|
layer_name=list(get_layer_norm_layer_generators().keys()),
|
||||||
parameter_tuple=get_layer_norm_parameter_tuples(),
|
parameter_tuple=get_layer_norm_parameter_tuples(),
|
||||||
layer_registry_name=list(get_layer_norm_registries().keys()),
|
layer_registry_name=list(get_layer_norm_registries().keys()),
|
||||||
|
num_microbatches=[None, 2],
|
||||||
is_eager=[True, False],
|
is_eager=[True, False],
|
||||||
)
|
)
|
||||||
def test_gradient_norms_on_various_models(
|
def test_gradient_norms_on_various_models(
|
||||||
|
@ -95,6 +96,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
layer_name,
|
layer_name,
|
||||||
parameter_tuple,
|
parameter_tuple,
|
||||||
layer_registry_name,
|
layer_registry_name,
|
||||||
|
num_microbatches,
|
||||||
is_eager,
|
is_eager,
|
||||||
):
|
):
|
||||||
# Parse inputs to generate test data.
|
# Parse inputs to generate test data.
|
||||||
|
@ -121,7 +123,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
return common_test_utils.get_computed_and_true_norms_from_model(
|
return common_test_utils.get_computed_and_true_norms_from_model(
|
||||||
model=model,
|
model=model,
|
||||||
per_example_loss_fn=None,
|
per_example_loss_fn=None,
|
||||||
num_microbatches=None,
|
num_microbatches=num_microbatches,
|
||||||
x_batch=[x_batch, x_batch] if model_name == 'tower2' else x_batch,
|
x_batch=[x_batch, x_batch] if model_name == 'tower2' else x_batch,
|
||||||
weight_batch=None,
|
weight_batch=None,
|
||||||
registry=get_layer_norm_registries()[layer_registry_name],
|
registry=get_layer_norm_registries()[layer_registry_name],
|
||||||
|
|
Loading…
Reference in a new issue