forked from 626_privacy/tensorflow_privacy
Add utility functions for unwrapping BERT encoder layers into individual Keras layers.
PiperOrigin-RevId: 588419989
This commit is contained in:
parent
93376c9d6a
commit
fbe5879023
3 changed files with 254 additions and 0 deletions
|
@ -6,6 +6,20 @@ py_library(
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "bert_encoder_utils",
|
||||||
|
srcs = ["bert_encoder_utils.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [":gradient_clipping_utils"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "bert_encoder_utils_test",
|
||||||
|
srcs = ["bert_encoder_utils_test.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [":bert_encoder_utils"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "common_manip_utils",
|
name = "common_manip_utils",
|
||||||
srcs = ["common_manip_utils.py"],
|
srcs = ["common_manip_utils.py"],
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Utility functions for manipulating official Tensorflow BERT encoders."""
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_models as tfm
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
|
|
||||||
|
|
||||||
|
def dedup_bert_encoder(input_bert_encoder: tfm.nlp.networks.BertEncoder):
|
||||||
|
"""Deduplicates the layer names in a BERT encoder."""
|
||||||
|
|
||||||
|
def _dedup(layer, attr_name, new_name):
|
||||||
|
sublayer = getattr(layer, attr_name)
|
||||||
|
if sublayer is None:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
sublayer_config = sublayer.get_config()
|
||||||
|
sublayer_config["name"] = new_name
|
||||||
|
setattr(layer, attr_name, sublayer.from_config(sublayer_config))
|
||||||
|
|
||||||
|
for layer in input_bert_encoder.layers:
|
||||||
|
# NOTE: the ordering of the renames is important for the ordering of the
|
||||||
|
# variables in the computed gradients. This is why we use three `for-loop`
|
||||||
|
# instead of one.
|
||||||
|
if isinstance(layer, tfm.nlp.layers.TransformerEncoderBlock):
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
for attr_name in ["inner_dropout_layer", "attention_dropout"]:
|
||||||
|
_dedup(layer, "_" + attr_name, layer.name + "/" + attr_name)
|
||||||
|
# Some layers are nested within the main attention layer (if it exists).
|
||||||
|
if layer._attention_layer is not None:
|
||||||
|
prefix = layer.name + "/" + layer._attention_layer.name
|
||||||
|
_dedup(layer, "_attention_layer", prefix + "/attention_layer")
|
||||||
|
_dedup(
|
||||||
|
layer._attention_layer,
|
||||||
|
"_dropout_layer",
|
||||||
|
prefix + "/attention_inner_dropout_layer",
|
||||||
|
)
|
||||||
|
for attr_name in ["attention_layer_norm", "intermediate_dense"]:
|
||||||
|
_dedup(layer, "_" + attr_name, layer.name + "/" + attr_name)
|
||||||
|
# This is one of the few times that we cannot build from a config, due
|
||||||
|
# to the presence of lambda functions.
|
||||||
|
if layer._intermediate_activation_layer is not None:
|
||||||
|
policy = tf.keras.mixed_precision.global_policy()
|
||||||
|
if policy.name == "mixed_bfloat16":
|
||||||
|
policy = tf.float32
|
||||||
|
layer._intermediate_activation_layer = tf.keras.layers.Activation(
|
||||||
|
layer._inner_activation,
|
||||||
|
dtype=policy,
|
||||||
|
name=layer.name + "/intermediate_activation_layer",
|
||||||
|
)
|
||||||
|
for attr_name in ["output_dense", "output_dropout", "output_layer_norm"]:
|
||||||
|
_dedup(layer, "_" + attr_name, layer.name + "/" + attr_name)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
def get_unwrapped_bert_encoder(
|
||||||
|
input_bert_encoder: tfm.nlp.networks.BertEncoder,
|
||||||
|
) -> tfm.nlp.networks.BertEncoder:
|
||||||
|
"""Creates a new BERT encoder whose layers are core Keras layers."""
|
||||||
|
dedup_bert_encoder(input_bert_encoder)
|
||||||
|
core_test_outputs = (
|
||||||
|
gradient_clipping_utils.generate_model_outputs_using_core_keras_layers(
|
||||||
|
input_bert_encoder,
|
||||||
|
custom_layer_set={tfm.nlp.layers.TransformerEncoderBlock},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tf.keras.Model(
|
||||||
|
inputs=input_bert_encoder.inputs,
|
||||||
|
outputs=core_test_outputs,
|
||||||
|
)
|
|
@ -0,0 +1,158 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests of `bert_encoder_utils.py`."""
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_models as tfm
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import bert_encoder_utils
|
||||||
|
|
||||||
|
|
||||||
|
def compute_bert_sample_inputs(
|
||||||
|
batch_size, sequence_length, vocab_size, num_types
|
||||||
|
):
|
||||||
|
"""Returns a set of BERT encoder inputs."""
|
||||||
|
word_id_sample = np.random.randint(
|
||||||
|
vocab_size, size=(batch_size, sequence_length)
|
||||||
|
)
|
||||||
|
mask_sample = np.random.randint(2, size=(batch_size, sequence_length))
|
||||||
|
type_id_sample = np.random.randint(
|
||||||
|
num_types,
|
||||||
|
size=(batch_size, sequence_length),
|
||||||
|
)
|
||||||
|
return [word_id_sample, mask_sample, type_id_sample]
|
||||||
|
|
||||||
|
|
||||||
|
def get_small_bert_encoder_and_sample_inputs(dict_outputs=False):
|
||||||
|
"""Returns a small BERT encoder for testing."""
|
||||||
|
hidden_size = 2
|
||||||
|
vocab_size = 3
|
||||||
|
num_types = 4
|
||||||
|
max_sequence_length = 5
|
||||||
|
inner_dense_units = 6
|
||||||
|
output_range = 1
|
||||||
|
num_heads = 2
|
||||||
|
num_transformer_layers = 3
|
||||||
|
seed = 777
|
||||||
|
|
||||||
|
bert_encoder = tfm.nlp.networks.BertEncoder(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_attention_heads=num_heads,
|
||||||
|
num_layers=num_transformer_layers,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
inner_dim=inner_dense_units,
|
||||||
|
type_vocab_size=num_types,
|
||||||
|
output_range=output_range,
|
||||||
|
initializer=tf.keras.initializers.GlorotUniform(seed),
|
||||||
|
dict_outputs=dict_outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size = 3
|
||||||
|
bert_sample_inputs = compute_bert_sample_inputs(
|
||||||
|
batch_size,
|
||||||
|
max_sequence_length,
|
||||||
|
vocab_size,
|
||||||
|
num_types,
|
||||||
|
)
|
||||||
|
|
||||||
|
return bert_encoder, bert_sample_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def get_shared_trainable_variables(model1, model2):
|
||||||
|
"""Returns the shared trainable variables (by name) between models."""
|
||||||
|
common_names = {v.name for v in model1.trainable_variables} & {
|
||||||
|
v.name for v in model2.trainable_variables
|
||||||
|
}
|
||||||
|
tvars1 = [v for v in model1.trainable_variables if v.name in common_names]
|
||||||
|
tvars2 = [v for v in model2.trainable_variables if v.name in common_names]
|
||||||
|
return tvars1, tvars2
|
||||||
|
|
||||||
|
|
||||||
|
def custom_reduced_loss(y_batch, y_pred):
|
||||||
|
del y_batch
|
||||||
|
# Create a loss multiplier to avoid small gradients.
|
||||||
|
large_value_multiplier = 1e10
|
||||||
|
sqr_outputs = []
|
||||||
|
for t in y_pred:
|
||||||
|
reduction_axes = tf.range(1, len(t.shape))
|
||||||
|
sqr_outputs.append(tf.reduce_sum(tf.square(t), axis=reduction_axes))
|
||||||
|
sqr_tsr = tf.stack(sqr_outputs, axis=1)
|
||||||
|
return large_value_multiplier * tf.reduce_sum(sqr_tsr, axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
class BertEncoderUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def test_outputs_are_equal(self):
|
||||||
|
true_encoder, sample_inputs = get_small_bert_encoder_and_sample_inputs()
|
||||||
|
unwrapped_encoder = bert_encoder_utils.get_unwrapped_bert_encoder(
|
||||||
|
true_encoder
|
||||||
|
)
|
||||||
|
true_outputs = true_encoder(sample_inputs)
|
||||||
|
computed_outputs = unwrapped_encoder(sample_inputs)
|
||||||
|
self.assertAllClose(true_outputs, computed_outputs)
|
||||||
|
|
||||||
|
def test_shared_trainable_variables_are_equal(self):
|
||||||
|
true_encoder, sample_inputs = get_small_bert_encoder_and_sample_inputs()
|
||||||
|
unwrapped_encoder = bert_encoder_utils.get_unwrapped_bert_encoder(
|
||||||
|
true_encoder
|
||||||
|
)
|
||||||
|
# Initializes the trainable variable shapes.
|
||||||
|
true_encoder(sample_inputs)
|
||||||
|
unwrapped_encoder(sample_inputs)
|
||||||
|
# The official BERT encoder may initialize trainable variables that are
|
||||||
|
# not used in a model forward pass. Hence, they are invisible when we
|
||||||
|
# try to unwrapping layers using our utility function.
|
||||||
|
true_vars, computed_vars = get_shared_trainable_variables(
|
||||||
|
true_encoder, unwrapped_encoder
|
||||||
|
)
|
||||||
|
self.assertAllClose(true_vars, computed_vars)
|
||||||
|
|
||||||
|
def test_shared_gradients_are_equal(self):
|
||||||
|
true_encoder, sample_inputs = get_small_bert_encoder_and_sample_inputs()
|
||||||
|
unwrapped_encoder = bert_encoder_utils.get_unwrapped_bert_encoder(
|
||||||
|
true_encoder
|
||||||
|
)
|
||||||
|
# Create a loss multiplier to avoid small gradients.
|
||||||
|
dummy_labels = None
|
||||||
|
with tf.GradientTape(persistent=True) as tape:
|
||||||
|
true_outputs = true_encoder(sample_inputs)
|
||||||
|
true_sqr_sum = tf.reduce_sum(
|
||||||
|
custom_reduced_loss(dummy_labels, true_outputs)
|
||||||
|
)
|
||||||
|
computed_outputs = unwrapped_encoder(sample_inputs)
|
||||||
|
computed_sqr_sum = tf.reduce_sum(
|
||||||
|
custom_reduced_loss(dummy_labels, computed_outputs)
|
||||||
|
)
|
||||||
|
# The official BERT encoder may initialize trainable variables that are
|
||||||
|
# not used in a model forward pass. Hence, they are invisible when we
|
||||||
|
# try to unwrapping layers using our utility function.
|
||||||
|
true_vars, computed_vars = get_shared_trainable_variables(
|
||||||
|
true_encoder, unwrapped_encoder
|
||||||
|
)
|
||||||
|
true_grads = tape.gradient(true_sqr_sum, true_vars)
|
||||||
|
computed_grads = tape.gradient(computed_sqr_sum, computed_vars)
|
||||||
|
self.assertEqual(len(true_grads), len(computed_grads))
|
||||||
|
for g1, g2 in zip(true_grads, computed_grads):
|
||||||
|
self.assertEqual(type(g1), type(g2))
|
||||||
|
if isinstance(g1, tf.IndexedSlices):
|
||||||
|
self.assertAllClose(g1.values, g2.values)
|
||||||
|
self.assertAllEqual(g2.indices, g2.indices)
|
||||||
|
else:
|
||||||
|
self.assertAllClose(g1, g2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
Loading…
Reference in a new issue