Support models with unconnected layers and gradients when training using a DP vectorized optimizer.
PiperOrigin-RevId: 450659644
This commit is contained in:
parent
5509adb296
commit
95e527acfb
2 changed files with 207 additions and 1 deletions
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -469,5 +470,199 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertNotAllClose([3.0], var1)
|
self.assertNotAllClose([3.0], var1)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleEmbeddingModel(tf.keras.Model):
|
||||||
|
"""Simple embedding model."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.embed_layer = tf.keras.layers.Embedding(
|
||||||
|
name='embedding',
|
||||||
|
input_dim=10, # vocabulary size.
|
||||||
|
output_dim=6, # embedding size.
|
||||||
|
embeddings_initializer='uniform',
|
||||||
|
input_length=4) # sequence length.
|
||||||
|
self.pool_layer = tf.keras.layers.Dense(
|
||||||
|
name='pooler',
|
||||||
|
units=6,
|
||||||
|
activation='tanh',
|
||||||
|
kernel_initializer='zeros',
|
||||||
|
bias_initializer='zeros')
|
||||||
|
self.probs_layer = tf.keras.layers.Dense(
|
||||||
|
units=1, activation='softmax', name='classification')
|
||||||
|
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
# The shape of the sequence output from the embedding layer is
|
||||||
|
# [batch_size, sequence_length, embedding_size]
|
||||||
|
sequence_output = self.embed_layer(inputs)
|
||||||
|
first_token_tensor = tf.squeeze(sequence_output[:, 0:1, :], axis=1)
|
||||||
|
# The shape of the pooled output from the embedding layer is
|
||||||
|
# [batch_size, embedding_size]
|
||||||
|
pooled_output = self.pool_layer(first_token_tensor)
|
||||||
|
return sequence_output, pooled_output
|
||||||
|
|
||||||
|
|
||||||
|
def keras_embedding_model_fn(opt_cls,
|
||||||
|
l2_norm_clip: float,
|
||||||
|
noise_multiplier: float,
|
||||||
|
num_microbatches: int,
|
||||||
|
learning_rate: float,
|
||||||
|
use_seq_output: bool = False,
|
||||||
|
unconnected_gradients_to_zero: bool = False):
|
||||||
|
"""Construct a simple embedding model with a classification layer."""
|
||||||
|
|
||||||
|
# Every sample has 4 tokens (sequence length=4).
|
||||||
|
x = tf.keras.layers.Input(shape=(4,), dtype=tf.float32, name='input')
|
||||||
|
sequence_output, pooled_output = SimpleEmbeddingModel()(x)
|
||||||
|
if use_seq_output:
|
||||||
|
embedding = sequence_output
|
||||||
|
else:
|
||||||
|
embedding = pooled_output
|
||||||
|
probs = tf.keras.layers.Dense(
|
||||||
|
units=1, activation='softmax', name='classification')(
|
||||||
|
embedding)
|
||||||
|
model = tf.keras.Model(inputs=x, outputs=probs, name='model')
|
||||||
|
|
||||||
|
optimizer = opt_cls(
|
||||||
|
l2_norm_clip=l2_norm_clip,
|
||||||
|
noise_multiplier=noise_multiplier,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
unconnected_gradients_to_zero=unconnected_gradients_to_zero,
|
||||||
|
learning_rate=learning_rate)
|
||||||
|
|
||||||
|
model.compile(
|
||||||
|
optimizer=optimizer,
|
||||||
|
loss=tf.keras.losses.MeanSquaredError(
|
||||||
|
# Return per-sample loss
|
||||||
|
reduction=tf.keras.losses.Reduction.NONE),
|
||||||
|
metrics=['accuracy'])
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class DPVectorizedOptimizerUnconnectedNodesTest(tf.test.TestCase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
"""Tests for vectorized optimizers when there are unconnected nodes.
|
||||||
|
|
||||||
|
Subclassed Keras models can have layers that are defined in the graph, but
|
||||||
|
not connected to the input or output. Or a condition expression could
|
||||||
|
determine if the layer in question was connected or not. In such cases, the
|
||||||
|
gradients are not present for that unconnected layer. The vectorized DP
|
||||||
|
optimizers compute the per-microbatch losses using the Jacobian. The Jacobian
|
||||||
|
will contain 'None' values corresponding to that layer. This causes an error
|
||||||
|
in the gradient computation.
|
||||||
|
This error can be mitigated by setting those unconnected gradients to 0
|
||||||
|
instead of 'None'. This is done using the 'unconnected_gradients' flag of the
|
||||||
|
tf.GradientTape.jacobian() method.
|
||||||
|
This class of tests tests the possible combinations of presence/absence of
|
||||||
|
unconnected layers and setting unconnected gradients to 'None' or 0. In these
|
||||||
|
tests, this is done by setting 'unconnected_gradients_to_zero' to True if the
|
||||||
|
gradients are to be set to zero, or False if they are to be set to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Parameters for testing: optimizer.
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('DPSGDVectorized_SeqOutput_UnconnectedGradients',
|
||||||
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),)
|
||||||
|
def testSeqOutputUnconnectedGradientsAsNoneFails(self, cls):
|
||||||
|
"""Tests that DP vectorized optimizers with 'None' unconnected gradients fail.
|
||||||
|
|
||||||
|
Sequence models that have unconnected gradients (with
|
||||||
|
'tf.UnconnectedGradients.NONE' passed to tf.GradientTape.jacobian) will
|
||||||
|
return a 'None' in the corresponding entry in the Jacobian. To mitigate this
|
||||||
|
the 'unconnected_gradients_to_zero' flag is added to the differentially
|
||||||
|
private optimizers to support setting these gradients to zero.
|
||||||
|
|
||||||
|
These tests test the various combinations of this flag and the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls: The DP optimizer class to test.
|
||||||
|
"""
|
||||||
|
|
||||||
|
embedding_model = keras_embedding_model_fn(
|
||||||
|
cls,
|
||||||
|
l2_norm_clip=1.0,
|
||||||
|
noise_multiplier=0.5,
|
||||||
|
num_microbatches=1,
|
||||||
|
learning_rate=1.0,
|
||||||
|
use_seq_output=True,
|
||||||
|
unconnected_gradients_to_zero=False)
|
||||||
|
|
||||||
|
train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32)
|
||||||
|
train_labels = np.random.randint(0, 2, size=(1000, 1), dtype=np.int32)
|
||||||
|
|
||||||
|
def train_data_input_fn():
|
||||||
|
return tf.data.Dataset.from_tensor_slices(
|
||||||
|
(train_data, train_labels)).batch(8)
|
||||||
|
|
||||||
|
self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
'None values not supported',
|
||||||
|
embedding_model.fit,
|
||||||
|
x=train_data_input_fn(),
|
||||||
|
epochs=1,
|
||||||
|
verbose=0)
|
||||||
|
|
||||||
|
# Parameters for testing: optimizer.
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('DPSGDVectorized_PooledOutput_UnconnectedGradients',
|
||||||
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),)
|
||||||
|
def testPooledOutputUnconnectedGradientsAsNonePasses(self, cls):
|
||||||
|
"""Tests that DP vectorized optimizers with 'None' unconnected gradients fail."""
|
||||||
|
|
||||||
|
embedding_model = keras_embedding_model_fn(
|
||||||
|
cls,
|
||||||
|
l2_norm_clip=1.0,
|
||||||
|
noise_multiplier=0.5,
|
||||||
|
num_microbatches=1,
|
||||||
|
learning_rate=1.0,
|
||||||
|
use_seq_output=False,
|
||||||
|
unconnected_gradients_to_zero=False)
|
||||||
|
|
||||||
|
train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32)
|
||||||
|
train_labels = np.random.randint(0, 2, size=(1000, 1), dtype=np.int32)
|
||||||
|
|
||||||
|
def train_data_input_fn():
|
||||||
|
return tf.data.Dataset.from_tensor_slices(
|
||||||
|
(train_data, train_labels)).batch(8)
|
||||||
|
|
||||||
|
try:
|
||||||
|
embedding_model.fit(x=train_data_input_fn(), epochs=1, verbose=0)
|
||||||
|
except ValueError:
|
||||||
|
# For a 'ValueError' exception the test should record a failure. All
|
||||||
|
# other exceptions are errors.
|
||||||
|
self.fail('ValueError raised by model.fit().')
|
||||||
|
|
||||||
|
# Parameters for testing: optimizer, use sequence output flag.
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('DPSGDVectorized_SeqOutput_UnconnectedGradientsAreZero',
|
||||||
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, True),
|
||||||
|
('DPSGDVectorized_PooledOutput_UnconnectedGradientsAreZero',
|
||||||
|
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, False),
|
||||||
|
)
|
||||||
|
def testUnconnectedGradientsAsZeroPasses(self, cls, use_seq_output):
|
||||||
|
"""Tests that DP vectorized optimizers with 'Zero' unconnected gradients pass."""
|
||||||
|
|
||||||
|
embedding_model = keras_embedding_model_fn(
|
||||||
|
cls,
|
||||||
|
l2_norm_clip=1.0,
|
||||||
|
noise_multiplier=0.5,
|
||||||
|
num_microbatches=1,
|
||||||
|
learning_rate=1.0,
|
||||||
|
use_seq_output=use_seq_output,
|
||||||
|
unconnected_gradients_to_zero=True)
|
||||||
|
|
||||||
|
train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32)
|
||||||
|
train_labels = np.random.randint(0, 2, size=(1000, 1), dtype=np.int32)
|
||||||
|
|
||||||
|
def train_data_input_fn():
|
||||||
|
return tf.data.Dataset.from_tensor_slices(
|
||||||
|
(train_data, train_labels)).batch(8)
|
||||||
|
|
||||||
|
try:
|
||||||
|
embedding_model.fit(x=train_data_input_fn(), epochs=1, verbose=0)
|
||||||
|
except ValueError:
|
||||||
|
# For a 'ValueError' exception the test should record a failure. All
|
||||||
|
# other exceptions are errors.
|
||||||
|
self.fail('ValueError raised by model.fit().')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
|
@ -106,6 +106,7 @@ def make_vectorized_keras_optimizer_class(cls):
|
||||||
l2_norm_clip,
|
l2_norm_clip,
|
||||||
noise_multiplier,
|
noise_multiplier,
|
||||||
num_microbatches=None,
|
num_microbatches=None,
|
||||||
|
unconnected_gradients_to_zero=False,
|
||||||
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Initialize the DPOptimizerClass.
|
"""Initialize the DPOptimizerClass.
|
||||||
|
@ -115,6 +116,10 @@ def make_vectorized_keras_optimizer_class(cls):
|
||||||
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
||||||
num_microbatches: Number of microbatches into which each minibatch is
|
num_microbatches: Number of microbatches into which each minibatch is
|
||||||
split.
|
split.
|
||||||
|
unconnected_gradients_to_zero: The Jacobian is used to compute the
|
||||||
|
microbatch losses. If a node in the graph is deliberately not
|
||||||
|
connected, then the Jacobian computation will return a `None` for that
|
||||||
|
node. Set this flag to True to treat these Jacobians as zero.
|
||||||
*args: These will be passed on to the base class `__init__` method.
|
*args: These will be passed on to the base class `__init__` method.
|
||||||
**kwargs: These will be passed on to the base class `__init__` method.
|
**kwargs: These will be passed on to the base class `__init__` method.
|
||||||
"""
|
"""
|
||||||
|
@ -122,6 +127,7 @@ def make_vectorized_keras_optimizer_class(cls):
|
||||||
self._l2_norm_clip = l2_norm_clip
|
self._l2_norm_clip = l2_norm_clip
|
||||||
self._noise_multiplier = noise_multiplier
|
self._noise_multiplier = noise_multiplier
|
||||||
self._num_microbatches = num_microbatches
|
self._num_microbatches = num_microbatches
|
||||||
|
self._unconnected_gradients_to_zero = unconnected_gradients_to_zero
|
||||||
self._dp_sum_query = gaussian_query.GaussianSumQuery(
|
self._dp_sum_query = gaussian_query.GaussianSumQuery(
|
||||||
l2_norm_clip, l2_norm_clip * noise_multiplier)
|
l2_norm_clip, l2_norm_clip * noise_multiplier)
|
||||||
self._global_state = None
|
self._global_state = None
|
||||||
|
@ -164,7 +170,12 @@ def make_vectorized_keras_optimizer_class(cls):
|
||||||
|
|
||||||
# Compute the per-microbatch losses using helpful jacobian method.
|
# Compute the per-microbatch losses using helpful jacobian method.
|
||||||
with tf.keras.backend.name_scope(self._name + '/gradients'):
|
with tf.keras.backend.name_scope(self._name + '/gradients'):
|
||||||
jacobian = tape.jacobian(microbatch_losses, var_list)
|
jacobian = tape.jacobian(
|
||||||
|
microbatch_losses,
|
||||||
|
var_list,
|
||||||
|
unconnected_gradients=(tf.UnconnectedGradients.ZERO
|
||||||
|
if self._unconnected_gradients_to_zero else
|
||||||
|
tf.UnconnectedGradients.NONE))
|
||||||
|
|
||||||
clipped_gradients = tf.vectorized_map(
|
clipped_gradients = tf.vectorized_map(
|
||||||
lambda g: clip_gradients_vmap(g, self._l2_norm_clip), jacobian)
|
lambda g: clip_gradients_vmap(g, self._l2_norm_clip), jacobian)
|
||||||
|
|
Loading…
Reference in a new issue