From a33afde0c105ece6c48b17a80f13899cf3e7c1b3 Mon Sep 17 00:00:00 2001 From: Steve Chien Date: Tue, 22 Feb 2022 20:31:42 -0800 Subject: [PATCH] Add ability to specify number of microbatches in `DPModel` class. PiperOrigin-RevId: 430358084 --- .../privacy/keras_models/dp_keras_model.py | 56 +++++--- .../keras_models/dp_keras_model_test.py | 123 +++++++++++++++++- 2 files changed, 157 insertions(+), 22 deletions(-) diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index 9c81f78..261edb2 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -62,25 +62,37 @@ def make_dp_model_class(cls): self, l2_norm_clip, noise_multiplier, + num_microbatches=None, use_xla=True, *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args **kwargs): """Initializes the DPModelClass. - Args: - l2_norm_clip: Clipping norm (max L2 norm of per microbatch - gradients). - noise_multiplier: Ratio of the standard deviation to the clipping - norm. - use_xla: If `True`, compiles train_step to XLA. - *args: These will be passed on to the base class `__init__` method. - **kwargs: These will be passed on to the base class `__init__` - method. + Args: + l2_norm_clip: Clipping norm (max L2 norm of per microbatch + gradients). + noise_multiplier: Ratio of the standard deviation to the clipping + norm. + num_microbatches: Number of microbatches. + use_xla: If `True`, compiles train_step to XLA. + *args: These will be passed on to the base class `__init__` method. + **kwargs: These will be passed on to the base class `__init__` + method. """ super().__init__(*args, **kwargs) self._l2_norm_clip = l2_norm_clip self._noise_multiplier = noise_multiplier + # Given that `num_microbatches` was added as an argument after the fact, + # this check helps detect unintended calls to the earlier API. + # In particular, boolean values supplied to `use_xla` in the earlier API + # will raise an error. + if isinstance(num_microbatches, bool): + raise ValueError('Boolean value supplied for `num_microbatches`. ' + 'Did you intend it for `use_xla`?') + + self._num_microbatches = num_microbatches + if use_xla: self.train_step = tf.function( self.train_step, experimental_compile=True) @@ -106,21 +118,35 @@ def make_dp_model_class(cls): def _compute_per_example_grads(self, data): x, y = data with tf.GradientTape() as tape: - # We need to add the extra dimension to x and y because model - # expects batched input. - y_pred = self(x[None], training=True) - loss = self.compiled_loss( - y[None], y_pred, regularization_losses=self.losses) + y_pred = self(x, training=True) + loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses) grads_list = tape.gradient(loss, self.trainable_variables) clipped_grads = self._process_per_example_grads(grads_list) - return tf.squeeze(y_pred, axis=0), loss, clipped_grads + return y_pred, loss, clipped_grads def train_step(self, data): """DP-SGD version of base class method.""" _, y = data + batch_size = y.shape[0] + + if self._num_microbatches is None: + self._num_microbatches = batch_size + if batch_size % self._num_microbatches != 0: + raise ValueError('Number of_microbatches must divide batch size.') + + def reshape_fn(x): + new_shape = (self._num_microbatches, + batch_size // self._num_microbatches) + x.shape[1:] + return tf.reshape(x, new_shape) + + data = tf.nest.map_structure(reshape_fn, data) + y_pred, _, per_eg_grads = tf.vectorized_map( self._compute_per_example_grads, data) + + y_pred = tf.reshape(y_pred, (batch_size) + y_pred.shape[2:]) + grads = tf.nest.map_structure(self._reduce_per_example_grads, per_eg_grads) self.optimizer.apply_gradients(zip(grads, self.trainable_variables)) diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py index 8e51aa4..a8c8508 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py @@ -92,11 +92,74 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): self.assertAllClose(model_weights[0], expected_weights) self.assertAllClose(model_weights[1], expected_bias) + def _compute_expected_gradients(self, data, labels, w, l2_norm_clip, + num_microbatches): + batch_size = data.shape[0] + if num_microbatches is None: + num_microbatches = batch_size + + preds = np.matmul(data, w) + + grads = 2 * data * (labels - preds)[:, np.newaxis] + grads = np.reshape(grads, + [num_microbatches, batch_size // num_microbatches, -1]) + + mb_grads = np.mean(grads, axis=1) + mb_grad_norms = np.linalg.norm(mb_grads, axis=1) + + scale = np.minimum(l2_norm_clip / mb_grad_norms, 1.0) + + mb_grads = mb_grads * scale[:, np.newaxis] + + final_grads = np.mean(mb_grads, axis=0) + return final_grads + @parameterized.named_parameters( - ('noise_multiplier 3 2', 3.0, 2.0), - ('noise_multiplier 5 4', 5.0, 4.0), + ('mb_test 0', 1.0, None), + ('mb_test 1', 1.0, 1), + ('mb_test 2', 1.0, 2), + ('mb_test 4', 1.0, 4), ) - def testNoiseMultiplier(self, l2_norm_clip, noise_multiplier): + def testMicrobatches(self, l2_norm_clip, num_microbatches): + train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]]) + w = np.zeros((2)) + train_labels = np.array([1.0, 3.0, -2.0, -4.0]) + learning_rate = 1.0 + + expected_grads = self._compute_expected_gradients(train_data, train_labels, + w, l2_norm_clip, + num_microbatches) + expected_weights = np.squeeze(learning_rate * expected_grads) + + optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) + loss = tf.keras.losses.MeanSquaredError() + + # Simple linear model returns w * x + b. + model = dp_keras_model.DPSequential( + l2_norm_clip=l2_norm_clip, + noise_multiplier=0.0, + num_microbatches=num_microbatches, + layers=[ + tf.keras.layers.InputLayer(input_shape=(2,)), + tf.keras.layers.Dense( + 1, use_bias=False, kernel_initializer='zeros') + ]) + model.compile(optimizer=optimizer, loss=loss) + model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False) + + model_weights = np.squeeze(model.get_weights()) + self.assertAllClose(model_weights, expected_weights) + + @parameterized.named_parameters( + ('noise_multiplier 3 2 1', 3.0, 2.0, 1), + ('noise_multiplier 5 4 1', 5.0, 4.0, 1), + ('noise_multiplier 3 2 2', 3.0, 2.0, 2), + ('noise_multiplier 5 4 2', 5.0, 4.0, 2), + ('noise_multiplier 3 2 4', 3.0, 2.0, 4), + ('noise_multiplier 5 4 4', 5.0, 4.0, 4), + ) + def testNoiseMultiplier(self, l2_norm_clip, noise_multiplier, + num_microbatches): # The idea behind this test is to start with a model whose parameters # are set to zero. We then run one step of a model that produces # an un-noised gradient of zero, and then compute the standard deviation @@ -104,8 +167,8 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): # deviation. # Data is one example of length 1000, set to zero, with label zero. - train_data = np.zeros((1, 1000)) - train_labels = np.array([0.0]) + train_data = np.zeros((4, 1000)) + train_labels = np.array([0.0, 0.0, 0.0, 0.0]) learning_rate = 1.0 optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) @@ -115,21 +178,67 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): model = dp_keras_model.DPSequential( l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, + num_microbatches=num_microbatches, layers=[ tf.keras.layers.InputLayer(input_shape=(1000,)), tf.keras.layers.Dense( 1, kernel_initializer='zeros', bias_initializer='zeros') ]) model.compile(optimizer=optimizer, loss=loss) - model.fit(train_data, train_labels, epochs=1, batch_size=1) + model.fit(train_data, train_labels, epochs=1, batch_size=4) model_weights = model.get_weights() measured_std = np.std(model_weights[0]) - expected_std = l2_norm_clip * noise_multiplier + expected_std = l2_norm_clip * noise_multiplier / num_microbatches # Test standard deviation is close to l2_norm_clip * noise_multiplier. self.assertNear(measured_std, expected_std, 0.1 * expected_std) + # Simple check to make sure dimensions are correct when output has + # dimension > 1. + @parameterized.named_parameters( + ('mb_test None 1', None, 1), + ('mb_test 1 2', 1, 2), + ('mb_test 2 2', 2, 2), + ('mb_test 4 4', 4, 4), + ) + def testMultiDimensionalOutput(self, num_microbatches, output_dimension): + train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]]) + train_labels = np.array([0, 1, 1, 0]) + learning_rate = 1.0 + + optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) + loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + + model = dp_keras_model.DPSequential( + l2_norm_clip=1.0e9, + noise_multiplier=0.0, + num_microbatches=num_microbatches, + layers=[ + tf.keras.layers.InputLayer(input_shape=(2,)), + tf.keras.layers.Dense( + output_dimension, use_bias=False, kernel_initializer='zeros') + ]) + model.compile(optimizer=optimizer, loss=loss_fn) + model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False) + + # Checks that calls to earlier API using `use_xla` as a positional argument + # raise an exception. + @parameterized.named_parameters( + ('earlier API True', True), + ('earlier API False', False), + ) + def testEarlierAPIFails(self, use_xla): + with self.assertRaises(ValueError): + _ = dp_keras_model.DPSequential( + 1.0e9, + 0.0, + use_xla, + layers=[ + tf.keras.layers.InputLayer(input_shape=(2,)), + tf.keras.layers.Dense( + 2, use_bias=False, kernel_initializer='zeros') + ]) if __name__ == '__main__': tf.test.main()