diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index 93840d6..439febc 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -93,8 +93,6 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry): loss_config['reduction'] = tf.keras.losses.Reduction.NONE per_example_loss_fn = input_model.loss.from_config(loss_config) losses = per_example_loss_fn(y_batch, model_outputs) - if tf.rank(tf.squeeze(losses)) > 1: - raise NotImplementedError('Vector losses are not supported.') summed_loss = tf.reduce_sum(losses) # Second loop computes the norm of the gradient of the loss with respect to # the pre-activation tensors, and multiplies these norms with the results of diff --git a/tensorflow_privacy/privacy/keras_models/BUILD b/tensorflow_privacy/privacy/keras_models/BUILD index d19e0e5..407537a 100644 --- a/tensorflow_privacy/privacy/keras_models/BUILD +++ b/tensorflow_privacy/privacy/keras_models/BUILD @@ -15,6 +15,11 @@ py_library( "dp_keras_model.py", ], srcs_version = "PY3", + deps = [ + "//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads", + "//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils", + "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry_factories", + ], ) py_test( @@ -22,5 +27,8 @@ py_test( srcs = ["dp_keras_model_test.py"], python_version = "PY3", srcs_version = "PY3", - deps = ["//tensorflow_privacy/privacy/keras_models:dp_keras_model"], + deps = [ + "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry_factories", + "//tensorflow_privacy/privacy/keras_models:dp_keras_model", + ], ) diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index 261edb2..c2e2421 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -13,19 +13,38 @@ # limitations under the License. """Keras Model for vectorized dpsgd with XLA acceleration.""" +from absl import logging import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads +from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils def make_dp_model_class(cls): """Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it.""" - class DPModelClass(cls): # pylint: disable=empty-docstring - __doc__ = ("""DP subclass of `{base_model}`. + class DPModelClass(cls): # pylint: disable=missing-class-docstring + __doc__ = ( + """DP subclass of `{base_model}`. This can be used as a differentially private replacement for {base_model}. This class implements DP-SGD using the standard Gaussian mechanism. + This class also utilizes a faster gradient clipping algorithm if the + following two conditions hold: + (i) the trainable layers of the model are keys in the `dict` input + `layer_registry`, + (ii) the loss `tf.Tensor` for a given batch of examples is either a + scalar or a 2D `tf.Tensor` that has only one column + `(i.e., tf.shape(loss)[1] == 1)` and whose i-th row corresponds to + the loss of the i-th example. + This clipping algorithm specifically computes clipped gradients at the + per-example level using the layer registry functions in `layer_registry` + (see clip_grads.py for more information about the algorithm). In this + setting, microbatching is not used (it is equivalent to + `num_microbatches == batch_size`), and the input `num_microbatches` + is ignored. + When instantiating this class, you need to supply several DP-related arguments followed by the standard arguments for `{short_base_model}`. @@ -53,10 +72,12 @@ def make_dp_model_class(cls): model.fit(train_data, train_labels, epochs=1, batch_size=32) ``` - """).format( - base_model='tf.keras.' + cls.__name__, - short_base_model=cls.__name__, - dp_model_class='DP' + cls.__name__) + """ + ).format( + base_model='tf.keras.' + cls.__name__, + short_base_model=cls.__name__, + dp_model_class='DP' + cls.__name__, + ) def __init__( self, @@ -64,24 +85,31 @@ def make_dp_model_class(cls): noise_multiplier, num_microbatches=None, use_xla=True, + layer_registry=None, *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args - **kwargs): + **kwargs, + ): """Initializes the DPModelClass. Args: - l2_norm_clip: Clipping norm (max L2 norm of per microbatch - gradients). - noise_multiplier: Ratio of the standard deviation to the clipping - norm. + l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients). + noise_multiplier: Ratio of the standard deviation to the clipping norm. num_microbatches: Number of microbatches. use_xla: If `True`, compiles train_step to XLA. + layer_registry: A `dict` of layers that support "fast" gradient norm + computations. The key is the class of the layer and the value is a + function that returns a `tuple` `(output, sqr_grad_norms, vars)`, + where `output` is the pre-activator tensor, `sqr_grad_norms` is + related to the squared norms of a layer's pre-activation tensor, and + `vars` are relevant trainable weights (see + `layer_registry_factories.py` for examples). *args: These will be passed on to the base class `__init__` method. - **kwargs: These will be passed on to the base class `__init__` - method. + **kwargs: These will be passed on to the base class `__init__` method. """ super().__init__(*args, **kwargs) self._l2_norm_clip = l2_norm_clip self._noise_multiplier = noise_multiplier + self._layer_registry = layer_registry # Given that `num_microbatches` was added as an argument after the fact, # this check helps detect unintended calls to the earlier API. @@ -91,7 +119,27 @@ def make_dp_model_class(cls): raise ValueError('Boolean value supplied for `num_microbatches`. ' 'Did you intend it for `use_xla`?') - self._num_microbatches = num_microbatches + # If all the trainable layers are in the input layer registry, we + # don't need to use microbatching and can instead use the "fast" + # chain rule trick for computing per-example gradients (peg). + if ( + layer_registry is not None + and gradient_clipping_utils.all_trainable_layers_are_registered( + self, layer_registry + ) + and gradient_clipping_utils.has_internal_compute_graph(self) + ): + if num_microbatches is not None: + raise ValueError( + 'Cannot initialize a model where num_microbatches ' + 'is not `None` and all trainable layers are ' + 'registered in layer_registry.' + ) + self._num_microbatches = None + self._enable_fast_peg_computation = True + else: + self._num_microbatches = num_microbatches + self._enable_fast_peg_computation = False if use_xla: self.train_step = tf.function( @@ -126,29 +174,72 @@ def make_dp_model_class(cls): return y_pred, loss, clipped_grads def train_step(self, data): - """DP-SGD version of base class method.""" - _, y = data - batch_size = y.shape[0] + """DP-SGD version of base class method. - if self._num_microbatches is None: - self._num_microbatches = batch_size - if batch_size % self._num_microbatches != 0: - raise ValueError('Number of_microbatches must divide batch size.') + Uses the "fast" gradient clipping algorithm to generate per-example + clipped gradients if (i) all the trainable layers of the model are + registered in the layer_registry input of the model constructor and + (ii) if the model contains an internal compute graph (e.g., this + condition is satisfied if the model subclasses the keras.Sequential or + keras.engine.functional.Functional class). - def reshape_fn(x): - new_shape = (self._num_microbatches, - batch_size // self._num_microbatches) + x.shape[1:] - return tf.reshape(x, new_shape) + If (i) and (ii) above do not hold, then clips and aggregates + gradients at the microbatch level. - data = tf.nest.map_structure(reshape_fn, data) + Args: + data: see the base class. - y_pred, _, per_eg_grads = tf.vectorized_map( - self._compute_per_example_grads, data) + Returns: + See the base class. + """ + if self._enable_fast_peg_computation: + logging.info( + 'Computing gradients using the fast per-example gradient ' + 'norm algorithm.' + ) + # Computes the per-example gradient norms using a "fast" clipping + # trick, and uses these norms to clip the per-example gradients. + x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data) + y_pred, clipped_grads = clip_grads.compute_pred_and_clipped_gradients( + self, x, y, self._l2_norm_clip, self._layer_registry + ) + grads = gradient_clipping_utils.add_aggregate_noise( + self, x, clipped_grads, self._l2_norm_clip, self._noise_multiplier + ) + else: + logging.info('Computing gradients using microbatching.') + # Computes per-example clipped gradients directly. This is called + # if at least one of the layers cannot use the "fast" gradient clipping + # algorithm. + # TODO(wkong): check if the following is valid with sample weights. + _, y = data + batch_size = y.shape[0] - y_pred = tf.reshape(y_pred, (batch_size) + y_pred.shape[2:]) + if self._num_microbatches is None: + self._num_microbatches = batch_size + if batch_size % self._num_microbatches != 0: + raise ValueError('Number of_microbatches must divide batch size.') - grads = tf.nest.map_structure(self._reduce_per_example_grads, - per_eg_grads) + def reshape_fn(x): + new_shape = ( + self._num_microbatches, + batch_size // self._num_microbatches, + ) + x.shape[1:] + return tf.reshape(x, new_shape) + + data = tf.nest.map_structure(reshape_fn, data) + + y_pred, _, per_eg_grads = tf.vectorized_map( + self._compute_per_example_grads, data + ) + + y_pred = tf.reshape(y_pred, (batch_size) + y_pred.shape[2:]) + + grads = tf.nest.map_structure( + self._reduce_per_example_grads, per_eg_grads + ) + + # Forward the private gradients to the optimizer and return the results. self.optimizer.apply_gradients(zip(grads, self.trainable_variables)) self.compiled_metrics.update_state(y, y_pred) return {m.name: m.result() for m in self.metrics} diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py index a8c8508..172a853 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py @@ -13,10 +13,9 @@ # limitations under the License. from absl.testing import parameterized - import numpy as np import tensorflow as tf - +from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry_factories from tensorflow_privacy.privacy.keras_models import dp_keras_model @@ -29,6 +28,13 @@ def get_data(): return data, labels +def get_layer_registries(): + # Outputs a list of testable layer registries. + # The empty registry {} tests the behavior of the standard approach, + # while the other one tests the fast gradient clipping algorithm. + return [{}, layer_registry_factories.make_default_layer_registry()] + + class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): def testBaseline(self): @@ -65,32 +71,35 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): """Tests that clipping norm works.""" train_data, train_labels = get_data() - # Simple linear model returns w * x + b. - model = dp_keras_model.DPSequential( - l2_norm_clip=l2_norm_clip, - noise_multiplier=0.0, - layers=[ - tf.keras.layers.InputLayer(input_shape=(2,)), - tf.keras.layers.Dense( - 1, kernel_initializer='zeros', bias_initializer='zeros') - ]) - learning_rate = 0.01 - optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) - loss = tf.keras.losses.MeanSquaredError() + for test_reg in get_layer_registries(): + # Simple linear model returns w * x + b. + model = dp_keras_model.DPSequential( + l2_norm_clip=l2_norm_clip, + noise_multiplier=0.0, + layer_registry=test_reg, + layers=[ + tf.keras.layers.InputLayer(input_shape=(2,)), + tf.keras.layers.Dense( + 1, kernel_initializer='zeros', bias_initializer='zeros' + ), + ], + ) + learning_rate = 0.01 + optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) + loss = tf.keras.losses.MeanSquaredError() + model.compile(optimizer=optimizer, loss=loss) + model.fit(train_data, train_labels, epochs=1, batch_size=1) - model.compile(optimizer=optimizer, loss=loss) - model.fit(train_data, train_labels, epochs=1, batch_size=1) + model_weights = model.get_weights() - model_weights = model.get_weights() + unclipped_gradient = np.sqrt(90**2 + 120**2 + 30**2) + scale = min(1.0, l2_norm_clip / unclipped_gradient) + expected_weights = np.array([[90], [120]]) * scale * learning_rate + expected_bias = np.array([30]) * scale * learning_rate - unclipped_gradient = np.sqrt(90**2 + 120**2 + 30**2) - scale = min(1.0, l2_norm_clip / unclipped_gradient) - expected_weights = np.array([[90], [120]]) * scale * learning_rate - expected_bias = np.array([30]) * scale * learning_rate - - # Check parameters are as expected, taking into account the learning rate. - self.assertAllClose(model_weights[0], expected_weights) - self.assertAllClose(model_weights[1], expected_bias) + # Check parameters are as expected, taking into account the learning rate. + self.assertAllClose(model_weights[0], expected_weights) + self.assertAllClose(model_weights[1], expected_bias) def _compute_expected_gradients(self, data, labels, w, l2_norm_clip, num_microbatches): @@ -98,9 +107,10 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): if num_microbatches is None: num_microbatches = batch_size - preds = np.matmul(data, w) + preds = np.matmul(data, np.expand_dims(w, axis=1)) + + grads = 2 * data * (preds - labels) - grads = 2 * data * (labels - preds)[:, np.newaxis] grads = np.reshape(grads, [num_microbatches, batch_size // num_microbatches, -1]) @@ -123,32 +133,45 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): def testMicrobatches(self, l2_norm_clip, num_microbatches): train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]]) w = np.zeros((2)) - train_labels = np.array([1.0, 3.0, -2.0, -4.0]) + train_labels = np.array([[1.0], [3.0], [-2.0], [-4.0]]) learning_rate = 1.0 - expected_grads = self._compute_expected_gradients(train_data, train_labels, - w, l2_norm_clip, - num_microbatches) - expected_weights = np.squeeze(learning_rate * expected_grads) + for test_reg, test_nm in zip( + get_layer_registries(), [num_microbatches, None] + ): + optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) + loss = tf.keras.losses.MeanSquaredError() - optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) - loss = tf.keras.losses.MeanSquaredError() + # Simple linear model returns w * x. + model = dp_keras_model.DPSequential( + l2_norm_clip=l2_norm_clip, + noise_multiplier=0.0, + num_microbatches=test_nm, + layer_registry=test_reg, + layers=[ + tf.keras.layers.InputLayer(input_shape=(2,)), + tf.keras.layers.Dense( + 1, use_bias=False, kernel_initializer='zeros' + ), + ], + ) + model.compile(optimizer=optimizer, loss=loss) + model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False) - # Simple linear model returns w * x + b. - model = dp_keras_model.DPSequential( - l2_norm_clip=l2_norm_clip, - noise_multiplier=0.0, - num_microbatches=num_microbatches, - layers=[ - tf.keras.layers.InputLayer(input_shape=(2,)), - tf.keras.layers.Dense( - 1, use_bias=False, kernel_initializer='zeros') - ]) - model.compile(optimizer=optimizer, loss=loss) - model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False) + model_weights = np.squeeze(model.get_weights()) - model_weights = np.squeeze(model.get_weights()) - self.assertAllClose(model_weights, expected_weights) + effective_num_microbatches = ( + train_data.shape[0] + if model._num_microbatches is None + else num_microbatches + ) + + expected_grads = self._compute_expected_gradients( + train_data, train_labels, w, l2_norm_clip, effective_num_microbatches + ) + expected_weights = np.squeeze(-learning_rate * expected_grads) + + self.assertAllClose(model_weights, expected_weights) @parameterized.named_parameters( ('noise_multiplier 3 2 1', 3.0, 2.0, 1), @@ -168,59 +191,81 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): # Data is one example of length 1000, set to zero, with label zero. train_data = np.zeros((4, 1000)) - train_labels = np.array([0.0, 0.0, 0.0, 0.0]) + train_labels = np.array([[0.0], [0.0], [0.0], [0.0]]) learning_rate = 1.0 - optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) - loss = tf.keras.losses.MeanSquaredError() - # Simple linear model returns w * x + b. - model = dp_keras_model.DPSequential( - l2_norm_clip=l2_norm_clip, - noise_multiplier=noise_multiplier, - num_microbatches=num_microbatches, - layers=[ - tf.keras.layers.InputLayer(input_shape=(1000,)), - tf.keras.layers.Dense( - 1, kernel_initializer='zeros', bias_initializer='zeros') - ]) - model.compile(optimizer=optimizer, loss=loss) - model.fit(train_data, train_labels, epochs=1, batch_size=4) + for test_reg, test_nm in zip( + get_layer_registries(), [num_microbatches, None] + ): + optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) + loss = tf.keras.losses.MeanSquaredError() - model_weights = model.get_weights() - measured_std = np.std(model_weights[0]) - expected_std = l2_norm_clip * noise_multiplier / num_microbatches + # Simple linear model returns w * x + b. + model = dp_keras_model.DPSequential( + l2_norm_clip=l2_norm_clip, + noise_multiplier=noise_multiplier, + num_microbatches=test_nm, + layer_registry=test_reg, + layers=[ + tf.keras.layers.InputLayer(input_shape=(1000,)), + tf.keras.layers.Dense( + 1, kernel_initializer='zeros', bias_initializer='zeros' + ), + ], + ) + model.compile(optimizer=optimizer, loss=loss) + model.fit(train_data, train_labels, epochs=1, batch_size=4) - # Test standard deviation is close to l2_norm_clip * noise_multiplier. - self.assertNear(measured_std, expected_std, 0.1 * expected_std) + effective_num_microbatches = ( + train_data.shape[0] + if model._num_microbatches is None + else num_microbatches + ) + + model_weights = model.get_weights() + measured_std = np.std(model_weights[0]) + expected_std = ( + l2_norm_clip * noise_multiplier / effective_num_microbatches + ) + + # Test standard deviation is close to l2_norm_clip * noise_multiplier. + self.assertNear(measured_std, expected_std, 0.1 * expected_std) # Simple check to make sure dimensions are correct when output has # dimension > 1. @parameterized.named_parameters( - ('mb_test None 1', None, 1), + ('mb_test None 2', None, 2), ('mb_test 1 2', 1, 2), ('mb_test 2 2', 2, 2), ('mb_test 4 4', 4, 4), ) def testMultiDimensionalOutput(self, num_microbatches, output_dimension): train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]]) - train_labels = np.array([0, 1, 1, 0]) + train_labels = np.array([[0], [1], [1], [0]]) learning_rate = 1.0 - optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) - loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + for test_reg, test_nm in zip( + get_layer_registries(), [num_microbatches, None] + ): + optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) + loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) - model = dp_keras_model.DPSequential( - l2_norm_clip=1.0e9, - noise_multiplier=0.0, - num_microbatches=num_microbatches, - layers=[ - tf.keras.layers.InputLayer(input_shape=(2,)), - tf.keras.layers.Dense( - output_dimension, use_bias=False, kernel_initializer='zeros') - ]) - model.compile(optimizer=optimizer, loss=loss_fn) - model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False) + model = dp_keras_model.DPSequential( + l2_norm_clip=1.0e9, + noise_multiplier=0.0, + num_microbatches=test_nm, + layer_registry=test_reg, + layers=[ + tf.keras.layers.InputLayer(input_shape=(2,)), + tf.keras.layers.Dense( + output_dimension, use_bias=False, kernel_initializer='zeros' + ), + tf.keras.layers.Dense(1), + ], + ) + model.compile(optimizer=optimizer, loss=loss_fn) + model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False) # Checks that calls to earlier API using `use_xla` as a positional argument # raise an exception. @@ -237,8 +282,11 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): layers=[ tf.keras.layers.InputLayer(input_shape=(2,)), tf.keras.layers.Dense( - 2, use_bias=False, kernel_initializer='zeros') - ]) + 2, use_bias=False, kernel_initializer='zeros' + ), + tf.keras.layers.Dense(1), + ], + ) if __name__ == '__main__': tf.test.main()