From 95e527acfb8d3f26a74d3ad17931b7640b3a947f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 May 2022 05:36:22 -0700 Subject: [PATCH] Support models with unconnected layers and gradients when training using a DP vectorized optimizer. PiperOrigin-RevId: 450659644 --- .../optimizers/dp_optimizer_keras_test.py | 195 ++++++++++++++++++ .../dp_optimizer_keras_vectorized.py | 13 +- 2 files changed, 207 insertions(+), 1 deletion(-) diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py index 8417750..045b118 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from absl.testing import parameterized import numpy as np import tensorflow as tf @@ -469,5 +470,199 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase): self.assertNotAllClose([3.0], var1) +class SimpleEmbeddingModel(tf.keras.Model): + """Simple embedding model.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.embed_layer = tf.keras.layers.Embedding( + name='embedding', + input_dim=10, # vocabulary size. + output_dim=6, # embedding size. + embeddings_initializer='uniform', + input_length=4) # sequence length. + self.pool_layer = tf.keras.layers.Dense( + name='pooler', + units=6, + activation='tanh', + kernel_initializer='zeros', + bias_initializer='zeros') + self.probs_layer = tf.keras.layers.Dense( + units=1, activation='softmax', name='classification') + + def call(self, inputs, training=None): + # The shape of the sequence output from the embedding layer is + # [batch_size, sequence_length, embedding_size] + sequence_output = self.embed_layer(inputs) + first_token_tensor = tf.squeeze(sequence_output[:, 0:1, :], axis=1) + # The shape of the pooled output from the embedding layer is + # [batch_size, embedding_size] + pooled_output = self.pool_layer(first_token_tensor) + return sequence_output, pooled_output + + +def keras_embedding_model_fn(opt_cls, + l2_norm_clip: float, + noise_multiplier: float, + num_microbatches: int, + learning_rate: float, + use_seq_output: bool = False, + unconnected_gradients_to_zero: bool = False): + """Construct a simple embedding model with a classification layer.""" + + # Every sample has 4 tokens (sequence length=4). + x = tf.keras.layers.Input(shape=(4,), dtype=tf.float32, name='input') + sequence_output, pooled_output = SimpleEmbeddingModel()(x) + if use_seq_output: + embedding = sequence_output + else: + embedding = pooled_output + probs = tf.keras.layers.Dense( + units=1, activation='softmax', name='classification')( + embedding) + model = tf.keras.Model(inputs=x, outputs=probs, name='model') + + optimizer = opt_cls( + l2_norm_clip=l2_norm_clip, + noise_multiplier=noise_multiplier, + num_microbatches=num_microbatches, + unconnected_gradients_to_zero=unconnected_gradients_to_zero, + learning_rate=learning_rate) + + model.compile( + optimizer=optimizer, + loss=tf.keras.losses.MeanSquaredError( + # Return per-sample loss + reduction=tf.keras.losses.Reduction.NONE), + metrics=['accuracy']) + return model + + +class DPVectorizedOptimizerUnconnectedNodesTest(tf.test.TestCase, + parameterized.TestCase): + """Tests for vectorized optimizers when there are unconnected nodes. + + Subclassed Keras models can have layers that are defined in the graph, but + not connected to the input or output. Or a condition expression could + determine if the layer in question was connected or not. In such cases, the + gradients are not present for that unconnected layer. The vectorized DP + optimizers compute the per-microbatch losses using the Jacobian. The Jacobian + will contain 'None' values corresponding to that layer. This causes an error + in the gradient computation. + This error can be mitigated by setting those unconnected gradients to 0 + instead of 'None'. This is done using the 'unconnected_gradients' flag of the + tf.GradientTape.jacobian() method. + This class of tests tests the possible combinations of presence/absence of + unconnected layers and setting unconnected gradients to 'None' or 0. In these + tests, this is done by setting 'unconnected_gradients_to_zero' to True if the + gradients are to be set to zero, or False if they are to be set to None. + """ + + # Parameters for testing: optimizer. + @parameterized.named_parameters( + ('DPSGDVectorized_SeqOutput_UnconnectedGradients', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),) + def testSeqOutputUnconnectedGradientsAsNoneFails(self, cls): + """Tests that DP vectorized optimizers with 'None' unconnected gradients fail. + + Sequence models that have unconnected gradients (with + 'tf.UnconnectedGradients.NONE' passed to tf.GradientTape.jacobian) will + return a 'None' in the corresponding entry in the Jacobian. To mitigate this + the 'unconnected_gradients_to_zero' flag is added to the differentially + private optimizers to support setting these gradients to zero. + + These tests test the various combinations of this flag and the model. + + Args: + cls: The DP optimizer class to test. + """ + + embedding_model = keras_embedding_model_fn( + cls, + l2_norm_clip=1.0, + noise_multiplier=0.5, + num_microbatches=1, + learning_rate=1.0, + use_seq_output=True, + unconnected_gradients_to_zero=False) + + train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32) + train_labels = np.random.randint(0, 2, size=(1000, 1), dtype=np.int32) + + def train_data_input_fn(): + return tf.data.Dataset.from_tensor_slices( + (train_data, train_labels)).batch(8) + + self.assertRaisesRegex( + ValueError, + 'None values not supported', + embedding_model.fit, + x=train_data_input_fn(), + epochs=1, + verbose=0) + + # Parameters for testing: optimizer. + @parameterized.named_parameters( + ('DPSGDVectorized_PooledOutput_UnconnectedGradients', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),) + def testPooledOutputUnconnectedGradientsAsNonePasses(self, cls): + """Tests that DP vectorized optimizers with 'None' unconnected gradients fail.""" + + embedding_model = keras_embedding_model_fn( + cls, + l2_norm_clip=1.0, + noise_multiplier=0.5, + num_microbatches=1, + learning_rate=1.0, + use_seq_output=False, + unconnected_gradients_to_zero=False) + + train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32) + train_labels = np.random.randint(0, 2, size=(1000, 1), dtype=np.int32) + + def train_data_input_fn(): + return tf.data.Dataset.from_tensor_slices( + (train_data, train_labels)).batch(8) + + try: + embedding_model.fit(x=train_data_input_fn(), epochs=1, verbose=0) + except ValueError: + # For a 'ValueError' exception the test should record a failure. All + # other exceptions are errors. + self.fail('ValueError raised by model.fit().') + + # Parameters for testing: optimizer, use sequence output flag. + @parameterized.named_parameters( + ('DPSGDVectorized_SeqOutput_UnconnectedGradientsAreZero', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, True), + ('DPSGDVectorized_PooledOutput_UnconnectedGradientsAreZero', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, False), + ) + def testUnconnectedGradientsAsZeroPasses(self, cls, use_seq_output): + """Tests that DP vectorized optimizers with 'Zero' unconnected gradients pass.""" + + embedding_model = keras_embedding_model_fn( + cls, + l2_norm_clip=1.0, + noise_multiplier=0.5, + num_microbatches=1, + learning_rate=1.0, + use_seq_output=use_seq_output, + unconnected_gradients_to_zero=True) + + train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32) + train_labels = np.random.randint(0, 2, size=(1000, 1), dtype=np.int32) + + def train_data_input_fn(): + return tf.data.Dataset.from_tensor_slices( + (train_data, train_labels)).batch(8) + + try: + embedding_model.fit(x=train_data_input_fn(), epochs=1, verbose=0) + except ValueError: + # For a 'ValueError' exception the test should record a failure. All + # other exceptions are errors. + self.fail('ValueError raised by model.fit().') + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py index 3481e6d..16a12d5 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py @@ -106,6 +106,7 @@ def make_vectorized_keras_optimizer_class(cls): l2_norm_clip, noise_multiplier, num_microbatches=None, + unconnected_gradients_to_zero=False, *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args **kwargs): """Initialize the DPOptimizerClass. @@ -115,6 +116,10 @@ def make_vectorized_keras_optimizer_class(cls): noise_multiplier: Ratio of the standard deviation to the clipping norm. num_microbatches: Number of microbatches into which each minibatch is split. + unconnected_gradients_to_zero: The Jacobian is used to compute the + microbatch losses. If a node in the graph is deliberately not + connected, then the Jacobian computation will return a `None` for that + node. Set this flag to True to treat these Jacobians as zero. *args: These will be passed on to the base class `__init__` method. **kwargs: These will be passed on to the base class `__init__` method. """ @@ -122,6 +127,7 @@ def make_vectorized_keras_optimizer_class(cls): self._l2_norm_clip = l2_norm_clip self._noise_multiplier = noise_multiplier self._num_microbatches = num_microbatches + self._unconnected_gradients_to_zero = unconnected_gradients_to_zero self._dp_sum_query = gaussian_query.GaussianSumQuery( l2_norm_clip, l2_norm_clip * noise_multiplier) self._global_state = None @@ -164,7 +170,12 @@ def make_vectorized_keras_optimizer_class(cls): # Compute the per-microbatch losses using helpful jacobian method. with tf.keras.backend.name_scope(self._name + '/gradients'): - jacobian = tape.jacobian(microbatch_losses, var_list) + jacobian = tape.jacobian( + microbatch_losses, + var_list, + unconnected_gradients=(tf.UnconnectedGradients.ZERO + if self._unconnected_gradients_to_zero else + tf.UnconnectedGradients.NONE)) clipped_gradients = tf.vectorized_map( lambda g: clip_gradients_vmap(g, self._l2_norm_clip), jacobian)